├── arena ├── __init__.py ├── env │ ├── __init__.py │ ├── env_int_wrapper.py │ ├── pong2p_env.py │ ├── sc2_base_env.py │ └── soccer_env.py ├── utils │ ├── __init__.py │ ├── pong2p │ │ ├── __init__.py │ │ ├── pong2p_utils.py │ │ └── pong2p_env.py │ ├── vizdoom │ │ ├── _scenarios │ │ │ ├── cig.wad │ │ │ ├── basic.wad │ │ │ ├── cig2017.wad │ │ │ ├── dogfight.wad │ │ │ ├── deathmatch.wad │ │ │ ├── multi_duel.wad │ │ │ ├── my_way_home.wad │ │ │ ├── take_cover.wad │ │ │ ├── rocket_basic.wad │ │ │ ├── simpler_basic.wad │ │ │ ├── cig_with_unknown.wad │ │ │ ├── deadly_corridor.wad │ │ │ ├── defend_the_line.wad │ │ │ ├── health_gathering.wad │ │ │ ├── multi_deathmatch.wad │ │ │ ├── predict_position.wad │ │ │ ├── defend_the_center.wad │ │ │ ├── health_gathering_supreme.wad │ │ │ ├── multi_duel.cfg │ │ │ ├── simpler_basic.cfg │ │ │ ├── rocket_basic.cfg │ │ │ ├── learning.cfg │ │ │ ├── take_cover.cfg │ │ │ ├── defend_the_line.cfg │ │ │ ├── multi.cfg │ │ │ ├── health_gathering.cfg │ │ │ ├── health_gathering_supreme.cfg │ │ │ ├── basic.cfg │ │ │ ├── deadly_corridor.cfg │ │ │ ├── my_way_home.cfg │ │ │ ├── predict_position.cfg │ │ │ ├── defend_the_center.cfg │ │ │ ├── oblige.cfg │ │ │ ├── deathmatch.cfg │ │ │ ├── dogfight.cfg │ │ │ ├── cig.cfg │ │ │ └── bots.cfg │ │ ├── random_agent.py │ │ ├── run_loop.py │ │ ├── run_parallel.py │ │ ├── player.py │ │ ├── player_vs_f1.py │ │ ├── Rect.py │ │ └── core_env.py │ ├── spaces.py │ ├── unit_util.py │ ├── run_loop.py │ └── constant.py ├── agents │ ├── __init__.py │ ├── vizdoom │ │ └── random_agent.py │ ├── simple_agent.py │ ├── base_agent.py │ └── agt_int_wrapper.py ├── interfaces │ ├── __init__.py │ ├── pommerman │ │ ├── __init__.py │ │ └── obs_int.py │ ├── vizdoom │ │ ├── __init__.py │ │ ├── act_int.py │ │ └── obs_int.py │ ├── sc2full_formal │ │ ├── __init__.py │ │ ├── noop_int.py │ │ ├── obs_int.py │ │ ├── zerg_obs_int.py │ │ └── act_int.py │ ├── raw_int.py │ ├── interface.py │ ├── soccer │ │ └── obs_int.py │ └── combine.py ├── wrappers │ ├── __init__.py │ ├── vizdoom │ │ ├── __init__.py │ │ ├── observation.py │ │ ├── action.py │ │ ├── game.py │ │ ├── reward.py │ │ └── reward_shape.py │ ├── sc2stat_wrapper.py │ ├── pong2p │ │ ├── pong2p_compete.py │ │ └── pong2p_wrappers.py │ └── basic_env_wrapper.py └── sandbox │ ├── full_game.sh │ ├── run.sh │ └── run_mp_game.py ├── .gitignore ├── setup.py ├── README.md └── docs └── agt_int.md /arena/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/env/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/agents/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/interfaces/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/utils/pong2p/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /arena/interfaces/pommerman/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/interfaces/vizdoom/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/wrappers/vizdoom/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/interfaces/sc2full_formal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/cig.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/cig.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/basic.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/basic.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/cig2017.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/cig2017.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/dogfight.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/dogfight.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/deathmatch.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/deathmatch.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/multi_duel.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/multi_duel.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/my_way_home.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/my_way_home.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/take_cover.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/take_cover.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/rocket_basic.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/rocket_basic.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/simpler_basic.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/simpler_basic.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/cig_with_unknown.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/cig_with_unknown.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/deadly_corridor.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/deadly_corridor.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/defend_the_line.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/defend_the_line.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/health_gathering.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/health_gathering.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/multi_deathmatch.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/multi_deathmatch.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/predict_position.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/predict_position.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/defend_the_center.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/defend_the_center.wad -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/health_gathering_supreme.wad: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tencent-ailab/Arena/HEAD/arena/utils/vizdoom/_scenarios/health_gathering_supreme.wad -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *_pb2.py 3 | __pychaces__ 4 | 5 | .DS_Store 6 | 7 | .idea 8 | Arena.egg-info/ 9 | */__pycache__/ 10 | sphinx/_build/ 11 | 12 | _vizdoom.ini 13 | arena/env/_vizdoom/ 14 | -------------------------------------------------------------------------------- /arena/agents/vizdoom/random_agent.py: -------------------------------------------------------------------------------- 1 | from random import choice 2 | 3 | 4 | class RandomAgent(object): 5 | def __init__(self, allowed_action): 6 | self._allowed_action = allowed_action 7 | 8 | def step(self, obs): 9 | return choice(self._allowed_action) 10 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/random_agent.py: -------------------------------------------------------------------------------- 1 | from random import choice 2 | 3 | 4 | class RandomAgent(object): 5 | def __init__(self, allowed_action): 6 | self._allowed_action = allowed_action 7 | 8 | def step(self, obs): 9 | return choice(self._allowed_action) 10 | -------------------------------------------------------------------------------- /arena/sandbox/full_game.sh: -------------------------------------------------------------------------------- 1 | python -m sc2arena.run_mp_game \ 2 | --map AbyssalReef \ 3 | --player2 tstarbot.agents.zerg_agent.ZergAgent \ 4 | --step_mul 8 \ 5 | --agent_interface_format "rgb" \ 6 | --screen_resolution 640 \ 7 | --visualize \ 8 | --difficulty '1' 9 | 10 | -------------------------------------------------------------------------------- /arena/sandbox/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m sc2arena.run_mp_game \ 3 | --map 4MarineA \ 4 | --player1 sc2arena.agents.simple_agent.AtkWeakestAgent \ 5 | --screen_resolution 64 \ 6 | --step_mul 8 \ 7 | --novisualize \ 8 | --agent_interface_format "rgb" \ 9 | --sleep_time_per_step 0 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='Arena', 5 | version='1.3', 6 | description='Arena', 7 | keywords='Arena', 8 | packages=[ 9 | 'arena', 10 | ], 11 | 12 | install_requires=[ 13 | 'gym', 14 | 'pillow', 15 | 'pygame' 16 | ], 17 | ) 18 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/multi_duel.cfg: -------------------------------------------------------------------------------- 1 | doom_scenario_path = multi_duel.wad 2 | 3 | screen_resolution = RES_640X480 4 | screen_format = CRCGCB 5 | render_hud = true 6 | render_crosshair = false 7 | render_weapon = true 8 | render_decals = true 9 | render_particles = true 10 | window_visible = true 11 | 12 | available_buttons = 13 | { 14 | MOVE_LEFT 15 | MOVE_RIGHT 16 | ATTACK 17 | } 18 | 19 | mode = PLAYER 20 | doom_skill = 5 21 | 22 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/simpler_basic.cfg: -------------------------------------------------------------------------------- 1 | doom_scenario_path = simpler_basic.wad 2 | 3 | # Rewards 4 | living_reward = -1 5 | 6 | # Rendering options 7 | screen_resolution = RES_640X480 8 | screen_format = GRAY8 9 | 10 | render_hud = true 11 | render_crosshair = false 12 | render_weapon = true 13 | render_decals = false 14 | render_particles = false 15 | 16 | # make episodes start after 20 tics (after unholstering the gun) 17 | episode_start_time = 14 18 | 19 | # make episodes finish after 300 actions (tics) 20 | episode_timeout = 300 21 | 22 | # Available buttons 23 | available_buttons = 24 | { 25 | MOVE_LEFT 26 | MOVE_RIGHT 27 | ATTACK 28 | } 29 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/rocket_basic.cfg: -------------------------------------------------------------------------------- 1 | doom_scenario_path = rocket_basic.wad 2 | 3 | # Rewards 4 | living_reward = -1 5 | 6 | # Rendering options 7 | screen_resolution = RES_640X480 8 | screen_format = GRAY8 9 | render_hud = true 10 | render_crosshair = false 11 | render_weapon = true 12 | render_decals = false 13 | render_particles = false 14 | 15 | # make episodes start after 14 tics (after unholstering the gun) 16 | episode_start_time = 14 17 | 18 | # make episodes finish after 300 actions (tics) 19 | episode_timeout = 300 20 | 21 | # Available buttons 22 | available_buttons = 23 | { 24 | MOVE_LEFT 25 | MOVE_RIGHT 26 | ATTACK 27 | } 28 | 29 | game_args += +sv_noautoaim 1 30 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/learning.cfg: -------------------------------------------------------------------------------- 1 | doom_scenario_path = basic.wad 2 | 3 | # Rewards 4 | living_reward = -1 5 | 6 | # Rendering options 7 | screen_resolution = RES_640X480 8 | screen_format = GRAY8 9 | render_hud = false 10 | render_crosshair = false 11 | render_weapon = true 12 | render_decals = false 13 | render_particles = false 14 | window_visible = false 15 | 16 | # make episodes start after 20 tics (after unholstering the gun) 17 | episode_start_time = 14 18 | 19 | # make episodes finish after 300 actions (tics) 20 | episode_timeout = 300 21 | 22 | # Available buttons 23 | available_buttons = 24 | { 25 | MOVE_LEFT 26 | MOVE_RIGHT 27 | ATTACK 28 | } 29 | 30 | mode = PLAYER 31 | 32 | 33 | -------------------------------------------------------------------------------- /arena/agents/simple_agent.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test agent to control marines 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | from arena.agents.base_agent import BaseAgent 9 | import gym 10 | 11 | 12 | class RandomAgent(BaseAgent): 13 | """Random action agent.""" 14 | def __init__(self, action_space=None): 15 | super(RandomAgent, self).__init__() 16 | self.action_space = action_space 17 | 18 | def step(self, obs): 19 | super(RandomAgent, self).step(obs) 20 | if hasattr(self.action_space, 'sample'): 21 | return self.action_space.sample() 22 | else: 23 | return None 24 | 25 | def reset(self, timestep=None): 26 | super(RandomAgent, self).reset(timestep) 27 | assert isinstance(self.action_space, gym.Space) 28 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/take_cover.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = take_cover.wad 6 | doom_map = map01 7 | 8 | # Rewards 9 | living_reward = 1 10 | 11 | # Rendering options 12 | screen_resolution = RES_320X240 13 | screen_format = CRCGCB 14 | render_hud = false 15 | render_crosshair = false 16 | render_weapon = false 17 | render_decals = false 18 | render_particles = false 19 | window_visible = true 20 | 21 | # Available buttons 22 | available_buttons = 23 | { 24 | MOVE_LEFT 25 | MOVE_RIGHT 26 | } 27 | 28 | # Game variables that will be in the state 29 | available_game_variables = { HEALTH } 30 | 31 | # Change it if you wish. 32 | doom_skill = 4 33 | 34 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/defend_the_line.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = defend_the_line.wad 6 | 7 | # Rewards 8 | death_penalty = 1 9 | 10 | # Rendering options 11 | screen_resolution = RES_320X240 12 | screen_format = CRCGCB 13 | render_hud = True 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 10 tics (after unholstering the gun) 21 | episode_start_time = 10 22 | 23 | 24 | # Available buttons 25 | available_buttons = 26 | { 27 | TURN_lEFT 28 | TURN_RIGHT 29 | ATTACK 30 | } 31 | 32 | # Game variables that will be in the state 33 | available_game_variables = { AMMO2 HEALTH} 34 | 35 | mode = PLAYER 36 | doom_skill = 3 37 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/multi.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = multi_deathmatch.wad 6 | 7 | # Rewards 8 | death_penalty = 1 9 | 10 | # Rendering options 11 | screen_resolution = RES_640X480 12 | screen_format = CRCGCB 13 | render_hud = true 14 | render_crosshair = true 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | 19 | window_visible = true 20 | 21 | 22 | # Available buttons 23 | available_buttons = 24 | { 25 | TURN_LEFT 26 | TURN_RIGHT 27 | ATTACK 28 | 29 | MOVE_RIGHT 30 | MOVE_LEFT 31 | 32 | MOVE_FORWARD 33 | MOVE_BACKWARD 34 | TURN_LEFT_RIGHT_DELTA 35 | LOOK_UP_DOWN_DELTA 36 | 37 | } 38 | 39 | available_game_variables = 40 | { 41 | HEALTH 42 | AMMO3 43 | } 44 | mode = ASYNC_PLAYER 45 | 46 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/health_gathering.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = health_gathering.wad 6 | 7 | # Each step is good for you! 8 | living_reward = 1 9 | # And death is not! 10 | death_penalty = 100 11 | 12 | # Rendering options 13 | screen_resolution = RES_320X240 14 | screen_format = CRCGCB 15 | render_hud = false 16 | render_crosshair = false 17 | render_weapon = false 18 | render_decals = false 19 | render_particles = false 20 | window_visible = true 21 | 22 | # make episodes finish after 2100 actions (tics) 23 | episode_timeout = 2100 24 | 25 | # Available buttons 26 | available_buttons = 27 | { 28 | TURN_LEFT 29 | TURN_RIGHT 30 | MOVE_FORWARD 31 | } 32 | 33 | # Game variables that will be in the state 34 | available_game_variables = { HEALTH } 35 | 36 | mode = PLAYER -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/health_gathering_supreme.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = health_gathering_supreme.wad 6 | 7 | # Each step is good for you! 8 | living_reward = 1 9 | # And death is not! 10 | death_penalty = 100 11 | 12 | # Rendering options 13 | screen_resolution = RES_320X240 14 | screen_format = CRCGCB 15 | render_hud = false 16 | render_crosshair = false 17 | render_weapon = false 18 | render_decals = false 19 | render_particles = false 20 | window_visible = true 21 | 22 | # make episodes finish after 2100 actions (tics) 23 | episode_timeout = 2100 24 | 25 | # Available buttons 26 | available_buttons = 27 | { 28 | TURN_LEFT 29 | TURN_RIGHT 30 | MOVE_FORWARD 31 | } 32 | 33 | # Game variables that will be in the state 34 | available_game_variables = { HEALTH } 35 | 36 | mode = PLAYER -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/basic.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = basic.wad 6 | doom_map = map01 7 | 8 | # Rewards 9 | living_reward = -1 10 | 11 | # Rendering options 12 | screen_resolution = RES_320X240 13 | screen_format = CRCGCB 14 | render_hud = True 15 | render_crosshair = false 16 | render_weapon = true 17 | render_decals = false 18 | render_particles = false 19 | window_visible = true 20 | 21 | # make episodes start after 20 tics (after unholstering the gun) 22 | episode_start_time = 14 23 | 24 | # make episodes finish after 300 actions (tics) 25 | episode_timeout = 300 26 | 27 | # Available buttons 28 | available_buttons = 29 | { 30 | MOVE_LEFT 31 | MOVE_RIGHT 32 | ATTACK 33 | } 34 | 35 | # Game variables that will be in the state 36 | available_game_variables = { AMMO2} 37 | 38 | mode = PLAYER 39 | doom_skill = 5 40 | -------------------------------------------------------------------------------- /arena/utils/spaces.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from collections import Iterable 3 | 4 | 5 | class NoneSpace(gym.Space): 6 | def __init__(self): 7 | super(NoneSpace, self).__init__(None, None) 8 | 9 | 10 | class SC2RawObsSpace(NoneSpace): 11 | def __init__(self): 12 | super(SC2RawObsSpace, self).__init__() 13 | from pysc2.env.environment import TimeStep 14 | self.obs_class = TimeStep 15 | 16 | def sample(self): 17 | return self.obs_class([None] * len(self.obs_class._fields)) 18 | 19 | def contains(self, x): 20 | return isinstance(x, self.obs_class) 21 | 22 | 23 | class SC2RawActSpace(NoneSpace): 24 | def __init__(self): 25 | super(SC2RawActSpace, self).__init__() 26 | from s2clientprotocol.sc2api_pb2 import Action 27 | self.act_class = Action 28 | 29 | def sample(self): 30 | return [] 31 | 32 | def contains(self, x): 33 | if isinstance(x, Iterable): 34 | return all([isinstance(a, self.act_class) for a in x]) 35 | else: 36 | return isinstance(x, self.act_class) 37 | 38 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/deadly_corridor.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = deadly_corridor.wad 6 | 7 | # Skill 5 is reccomanded for the scenario to be a challenge. 8 | doom_skill = 5 9 | 10 | # Rewards 11 | death_penalty = 100 12 | #living_reward = 0 13 | 14 | # Rendering options 15 | screen_resolution = RES_320X240 16 | screen_format = CRCGCB 17 | render_hud = true 18 | render_crosshair = false 19 | render_weapon = true 20 | render_decals = false 21 | render_particles = false 22 | window_visible = true 23 | 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | MOVE_LEFT 30 | MOVE_RIGHT 31 | ATTACK 32 | MOVE_FORWARD 33 | MOVE_BACKWARD 34 | TURN_LEFT 35 | TURN_RIGHT 36 | } 37 | 38 | # Game variables that will be in the state 39 | available_game_variables = { HEALTH } 40 | 41 | mode = PLAYER 42 | 43 | 44 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/my_way_home.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = my_way_home.wad 6 | 7 | # Rewards 8 | living_reward = -0.0001 9 | 10 | # Rendering options 11 | screen_resolution = RES_640X480 12 | screen_format = CRCGCB 13 | render_hud = false 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 10 tics (after unholstering the gun) 21 | episode_start_time = 10 22 | 23 | # make episodes finish after 2100 actions (tics) 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | TURN_LEFT 30 | TURN_RIGHT 31 | MOVE_FORWARD 32 | MOVE_LEFT 33 | MOVE_RIGHT 34 | } 35 | 36 | # Game variables that will be in the state 37 | available_game_variables = { AMMO0 } 38 | 39 | mode = PLAYER 40 | doom_skill = 5 41 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/predict_position.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = predict_position.wad 6 | 7 | # Rewards 8 | living_reward = -0.001 9 | 10 | # Rendering options 11 | screen_resolution = RES_800X450 12 | screen_format = CRCGCB 13 | render_hud = false 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 16 tics (after producing the rocket launcher) 21 | episode_start_time = 16 22 | 23 | # make episodes finish after 300 actions (tics) 24 | episode_timeout = 300 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | TURN_LEFT 30 | TURN_RIGHT 31 | ATTACK 32 | } 33 | 34 | # Empty list is allowed, in case you are lazy. 35 | available_game_variables = { } 36 | 37 | game_args += +sv_noautoaim 1 38 | 39 | mode = PLAYER 40 | doom_skill = 1 41 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/defend_the_center.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = defend_the_center.wad 6 | 7 | # Rewards 8 | death_penalty = 1 9 | 10 | # Rendering options 11 | screen_resolution = RES_640X480 12 | screen_format = CRCGCB 13 | render_hud = True 14 | render_crosshair = false 15 | render_weapon = true 16 | render_decals = false 17 | render_particles = false 18 | window_visible = true 19 | 20 | # make episodes start after 10 tics (after unholstering the gun) 21 | episode_start_time = 10 22 | 23 | # make episodes finish after 2100 actions (tics) 24 | episode_timeout = 2100 25 | 26 | # Available buttons 27 | available_buttons = 28 | { 29 | TURN_LEFT 30 | TURN_RIGHT 31 | ATTACK 32 | } 33 | 34 | # Game variables that will be in the state 35 | #available_game_variables = { AMMO2 HEALTH } 36 | available_game_variables = { KILLCOUNT AMMO2 HEALTH } 37 | 38 | mode = PLAYER 39 | doom_skill = 3 40 | -------------------------------------------------------------------------------- /arena/interfaces/raw_int.py: -------------------------------------------------------------------------------- 1 | from gym import spaces 2 | from arena.utils.spaces import NoneSpace 3 | 4 | 5 | class RawInt(object): 6 | """ This interface is usually used at the env side """ 7 | 8 | def __init__(self): 9 | self._observation_space = NoneSpace() 10 | self._action_space = NoneSpace() 11 | self.steps = None 12 | 13 | def setup(self, observation_space, action_space): 14 | self._observation_space = observation_space 15 | self._action_space = action_space 16 | 17 | def reset(self, obs, **kwargs): 18 | self._obs = obs 19 | self.steps = 0 20 | 21 | @property 22 | def observation_space(self): 23 | return self._observation_space 24 | 25 | @property 26 | def action_space(self): 27 | return self._action_space 28 | 29 | def obs_trans(self, obs): 30 | self._obs = obs 31 | return obs 32 | 33 | def act_trans(self, act): 34 | self._act = act 35 | self.steps += 1 36 | return act 37 | 38 | def __str__(self): 39 | return '<'+self.__class__.__name__+'>' 40 | 41 | def unwrapped(self): 42 | return self 43 | -------------------------------------------------------------------------------- /arena/agents/base_agent.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseAgent(object): 3 | """ Base agent class """ 4 | observation_space = None 5 | action_space = None 6 | 7 | def __init__(self): 8 | self.episodes = 0 9 | self.steps = 0 10 | self._obs = None 11 | 12 | def setup(self, observation_space, action_space): 13 | """ Set the observation space and action space 14 | 15 | Parameters: 16 | observation_space (gym.spaces.Space): Observation space 17 | action_space (gym.spaces.Space): Action space 18 | """ 19 | self.observation_space = observation_space 20 | self.action_space = action_space 21 | 22 | def reset(self, obs=None): 23 | """ Reset the agent with initial observation. 24 | 25 | Parameters: 26 | obs: Initial observation 27 | """ 28 | self._obs = obs 29 | self.episodes += 1 30 | self.steps = 0 31 | 32 | def step(self, obs): 33 | """ Step the agent, observe the obs and return the action. 34 | 35 | Parameters: 36 | obs: Initial observation 37 | 38 | Returns: 39 | action of the agent 40 | """ 41 | self._obs = obs 42 | self.steps += 1 43 | return None 44 | 45 | -------------------------------------------------------------------------------- /arena/agents/agt_int_wrapper.py: -------------------------------------------------------------------------------- 1 | from arena.agents.base_agent import BaseAgent 2 | from arena.interfaces.raw_int import RawInt 3 | 4 | 5 | class AgtIntWrapper(BaseAgent): 6 | inter = None 7 | 8 | def __init__(self, agent, interface=RawInt(), step_mul=1): 9 | super(AgtIntWrapper, self).__init__() 10 | self.agent = agent 11 | self.inter = interface 12 | self.step_mul = step_mul 13 | self.act = None 14 | assert isinstance(self.inter, RawInt) 15 | 16 | def setup(self, observation_space, action_space): 17 | super(AgtIntWrapper, self).setup(observation_space, action_space) 18 | self.inter.unwrapped()._observation_space = observation_space 19 | self.inter.unwrapped()._action_space = action_space 20 | 21 | def reset(self, obs, inter_kwargs={}): 22 | super(AgtIntWrapper, self).reset(obs) 23 | self.inter.reset(obs, **inter_kwargs) 24 | self.agent.setup(self.inter.observation_space, self.inter.action_space) 25 | self.agent.reset(self.inter.obs_trans(obs)) 26 | self.act = None 27 | 28 | def step(self, obs): 29 | super(AgtIntWrapper, self).step(obs) 30 | obs = self.inter.obs_trans(obs) 31 | if (self.steps -1) % self.step_mul == 0: 32 | self.act = self.agent.step(obs) 33 | act = self.inter.act_trans(self.act) 34 | return act 35 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/oblige.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | # Rendering options 6 | screen_resolution = RES_320X240 7 | screen_format = CRCGCB 8 | render_hud = true 9 | render_crosshair = false 10 | render_weapon = true 11 | render_decals = false 12 | render_particles = false 13 | window_visible = true 14 | 15 | # make episodes finish after 4200 actions (tics) 16 | episode_timeout = 4200 17 | 18 | # Available buttons 19 | available_buttons = 20 | { 21 | ATTACK 22 | USE 23 | SPEED 24 | STRAFE 25 | 26 | MOVE_RIGHT 27 | MOVE_LEFT 28 | MOVE_BACKWARD 29 | MOVE_FORWARD 30 | TURN_RIGHT 31 | TURN_LEFT 32 | 33 | SELECT_WEAPON1 34 | SELECT_WEAPON2 35 | SELECT_WEAPON3 36 | SELECT_WEAPON4 37 | SELECT_WEAPON5 38 | SELECT_WEAPON6 39 | 40 | SELECT_NEXT_WEAPON 41 | SELECT_PREV_WEAPON 42 | 43 | LOOK_UP_DOWN_DELTA 44 | TURN_LEFT_RIGHT_DELTA 45 | MOVE_LEFT_RIGHT_DELTA 46 | 47 | } 48 | 49 | # Game variables that will be in the state 50 | available_game_variables = 51 | { 52 | KILLCOUNT 53 | HEALTH 54 | ARMOR 55 | SELECTED_WEAPON 56 | SELECTED_WEAPON_AMMO 57 | } 58 | mode = PLAYER 59 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/deathmatch.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = deathmatch.wad 6 | 7 | # Rendering options 8 | screen_resolution = RES_320X240 9 | screen_format = CRCGCB 10 | render_hud = true 11 | render_crosshair = false 12 | render_weapon = true 13 | render_decals = false 14 | render_particles = false 15 | window_visible = true 16 | 17 | # make episodes finish after 4200 actions (tics) 18 | episode_timeout = 4200 19 | 20 | # Available buttons 21 | available_buttons = 22 | { 23 | ATTACK 24 | SPEED 25 | STRAFE 26 | 27 | MOVE_RIGHT 28 | MOVE_LEFT 29 | MOVE_BACKWARD 30 | MOVE_FORWARD 31 | TURN_RIGHT 32 | TURN_LEFT 33 | 34 | SELECT_WEAPON1 35 | SELECT_WEAPON2 36 | SELECT_WEAPON3 37 | SELECT_WEAPON4 38 | SELECT_WEAPON5 39 | SELECT_WEAPON6 40 | 41 | SELECT_NEXT_WEAPON 42 | SELECT_PREV_WEAPON 43 | 44 | LOOK_UP_DOWN_DELTA 45 | TURN_LEFT_RIGHT_DELTA 46 | MOVE_LEFT_RIGHT_DELTA 47 | 48 | } 49 | 50 | # Game variables that will be in the state 51 | available_game_variables = 52 | { 53 | KILLCOUNT 54 | HEALTH 55 | ARMOR 56 | SELECTED_WEAPON 57 | SELECTED_WEAPON_AMMO 58 | } 59 | mode = PLAYER 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Arena 2 | Common Environment Interface for Multi-player Games. 3 | See our Technical Report https://arxiv.org/abs/1907.09467 4 | 5 | Qing Wang*, Jiechao Xiong*, Lei Han, Meng Fang, Xinghai Sun, Zhuobin Zheng, Peng Sun, and Zhengyou Zhang. 6 | Arena: a toolkit for multi-agent reinforcement learning. 7 | arXiv preprint arXiv:1907.09467, 2019. 8 | (* Equal contribution) 9 | 10 | 11 | ## Install 12 | cd to the folder and run the command: 13 | ``` 14 | pip install -e . 15 | ``` 16 | 17 | 18 | ## Dependencies 19 | All the required dependencies will be automatically installed with the `pip install` command. 20 | There are several other dependencies that you have to install manually, in an on-demand way (Just install it when you need it): 21 | * For `StarCraft II` environment, install [TImitate](https://github.com/tencent-ailab/TImitate) 22 | * For `ViZDoom` environment, install the [vizdoom](https://github.com/mwydmuch/ViZDoom), version >= '1.1.8' 23 | * For `Pommerman` einvironment, [pommerman](https://github.com/MultiAgentLearning/playground) 24 | * For `soccer` environment, [dm_control](https://github.com/deepmind/dm_control) 25 | 26 | 27 | # Disclaimer 28 | This is not an officially supported Tencent product. 29 | The code and data in this repository are for research purpose only. 30 | No representation or warranty whatsoever, expressed or implied, is made as to its accuracy, reliability or completeness. 31 | We assume no liability and are not responsible for any misuse or damage caused by the code and data. 32 | Your use of the code and data are subject to applicable laws and your use of them is at your own risk.: 33 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/run_loop.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | 3 | from arena.utils.vizdoom.player import player_window_cv 4 | from arena.utils.vizdoom.core_env import VecEnv 5 | 6 | 7 | def run_loop_venv(venv, agents, max_steps=3000, is_window_cv_visible=False, 8 | verbose=4): 9 | obs = venv.reset() 10 | print('new episode') 11 | 12 | t_start = time() 13 | i_ep_step = 0 14 | ep_return = None 15 | for i_step in range(0, max_steps): 16 | actions = [ag.step(o) for ag, o in zip(agents, obs)] 17 | obs, rwd, dones, infos = venv.step(actions) 18 | print('run_loop/run_loop_venv/infos') 19 | print(infos) 20 | if ep_return is None: 21 | ep_return = rwd 22 | else: 23 | ep_return = [a+b for a, b in zip(ep_return, rwd)] 24 | 25 | if verbose >= 4: 26 | print('step: ', i_step) 27 | print('ep step: ', i_ep_step) 28 | for i, (o, r, d) in enumerate(zip(obs, rwd, dones)): 29 | print('o shape = {}, o type = {}, r = {}, done = {}'.format( 30 | o.shape, o.dtype, r, d)) 31 | if is_window_cv_visible: 32 | player_window_cv(i, o) 33 | 34 | if all(dones): 35 | print('ep return when all dones: ', ep_return) 36 | venv.reset() 37 | i_ep_step = 0 38 | ep_return = None 39 | print('new episode') 40 | else: 41 | i_ep_step += 1 42 | 43 | t_end = time() 44 | print('elapsed_time = ', t_end - t_start) 45 | print('fps = ', float(max_steps) / float(t_end - t_start)) 46 | 47 | 48 | def run_loop_env(env, agent, **kwargs): 49 | venv = VecEnv([env]) 50 | assert(type(agent) is not list) 51 | run_loop_venv(venv, [agent], **kwargs) 52 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/dogfight.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = dogfight.wad 6 | 7 | 8 | # Rendering options 9 | screen_resolution = RES_160X120 10 | screen_format = CRCGCB 11 | render_hud = true 12 | render_crosshair = true 13 | render_weapon = true 14 | render_decals = false 15 | render_particles = false 16 | window_visible = false 17 | 18 | # start, end in tics 19 | episode_start_time = 10 20 | episode_timeout = 4200 21 | 22 | #death_penalty = 0.1 23 | living_reward = -0.0001 24 | 25 | # Available buttons 26 | available_buttons = 27 | { 28 | ATTACK 29 | USE 30 | 31 | TURN_LEFT 32 | TURN_RIGHT 33 | MOVE_RIGHT 34 | MOVE_LEFT 35 | MOVE_FORWARD 36 | MOVE_BACKWARD 37 | 38 | TURN_LEFT_RIGHT_DELTA 39 | LOOK_UP_DOWN_DELTA 40 | } 41 | # TODO(pengsun): update to vizdoom 1.1.8 to support parsing HITCOUNT, HITS_TAKEN 42 | available_game_variables = 43 | { 44 | KILLCOUNT 45 | ITEMCOUNT 46 | SECRETCOUNT 47 | FRAGCOUNT 48 | DEATHCOUNT 49 | HEALTH 50 | SELECTED_WEAPON_AMMO 51 | HITCOUNT 52 | #HITS_TAKEN 53 | #DAMAGE_TAKEN 54 | #DAMAGECOUNT 55 | #FRAGCOUNT 56 | POSITION_X 57 | POSITION_Y 58 | ARMOR 59 | PLAYER_NUMBER 60 | DEAD 61 | # ANGLE 62 | } 63 | 64 | game_args += +sv_noautoaim 1 65 | game_args += +sv_respawnprotect 1 66 | game_args += +sv_forcerespawn 1 67 | #game_args += +sv_unlimited_pickup 1 68 | 69 | mode = PLAYER 70 | doom_skill = 3 71 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/cig.cfg: -------------------------------------------------------------------------------- 1 | # Lines starting with # are treated as comments (or with whitespaces+#). 2 | # It doesn't matter if you use capital letters or not. 3 | # It doesn't matter if you use underscore or camel notation for keys, e.g. episode_timeout is the same as episodeTimeout. 4 | 5 | doom_scenario_path = cig2017.wad 6 | 7 | #12 minutes 8 | episode_timeout = 21000 9 | 10 | # Rendering options 11 | screen_resolution = RES_160X120 12 | screen_format = CRCGCB 13 | render_hud = false 14 | render_crosshair = true 15 | render_weapon = false 16 | render_decals = false 17 | render_particles = false 18 | window_visible = false 19 | 20 | episode_start_time = 10 21 | #death_penalty = 0.1 22 | living_reward = -0.0001 23 | 24 | # Available buttons 25 | available_buttons = 26 | { 27 | ATTACK 28 | 29 | TURN_LEFT 30 | TURN_RIGHT 31 | MOVE_RIGHT 32 | MOVE_LEFT 33 | MOVE_FORWARD 34 | MOVE_BACKWARD 35 | SPEED 36 | TURN180 37 | TURN_LEFT_RIGHT_DELTA 38 | } 39 | # TODO(pengsun): update to vizdoom 1.1.8 to support parsing HITCOUNT, HITS_TAKEN 40 | available_game_variables = 41 | { 42 | KILLCOUNT 43 | ITEMCOUNT 44 | SECRETCOUNT 45 | FRAGCOUNT 46 | DEATHCOUNT 47 | HEALTH 48 | SELECTED_WEAPON_AMMO 49 | HITCOUNT 50 | HITS_TAKEN 51 | #DAMAGE_TAKEN 52 | #DAMAGECOUNT 53 | #FRAGCOUNT 54 | POSITION_X 55 | POSITION_Y 56 | ARMOR 57 | PLAYER_NUMBER 58 | DEAD 59 | 60 | # ANGLE 61 | } 62 | 63 | game_args += +sv_noautoaim 1 64 | game_args += +sv_respawnprotect 1 65 | game_args += +sv_forcerespawn 1 66 | game_args += +viz_respawn_delay 10 67 | #game_args += +sv_unlimited_pickup 1 68 | 69 | mode = PLAYER 70 | doom_skill = 3 71 | 72 | -------------------------------------------------------------------------------- /arena/utils/pong2p/pong2p_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import pygame 6 | 7 | class Rect(): 8 | def __init__(self, x, y, w, h): 9 | self._x = x 10 | self._y = y 11 | self._w = w 12 | self._h = h 13 | 14 | def set_right(self, right): 15 | self._x = right - self._w 16 | 17 | def set_left(self, left): 18 | self._x = left 19 | 20 | def set_top(self, top): 21 | self._y = top 22 | 23 | def set_bottom(self, bottom): 24 | self._y = bottom - self._h 25 | 26 | def set_centery(self, centery): 27 | self._y = centery - self._h / 2.0 28 | 29 | def set_pos(self, x, y): 30 | self._x = x 31 | self._y = y 32 | 33 | def add_x(self, dx): 34 | self._x += dx 35 | 36 | def add_y(self, dy): 37 | self._y += dy 38 | 39 | def add_pos(self, dx, dy): 40 | self._x += dx 41 | self._y += dy 42 | 43 | def x(self): 44 | return self._x 45 | 46 | def y(self): 47 | return self._y 48 | 49 | def top(self): 50 | return self._y 51 | 52 | def bottom(self): 53 | return self._y + self._h 54 | 55 | def left(self): 56 | return self._x 57 | 58 | def right(self): 59 | return self._x + self._w 60 | 61 | def width(self): 62 | return self._w 63 | 64 | def height(self): 65 | return self._h 66 | 67 | def centerx(self): 68 | return self._x + self._w / 2.0 69 | 70 | def centery(self): 71 | return self._y + self._h / 2.0 72 | 73 | def rect(self): 74 | return pygame.Rect(self._x, self._y, self._w, self._h) 75 | -------------------------------------------------------------------------------- /arena/wrappers/sc2stat_wrapper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from gym import Wrapper 4 | from pysc2.lib.typeenums import ABILITY_ID 5 | 6 | 7 | class StatAllAction(Wrapper): 8 | """Statistics for all actions counting.""" 9 | def __init__(self, env): 10 | super(StatAllAction, self).__init__(env) 11 | self._ab_dict = dict([(ab.value, 0) for ab in ABILITY_ID]) 12 | 13 | def _reset_stat(self): 14 | self._ab_dict = dict([(ab.value, 0) for ab in ABILITY_ID]) 15 | 16 | def _action_stat(self, actions): 17 | for action in actions[0]: 18 | if action.action_raw.unit_command.ability_id in self._ab_dict: 19 | self._ab_dict[action.action_raw.unit_command.ability_id] += 1 20 | 21 | def step(self, actions): 22 | self._action_stat(actions) 23 | obs, reward, done, info = self.env.step(actions) 24 | if done: 25 | for k, v in self._ab_dict.items(): 26 | if v > 0: 27 | info[ABILITY_ID(k).name] = v 28 | self._reset_stat() 29 | return obs, reward, done, info 30 | 31 | 32 | class StatZStatFn(Wrapper): 33 | """Statistics for ZStat Filename""" 34 | def step(self, actions): 35 | obs, reward, done, info = self.env.step(actions) 36 | if done: 37 | if not hasattr(self, 'inters'): 38 | logging.warning("Cannot find the field 'inters' for this env {}".format(str(self))) 39 | return obs, reward, done, info 40 | for ind, interf in enumerate(self.inters): 41 | key = 'agt{}zstat'.format(ind) 42 | root_interf = interf.unwrapped() 43 | if not hasattr(root_interf, 'cur_zstat_fn'): 44 | logging.warning("Cannot find the field 'cur_zstat_fn' for the root interface {}".format(root_interf)) 45 | return obs, reward, done, info 46 | info[key] = root_interf.cur_zstat_fn 47 | return obs, reward, done, info -------------------------------------------------------------------------------- /arena/interfaces/sc2full_formal/noop_int.py: -------------------------------------------------------------------------------- 1 | """ Gym env wrappers """ 2 | from gym import spaces 3 | import numpy as np 4 | 5 | from arena.interfaces.interface import Interface 6 | 7 | 8 | class NoopMaskInt(Interface): 9 | """ have to be wrapped after FullActInt or NoopActIntV2 which contains self.noop_cnt """ 10 | def __init__(self, inter, max_noop_dim=10): 11 | super(self.__class__, self).__init__(inter) 12 | self.ability_dim = self.action_space.spaces[0].n 13 | self.noop_dim = max_noop_dim 14 | self.pre_obs_space = inter.observation_space 15 | 16 | def reset(self, obs, **kwargs): 17 | super(self.__class__, self).reset(obs, **kwargs) 18 | 19 | @property 20 | def observation_space(self): 21 | obs_spec = spaces.Tuple(self.pre_obs_space.spaces + [ 22 | spaces.Box(low=0, high=1, shape=(self.ability_dim,), dtype=np.bool), 23 | spaces.Box(low=0, high=1, shape=(self.noop_dim,), dtype=np.bool)]) 24 | return obs_spec 25 | 26 | def obs_trans(self, raw_obs): 27 | if self.inter and hasattr(self.inter, 'noop_cnt'): 28 | obs = self.inter.obs_trans(raw_obs) 29 | obs += self._make_noop_mask(self.inter.noop_cnt) 30 | return obs 31 | else: 32 | raise BaseException('NoopInt has to be used together with ' 33 | 'FullActInt or NoopActIntV2') 34 | 35 | def _make_noop_mask(self, noop_cnt): 36 | if noop_cnt > 0: 37 | ability_mask = np.zeros(shape=(self.ability_dim,), dtype=np.bool) 38 | ability_mask[0] = 1 39 | noop_mask = np.zeros(shape=(self.noop_dim,), dtype=np.bool) 40 | noop_mask[noop_cnt-1] = 1 41 | else: 42 | ability_mask = np.ones(shape=(self.ability_dim,), dtype=np.bool) 43 | noop_mask = np.ones(shape=(self.noop_dim,), dtype=np.bool) 44 | return ability_mask, noop_mask -------------------------------------------------------------------------------- /arena/interfaces/vizdoom/act_int.py: -------------------------------------------------------------------------------- 1 | """Vizdoom action interfaces""" 2 | from copy import deepcopy 3 | 4 | from gym.spaces import Discrete as GymDiscrete 5 | from vizdoom import Button 6 | 7 | from arena.interfaces.interface import Interface 8 | 9 | 10 | class Discrete6ActionInt(Interface): 11 | """Wu Yuxing's 6-action setting. 12 | 13 | The six actions are: 14 | move forward, 15 | fire, 16 | move left, 17 | move right, 18 | turn left, 19 | turn right, 20 | """ 21 | def __init__(self, inter): 22 | super(Discrete6ActionInt, self).__init__(inter) 23 | 24 | self.allowed_buttons = [ 25 | Button.ATTACK, 26 | Button.TURN_LEFT, 27 | Button.TURN_RIGHT, 28 | Button.MOVE_RIGHT, 29 | Button.MOVE_LEFT, 30 | Button.MOVE_FORWARD, 31 | Button.MOVE_BACKWARD, 32 | Button.SPEED, 33 | Button.TURN180, 34 | Button.TURN_LEFT_RIGHT_DELTA 35 | ] 36 | # NOTE: [7:] actions are not exposed 37 | self.allowed_actions = [ 38 | [0, 0, 0, 0, 0, 1, 0, 1, 0, 0], # 0 move fast forward 39 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], # 1 fire 40 | [0, 0, 0, 1, 0, 0, 0, 1, 0, 0], # 2 move left 41 | [0, 0, 0, 0, 1, 0, 0, 1, 0, 0], # 3 move right 42 | [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], # 4 turn left 43 | [0, 0, 1, 0, 0, 0, 0, 1, 0, 0], # 5 turn right 44 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 20], # 6 turn left 40 degree and move forward 45 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 20], # 7 turn right 40 degree and move forward 46 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], # 8 move forward 47 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0], # 9 turn 180 48 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # 10 move left 49 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # 11 move right 50 | [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], # 12 turn left 51 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # 13 turn right 52 | ] 53 | pass 54 | 55 | @property 56 | def action_space(self): 57 | return GymDiscrete(n=6) 58 | 59 | def act_trans(self, act): 60 | # act (as index) -> button vector 61 | act = int(act) 62 | return self.allowed_actions[act] 63 | -------------------------------------------------------------------------------- /arena/utils/unit_util.py: -------------------------------------------------------------------------------- 1 | from arena.utils.constant import AllianceType 2 | import numpy as np 3 | 4 | 5 | def collect_units_by_type(units, unit_type, alliance=AllianceType.SELF.value): 6 | """ return unit's ID in the same type """ 7 | return [u for u in units 8 | if u.unit_type == unit_type and u.alliance == alliance] 9 | 10 | 11 | def collect_units_by_types(units, unit_types): 12 | """ return unit's ID in the unit_types list """ 13 | return [u for u in units if u.unit_type in unit_types] 14 | 15 | 16 | def collect_units_by_alliance(units, alliance=AllianceType.SELF.value): 17 | return [u for u in units if u.alliance == alliance] 18 | 19 | 20 | def find_units_by_tag(units, tag): 21 | return [u for u in units if u.tag == tag] 22 | 23 | 24 | def find_weakest(units): 25 | """ find the weakest one to 'unit' within the list 'units' """ 26 | if not units: 27 | return None 28 | dd = np.asarray([u.health for u in units]) 29 | return units[dd.argmin()] 30 | 31 | 32 | def find_strongest(units): 33 | """ find the strongest one to 'unit' within the list 'units' """ 34 | if not units: 35 | return None 36 | dd = np.asarray([u.health for u in units]) 37 | return units[dd.argmax()] 38 | 39 | 40 | def merge_units_from(l1, l2): 41 | """ Merge info from l2 to l1 """ 42 | for u2 in l2: 43 | matched_u = [u for u in l1 if u.tag == u2.tag] 44 | if len(matched_u) > 0: # Found 45 | assert len(matched_u) == 1 46 | u1 = matched_u[0] 47 | # Set each unset field 48 | if not u1.HasField('weapon_cooldown'): 49 | u1.weapon_cooldown = u2.weapon_cooldown 50 | if not u1.HasField('engaged_target_tag'): 51 | u1.engaged_target_tag = u2.engaged_target_tag 52 | for od2 in u2.orders: 53 | if not od2 in u1.orders: 54 | u1.orders.extend([od2]) 55 | else: # Not Found 56 | l1.extend([u2]) 57 | return 58 | 59 | 60 | def merge_units(l1, l2): 61 | " Merge units info from two raw_data.units to complete info" 62 | merge_units_from(l1, l2) 63 | merge_units_from(l2, l1) 64 | 65 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/run_parallel.py: -------------------------------------------------------------------------------- 1 | """ 2 | copied from PySC2 code 3 | """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import functools 9 | from concurrent import futures 10 | 11 | 12 | class RunParallel(object): 13 | """Run all funcs in parallel.""" 14 | 15 | def __init__(self, timeout=None): 16 | self._timeout = timeout 17 | self._executor = None 18 | self._workers = 0 19 | 20 | def run(self, funcs): 21 | """Run a set of functions in parallel, returning their results. 22 | 23 | Make sure any function you pass exits with a reasonable timeout. If it 24 | doesn't return within the timeout or the result is ignored due an exception 25 | in a separate thread it will continue to stick around until it finishes, 26 | including blocking process exit. 27 | 28 | Args: 29 | funcs: An iterable of functions or iterable of args to functools.partial. 30 | 31 | Returns: 32 | A list of return values with the values matching the order in funcs. 33 | 34 | Raises: 35 | Propagates the first exception encountered in one of the functions. 36 | """ 37 | funcs = [f if callable(f) else functools.partial(*f) for f in funcs] 38 | if len(funcs) == 1: # Ignore threads if it's not needed. 39 | return [funcs[0]()] 40 | if len(funcs) > self._workers: # Lazy init and grow as needed. 41 | self.shutdown() 42 | self._workers = len(funcs) 43 | self._executor = futures.ThreadPoolExecutor(self._workers) 44 | futs = [self._executor.submit(f) for f in funcs] 45 | done, not_done = futures.wait(futs, self._timeout, futures.FIRST_EXCEPTION) 46 | # Make sure to propagate any exceptions. 47 | for f in done: 48 | if not f.cancelled() and f.exception() is not None: 49 | if not_done: 50 | # If there are some calls that haven't finished, cancel and recreate 51 | # the thread pool. Otherwise we may have a thread running forever 52 | # blocking parallel calls. 53 | for nd in not_done: 54 | nd.cancel() 55 | self.shutdown(False) # Don't wait, they may be deadlocked. 56 | raise f.exception() 57 | # Either done or timed out, so don't wait again. 58 | return [f.result(timeout=0) for f in futs] 59 | 60 | def shutdown(self, wait=True): 61 | if self._executor: 62 | self._executor.shutdown(wait) 63 | self._executor = None 64 | self._workers = 0 65 | 66 | def __del__(self): 67 | self.shutdown() 68 | -------------------------------------------------------------------------------- /arena/wrappers/pong2p/pong2p_compete.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import gym 6 | import gym.spaces 7 | import numpy as np 8 | import random 9 | 10 | #import gym_compete.envs 11 | #from gym_compete.wrappers.pong_wrappers import wrap_pong 12 | 13 | class WrapCompete(gym.Wrapper): 14 | def __init__(self, env): 15 | """ Wrap compete envs 16 | """ 17 | gym.Wrapper.__init__(self, env) 18 | 19 | def reset(self, **kwargs): 20 | obs = self.env.reset(**kwargs) 21 | if isinstance(obs, tuple): 22 | return np.array(obs) 23 | return obs 24 | 25 | def step(self, action): 26 | obs, reward, done, info = self.env.step(action) 27 | if isinstance(obs, tuple): 28 | obs = np.array(obs) 29 | if isinstance(reward, tuple): 30 | reward = np.array(reward) 31 | 32 | return obs, reward, done, info 33 | 34 | 35 | class TransposeWrapper(gym.ObservationWrapper): 36 | def observation(self, observation): 37 | if isinstance(observation, tuple): 38 | return tuple([ 39 | np.transpose(np.array(ob), axes=(2,0,1)) 40 | for ob in observation 41 | ]) 42 | else: 43 | return np.transpose(np.array(observation), axes=(2,0,1)) 44 | 45 | class NoRwdResetEnv(gym.Wrapper): 46 | def __init__(self, env, no_reward_thres): 47 | """Reset the environment if no reward received in N steps 48 | """ 49 | gym.Wrapper.__init__(self, env) 50 | self.no_reward_thres = no_reward_thres 51 | self.no_reward_step = 0 52 | 53 | def step(self, action): 54 | obs, reward, done, info = self.env.step(action) 55 | if isinstance(reward, tuple): 56 | if all(r == 0.0 for r in reward): 57 | self.no_reward_step += 1 58 | else: 59 | self.no_reward_step = 0 60 | else: 61 | if reward == 0.0: 62 | self.no_reward_step += 1 63 | else: 64 | self.no_reward_step = 0 65 | 66 | if self.no_reward_step > self.no_reward_thres: 67 | done = True 68 | return obs, reward, done, info 69 | 70 | def reset(self, **kwargs): 71 | obs = self.env.reset(**kwargs) 72 | self.no_reward_step = 0 73 | return obs 74 | 75 | #def make_pong(env_id, episode_life=True, clip_rewards=True, frame_stack=True, scale=True, seed=None): 76 | # env = wrap_pong(env_id, episode_life, clip_rewards, frame_stack, scale, seed) 77 | # #env = TransposeWrapper(env) 78 | # env = NoRwdResetEnv(env, no_reward_thres = 1000) 79 | # env = WrapCompete(env) 80 | # return env 81 | -------------------------------------------------------------------------------- /arena/utils/run_loop.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # adopted from run_loop.py in pysc2 3 | """A run loop for agent/environment interaction.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | from pysc2.lib.features import Features 10 | from arena.utils.unit_util import merge_units 11 | import time 12 | 13 | 14 | def run_loop(agents, env, max_frames=0, max_episodes=0, sleep_time_per_step=0, merge_units_info=False): 15 | """A run loop to have agents and an environment interact.""" 16 | total_frames = 0 17 | total_episodes = 0 18 | start_time = time.time() 19 | result_stat = [0]*3 # n_draw, n_win, n_loss 20 | 21 | observation_spec = env.observation_spec() 22 | action_spec = env.action_spec() 23 | #for agent, obs_spec, act_spec in zip(agents, observation_spec, action_spec): 24 | # agent.setup(obs_spec, act_spec) 25 | 26 | try: 27 | while not max_episodes or total_episodes < max_episodes: 28 | total_episodes += 1 29 | timesteps = env.reset() 30 | for a in agents: 31 | a.reset() 32 | while True: 33 | total_frames += 1 34 | if merge_units_info: 35 | assert len(timesteps)==2 36 | # Merge units from two timesteps to one 37 | merge_units(timesteps[0].observation.raw_data.units, 38 | timesteps[1].observation.raw_data.units) 39 | for i in range(2): 40 | timesteps[i].observation['units'] = \ 41 | Features.transform_unit_control(timesteps[i].observation.raw_data.units) 42 | actions = [agent.step(timestep) 43 | for agent, timestep in zip(agents, timesteps)] 44 | if max_frames and total_frames >= max_frames: 45 | return 46 | if timesteps[0].last(): # player 1 47 | result_stat[timesteps[0].reward] += 1 48 | break 49 | if sleep_time_per_step > 0: 50 | time.sleep(sleep_time_per_step) 51 | timesteps = env.step(actions) 52 | except KeyboardInterrupt: 53 | pass 54 | finally: 55 | print("Game result statistics: Win: %2d, Loss: %2d, Draw: %2d" % ( 56 | result_stat[1], result_stat[-1], result_stat[0])) 57 | elapsed_time = time.time() - start_time 58 | print("Took %.3f seconds for %s steps: %.3f fps" % ( 59 | elapsed_time, total_frames, total_frames / elapsed_time)) 60 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/player.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | class PlayerHostConfig(object): 5 | def __init__(self, port, num_players=2): 6 | self.num_players = num_players 7 | self.port = port 8 | print('Host {}'.format(self.port)) 9 | 10 | class PlayerJoinConfig(object): 11 | def __init__(self, port): 12 | self.join_ip = '127.0.0.1' 13 | self.port = port 14 | print('Player {}'.format(self.port)) 15 | 16 | class PlayerConfig(object): 17 | def __init__(self): 18 | self.config_path = None 19 | self.player_mode = None 20 | self.is_render_hud = None 21 | self.screen_resolution = None 22 | self.screen_format = None 23 | self.is_window_visible = None 24 | self.ticrate = None 25 | self.episode_timeout = None 26 | self.name = None 27 | self.colorset = None 28 | 29 | self.repeat_frame = 2 30 | self.num_bots = 0 31 | 32 | self.is_multiplayer_game = True 33 | self.host_cfg = None 34 | self.join_cfg = None 35 | 36 | def player_host_setup(game, host_config): 37 | game.add_game_args(' '.join([ 38 | "-host {}".format(host_config.num_players), 39 | "-port {}".format(host_config.port), 40 | "-netmode 0", 41 | "-deathmatch", 42 | "+sv_spawnfarthest 1", 43 | "+viz_nocheat 0", 44 | ])) 45 | return game 46 | 47 | def player_join_setup(game, join_config): 48 | game.add_game_args(' '.join([ 49 | "-join {}".format(join_config.join_ip), 50 | "-port {}".format(join_config.port), 51 | ])) 52 | return game 53 | 54 | def player_setup(game, player_config): 55 | pc = player_config # a short name 56 | 57 | # read in the config from file first, allow over-write later 58 | if pc.config_path is not None: 59 | game.load_config(pc.config_path) 60 | 61 | if pc.player_mode is not None: 62 | game.set_mode(pc.player_mode) 63 | if pc.screen_resolution is not None: 64 | game.set_screen_resolution(pc.screen_resolution) 65 | if pc.screen_format is not None: 66 | game.set_screen_format(pc.screen_format) 67 | if pc.is_window_visible is not None: 68 | game.set_window_visible(pc.is_window_visible) 69 | if pc.ticrate is not None: 70 | game.set_ticrate(pc.ticrate) 71 | if pc.episode_timeout is not None: 72 | game.set_episode_timeout(pc.episode_timeout) 73 | 74 | game.set_console_enabled(False) 75 | 76 | if pc.name is not None: 77 | game.add_game_args("+name {}".format(pc.name)) 78 | if pc.colorset is not None: 79 | game.add_game_args("+colorset {}".format(pc.colorset)) 80 | return game 81 | 82 | def player_window_cv(i_player, img, transpose=False): 83 | window_name = 'player ' + str(i_player) 84 | if transpose: 85 | img = np.transpose(img, axes=(1, 2, 0)) 86 | cv2.imshow(window_name, img) 87 | cv2.waitKey(1) 88 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/player_vs_f1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | class PlayerHostConfig(object): 5 | def __init__(self, port, num_players=2): 6 | self.num_players = num_players 7 | self.port = port 8 | print('Host {}'.format(self.port)) 9 | 10 | class PlayerJoinConfig(object): 11 | def __init__(self): 12 | self.join_ip = 'localhost' 13 | # self.port = port 14 | # print('Player {}'.format(self.port)) 15 | 16 | class PlayerConfig(object): 17 | def __init__(self): 18 | self.config_path = None 19 | self.player_mode = None 20 | self.is_render_hud = None 21 | self.screen_resolution = None 22 | self.screen_format = None 23 | self.is_window_visible = None 24 | self.ticrate = None 25 | self.episode_timeout = None 26 | self.name = None 27 | self.colorset = None 28 | 29 | self.repeat_frame = 2 30 | self.num_bots = 0 31 | 32 | self.is_multiplayer_game = True 33 | self.host_cfg = None 34 | self.join_cfg = None 35 | 36 | def player_host_setup(game, host_config): 37 | game.add_game_args(' '.join([ 38 | "-host {}".format(host_config.num_players), 39 | "-port {}".format(host_config.port), 40 | "-netmode 0", 41 | "-deathmatch", 42 | "+sv_spawnfarthest 1", 43 | "+viz_nocheat 0", 44 | ])) 45 | return game 46 | 47 | def player_join_setup(game, join_config): 48 | game.add_game_args(' '.join([ 49 | "-join {}".format(join_config.join_ip), 50 | # "-port {}".format(join_config.port), 51 | ])) 52 | return game 53 | 54 | def player_setup(game, player_config): 55 | pc = player_config # a short name 56 | 57 | # read in the config from file first, allow over-write later 58 | if pc.config_path is not None: 59 | game.load_config(pc.config_path) 60 | 61 | if pc.player_mode is not None: 62 | game.set_mode(pc.player_mode) 63 | if pc.screen_resolution is not None: 64 | game.set_screen_resolution(pc.screen_resolution) 65 | if pc.screen_format is not None: 66 | game.set_screen_format(pc.screen_format) 67 | if pc.is_window_visible is not None: 68 | game.set_window_visible(pc.is_window_visible) 69 | if pc.ticrate is not None: 70 | game.set_ticrate(pc.ticrate) 71 | if pc.episode_timeout is not None: 72 | game.set_episode_timeout(pc.episode_timeout) 73 | 74 | game.set_console_enabled(False) 75 | 76 | if pc.name is not None: 77 | game.add_game_args("+name {}".format(pc.name)) 78 | if pc.colorset is not None: 79 | game.add_game_args("+colorset {}".format(pc.colorset)) 80 | return game 81 | 82 | def player_window_cv(i_player, img, transpose=False): 83 | window_name = 'player ' + str(i_player) 84 | if transpose: 85 | img = np.transpose(img, axes=(1, 2, 0)) 86 | cv2.imshow(window_name, img) 87 | cv2.waitKey(1) 88 | -------------------------------------------------------------------------------- /arena/interfaces/interface.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | from arena.interfaces.raw_int import RawInt 4 | 5 | 6 | class Interface(RawInt): 7 | """ 8 | Interface class 9 | 10 | """ 11 | inter = None 12 | 13 | def __init__(self, interface): 14 | """ Initialization. 15 | 16 | :param interface previous interface to wrap on""" 17 | # In `__init__()` of derived class, one should firstly call 18 | # super(self.__class__, self).__init__(interface) 19 | if interface is None: 20 | self.inter = RawInt() 21 | else: 22 | self.inter = interface 23 | assert isinstance(self.inter, RawInt) 24 | 25 | def reset(self, obs, **kwargs): 26 | """ Reset this interface. 27 | For some reasons, obs space and action space may be specified on reset(). 28 | 29 | :param obs input obs (received by the root interface)""" 30 | # In `reset()` of derived class, one should firstly call 31 | # super(self.__class__, self).reset(obs) 32 | self.inter.reset(obs, **kwargs) 33 | 34 | def obs_trans(self, obs): 35 | """ Observation Transformation. This is a recursive call. """ 36 | obs = self.inter.obs_trans(obs) 37 | # Implement customized obs_trans here in derived class 38 | return obs 39 | 40 | def act_trans(self, act): 41 | """ Action Transformation. This is a recursive call. """ 42 | # TODO(peng): raise NotImplementedError, encourage recursive call in derived class 43 | # Implement customized act_trans here in derived class 44 | act = self.inter.act_trans(act) 45 | return act 46 | 47 | def unwrapped(self): 48 | """ Get the root instance. 49 | This is usually used for storing global information. 50 | For example, raw obs and raw act are saved by RawInt(). """ 51 | return self.inter.unwrapped() 52 | 53 | @property 54 | def observation_space(self): 55 | """ Observation Space, calculated in a recursive manner. 56 | Implement customized observation_space here in derived class. """ 57 | return self.inter.observation_space 58 | 59 | @property 60 | def action_space(self): 61 | """ Action Space, calculated in a recursive manner. 62 | Implement customized action_space here in derived class """ 63 | return self.inter.action_space 64 | 65 | def setup(self, observation_space, action_space): 66 | self.unwrapped().setup(observation_space, action_space) 67 | 68 | def __str__(self): 69 | """ Get the name of all stacked interface. """ 70 | # TODO(peng): return my_name + '<' + wrapped_interface_name + '>' 71 | s = str(self.inter) 72 | return s+'<'+self.__class__.__name__+'>' 73 | -------------------------------------------------------------------------------- /arena/wrappers/vizdoom/observation.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | from copy import deepcopy 5 | 6 | import gym 7 | from gym.spaces import Box, Discrete 8 | import numpy as np 9 | import cv2 10 | import vizdoom as vd 11 | 12 | 13 | class PermuteAndResizeFrame(gym.ObservationWrapper): 14 | """ CHW to HWC, and Resize Frame """ 15 | def __init__(self, env, height=84, width=84): 16 | gym.ObservationWrapper.__init__(self, env) 17 | channel = self.env.observation_space.shape[0] 18 | self.observation_space = self.env.observation_space 19 | self.observation_space.shape = (height, width, channel) 20 | 21 | self._height, self._width = height, width 22 | 23 | def observation(self, frame): 24 | frame = np.transpose(frame, axes=(1, 2, 0)) 25 | frame = cv2.resize(frame, (self._height, self._width), 26 | interpolation=cv2.INTER_AREA) 27 | return frame 28 | 29 | 30 | class PermuteFrame(gym.ObservationWrapper): 31 | """ CHW to HWC, and Resize Frame """ 32 | def __init__(self, env): 33 | gym.ObservationWrapper.__init__(self, env) 34 | c, h, w = self.env.observation_space.shape 35 | self.observation_space = self.env.observation_space 36 | self.observation_space.shape = (h, w, c) 37 | 38 | def observation(self, frame): 39 | frame = np.transpose(frame, axes=(1, 2, 0)) 40 | return frame 41 | 42 | 43 | class WuObservation(gym.ObservationWrapper): 44 | """ Wu Yuxin's Observation. Screen + GameVariables. 45 | 46 | Expose observations as a list [screen, game_var], where 47 | screen.shape = (height, width, channel) 48 | and 49 | game_var.shape = (2,) 50 | which includes (health, ammo) normalized to [0.0, 1.0] 51 | """ 52 | def __init__(self, env, height=84, width=84): 53 | gym.ObservationWrapper.__init__(self, env) 54 | # export observation space 55 | channel = self.env.observation_space.shape[0] 56 | screen_sp = self.env.observation_space 57 | screen_sp.shape = (height, width, channel) 58 | game_var_sp = Box(low=0.0, high=1.0, shape=(2,), dtype=np.float32) 59 | self.observation_space = [screen_sp, game_var_sp] 60 | 61 | self._height, self._width = height, width 62 | self._dft_gamevar = np.zeros(shape=game_var_sp.shape, 63 | dtype=game_var_sp.dtype) 64 | self._gamevar = deepcopy(self._dft_gamevar) 65 | 66 | def observation(self, frame): 67 | # Permute and resize 68 | frame = np.transpose(frame, axes=(1, 2, 0)) 69 | frame = cv2.resize(frame, (self._height, self._width), 70 | interpolation=cv2.INTER_AREA) 71 | # normalized game vars 72 | self._grab_gamevar() 73 | return [frame, self._gamevar] 74 | 75 | def _grab_gamevar(self): 76 | if self.env.unwrapped._state is not None: 77 | game = self.env.unwrapped.game 78 | self._gamevar[0] = game.get_game_variable(vd.GameVariable.HEALTH) / 100.0 79 | self._gamevar[1] = game.get_game_variable( 80 | vd.GameVariable.SELECTED_WEAPON_AMMO 81 | ) / 15.0 82 | -------------------------------------------------------------------------------- /docs/agt_int.md: -------------------------------------------------------------------------------- 1 | # AgtInt (Agent Interface) 2 | AgtInt is a class to define action space/wrapper and observation space/wrapper 3 | between agent and environment. After defining the Agent Interfaces of all the agent, 4 | one can use 5 | ``` 6 | env_new = EnvWrapper(env, (agt_int_1, ... , agt_int_n)) 7 | ``` 8 | to generate a new env with the desired action/observation space (usually simpler) for each agent. 9 | Here env is a gym-style "multi-player" environment and agt_int_i is the Agent 10 | Interface for each agent. 11 | 12 | One can also use 13 | ``` 14 | agent_new = AgentWrapper(agent, agt_int) 15 | ``` 16 | to transform a agent in simpler action/observation space to a new agent which can 17 | interact with raw environment directly. 18 | 19 | 20 | ## Basic idea of AgtInt 21 | A AgtInt need to clearly define 'obs_spec' and 'action_spec' for the agent and also two 22 | transformation functions: 23 | 24 | * "observation_transform" transforms the raw observation into the desired simple observation 25 | 26 | * "action_transform" transforms the agent's simple action into raw action in origin env 27 | 28 | ## AgtIntWrapper 29 | AgtIntWrapper is used to transform AgtInt in a modular way: 30 | ``` 31 | agt_int = AgtInt() 32 | agt_int = Discre4M2AWrapper(agt_int) 33 | agt_int = UnitAttrWrapper(agt_int, override=True) 34 | ``` 35 | A AgtIntWrapper can override or modify the observation_transform/action_transform in agt_int, 36 | obs_spec and action_spec also need specified. 37 | 38 | # "Multi-player" Environment (env/base_env.py) 39 | In this repository, we define a "multi-player" environment in following convention: 40 | 41 | * observation space: a gym.spaces.Tuple object whose each entry means the obs_spec 42 | for each agent. 43 | 44 | * action space: a gym.spaces.Tuple object with the same length of observation space. 45 | Each entry means the corresponding agent's action_spec. (A agent's action_spec may also 46 | be a gym.spaces.Tuple object with each entry means the action_spec of each unit controlled 47 | by this agent.) 48 | 49 | Even there is only one player in the environment, we still use the above "multi-player" environment 50 | convention to define observation/action space but with length of 1. 51 | 52 | ## EnvWrapper (env/base_env.py) 53 | EnvWrapper is used to transform "Multi-player" Environment with defined AgtInts for all the players: 54 | ``` 55 | def step(self, actions): 56 | raw_actions = self.action(actions) 57 | obs, rwd, done, info = self.env.step(raw_actions) 58 | new_obs = self.observation(obs) 59 | return new_obs, rwd, done, info 60 | ``` 61 | 62 | # BaseAgent (env/base_agent.py) 63 | A BaseAgent is basically a agent with "step" function to predict action in its action_spec 64 | given the observation in its obs_spec. 65 | 66 | ## AgentWrapper (env/base_agent.py) 67 | AgentWrapper is used to transform a BaseAgent object with defined AgtInt: 68 | ``` 69 | def step(self, obs): 70 | super(AgentWrapper, self).step(obs) 71 | obs = self.agt_int.observation_transform(obs) 72 | action = self.agent.step(obs) 73 | action = self.agt_int.action_transform(action) 74 | return action 75 | ``` 76 | 77 | -------------------------------------------------------------------------------- /arena/interfaces/soccer/obs_int.py: -------------------------------------------------------------------------------- 1 | """This file contains the observation interfaces for dm-control soccer.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import copy 8 | import numpy as np 9 | from arena.interfaces.interface import Interface 10 | from arena.interfaces.common import AppendObsInt 11 | from arena.utils.spaces import NoneSpace 12 | from gym import spaces 13 | 14 | class Dict2Vec(Interface): 15 | @property 16 | def observation_space(self): 17 | if isinstance(self.inter.observation_space, NoneSpace): 18 | return NoneSpace() 19 | assert isinstance(self.inter.observation_space, spaces.Dict) 20 | return self.convert_OrderedDict(self.inter.observation_space.spaces) 21 | 22 | def convert_OrderedDict(self, odict): 23 | # concatentation 24 | numdim = sum([np.int(np.prod(odict[key].shape)) for key in odict]) 25 | return spaces.Box(-np.inf, np.inf, shape=(numdim,)) 26 | 27 | def convert_observation(self, dict_obs): 28 | numdim = sum([np.int(np.prod(dict_obs[key].shape)) for key in dict_obs]) 29 | space_obs = np.zeros((numdim,)) 30 | i = 0 31 | for key in dict_obs: 32 | space_obs[i:i+np.prod(dict_obs[key].shape)] = dict_obs[key].ravel() 33 | i += np.prod(dict_obs[key].shape) 34 | return space_obs 35 | 36 | def obs_trans(self, obs): 37 | ret = self.convert_observation(obs) 38 | return ret 39 | 40 | class ConcatObsAct(Interface): 41 | @property 42 | def observation_space(self): 43 | if isinstance(self.inter.observation_space, NoneSpace): 44 | return NoneSpace() 45 | assert isinstance(self.inter.observation_space, spaces.Tuple) 46 | sps = self.inter.observation_space.spaces 47 | if any([isinstance(sp, NoneSpace) for sp in sps]): 48 | return NoneSpace() 49 | numdim = sum([ np.int(np.prod(sps[i].shape)) for i in range(len(sps))]) 50 | return spaces.Box(-np.inf, np.inf, shape=(numdim,)) 51 | 52 | @property 53 | def action_space(self): 54 | if isinstance(self.inter.action_space, NoneSpace): 55 | return NoneSpace() 56 | assert isinstance(self.inter.action_space, spaces.Tuple) 57 | sps = self.inter.action_space.spaces 58 | if any([isinstance(sp, NoneSpace) for sp in sps]): 59 | return NoneSpace() 60 | numdim = sum([ np.int(np.prod(sps[i].shape)) for i in range(len(sps))]) 61 | return spaces.Box(-1., 1., shape=(numdim,)) 62 | 63 | def _obs_trans(self, obs): 64 | return np.concatenate(obs) 65 | 66 | def obs_trans(self, obs): 67 | """ Observation Transformation. This is a recursive call. """ 68 | obs = self.inter.obs_trans(obs) 69 | return self._obs_trans(obs) 70 | 71 | def _act_trans(self, acts): 72 | rets = [] 73 | sps = self.inter.action_space.spaces 74 | i = 0 75 | for agsps in sps: 76 | size = agsps.shape[0] 77 | rets.append(acts[i:i+size]) 78 | i += size 79 | return rets -------------------------------------------------------------------------------- /arena/utils/vizdoom/Rect.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # File: rect.py 4 | # Author: Yuxin Wu 5 | 6 | import numpy as np 7 | 8 | class Rect(object): 9 | """ 10 | A Rectangle. 11 | Note that x1 = x+w, not x+w-1 or something 12 | """ 13 | __slots__ = ['x', 'y', 'w', 'h'] 14 | 15 | def __init__(self, x=0, y=0, w=0, h=0, allow_neg=False): 16 | self.x = x 17 | self.y = y 18 | self.w = w 19 | self.h = h 20 | if not allow_neg: 21 | assert min(self.x, self.y, self.w, self.h) >= 0 22 | 23 | @property 24 | def x0(self): 25 | return self.x 26 | 27 | @property 28 | def y0(self): 29 | return self.y 30 | 31 | @property 32 | def x1(self): 33 | return self.x + self.w 34 | 35 | @property 36 | def y1(self): 37 | return self.y + self.h 38 | 39 | def copy(self): 40 | new = type(self)() 41 | for i in self.__slots__: 42 | setattr(new, i, getattr(self, i)) 43 | return new 44 | 45 | def __str__(self): 46 | return 'Rect(x={}, y={}, w={}, h={})'.format(self.x, self.y, self.w, self.h) 47 | 48 | def area(self): 49 | return self.w * self.h 50 | 51 | def validate(self, shape=None): 52 | """ 53 | Is a valid bounding box within this shape 54 | :param shape: [h, w] 55 | :returns: boolean 56 | """ 57 | if min(self.x, self.y) < 0: 58 | return False 59 | if min(self.w, self.h) <= 0: 60 | return False 61 | if shape is None: 62 | return True 63 | if self.x1 > shape[1] - 1: 64 | return False 65 | if self.y1 > shape[0] - 1: 66 | return False 67 | return True 68 | 69 | def roi(self, img): 70 | assert self.validate(img.shape[:2]), "{} vs {}".format(self, img.shape[:2]) 71 | return img[self.y0:self.y1+1, self.x0:self.x1+1] 72 | 73 | def expand(self, frac): 74 | assert frac > 1.0, frac 75 | neww = self.w * frac 76 | newh = self.h * frac 77 | newx = self.x - (neww - self.w) * 0.5 78 | newy = self.y - (newh - self.h) * 0.5 79 | return Rect(*(map(int, [newx, newy, neww, newh])), allow_neg=True) 80 | 81 | def roi_zeropad(self, img): 82 | shp = list(img.shape) 83 | shp[0] = self.h 84 | shp[1] = self.w 85 | ret = np.zeros(tuple(shp), dtype=img.dtype) 86 | 87 | xstart = 0 if self.x >= 0 else -self.x 88 | ystart = 0 if self.y >= 0 else -self.y 89 | 90 | xmin = max(self.x0, 0) 91 | ymin = max(self.y0, 0) 92 | xmax = min(self.x1, img.shape[1]) 93 | ymax = min(self.y1, img.shape[0]) 94 | patch = img[ymin:ymax, xmin:xmax] 95 | ret[ystart:ystart+patch.shape[0],xstart:xstart+patch.shape[1]] = patch 96 | return ret 97 | 98 | __repr__ = __str__ 99 | 100 | 101 | if __name__ == '__main__': 102 | x = Rect(2, 1, 3, 3, allow_neg=True) 103 | 104 | img = np.random.rand(3,3) 105 | print(img) 106 | print(x.roi_zeropad(img)) 107 | -------------------------------------------------------------------------------- /arena/interfaces/combine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:utf-8 -*- 3 | from gym import spaces 4 | from arena.interfaces.interface import Interface 5 | from arena.interfaces.raw_int import RawInt 6 | 7 | 8 | class Combine(Interface): 9 | """ 10 | Concat several Interface to form a new Interface 11 | 12 | """ 13 | inter = None 14 | 15 | def __init__(self, interface, sub_interfaces=[]): 16 | """ Initialization. 17 | 18 | :param interface previous interface to wrap on 19 | :param sub_interfaces interfaces to combine 20 | """ 21 | 22 | if interface is None: 23 | self.inter = RawInt() 24 | else: 25 | self.inter = interface 26 | assert isinstance(self.inter, RawInt) 27 | 28 | self.sub_interfaces = list(sub_interfaces) 29 | for i, interface in enumerate(sub_interfaces): 30 | if interface is None: 31 | self.sub_interfaces[i] = RawInt() 32 | else: 33 | assert isinstance(interface, RawInt) 34 | 35 | def setup(self, observation_space, action_space): 36 | self.unwrapped().setup(observation_space, action_space) 37 | for i in range(len(self.sub_interfaces)): 38 | self.sub_interfaces[i].setup(observation_space.spaces[i], 39 | action_space.spaces[i]) 40 | 41 | def reset(self, obs, **kwargs): 42 | inter_ob_sp = self.inter.observation_space 43 | inter_ac_sp = self.inter.action_space 44 | assert isinstance(inter_ob_sp, spaces.Tuple) 45 | assert isinstance(inter_ac_sp, spaces.Tuple) 46 | assert len(inter_ob_sp.spaces) == len(self.sub_interfaces) 47 | self.inter.reset(obs) 48 | for i in range(len(self.sub_interfaces)): 49 | self.sub_interfaces[i].setup(inter_ob_sp.spaces[i], 50 | inter_ac_sp.spaces[i]) 51 | self.sub_interfaces[i].reset(obs[i]) 52 | 53 | def obs_trans(self, obs): 54 | obs = self.inter.obs_trans(obs) 55 | sub_obs = tuple([sub_inter.obs_trans(ob) 56 | for ob, sub_inter in zip(obs, self.sub_interfaces)]) 57 | return self._obs_trans(obs, sub_obs) 58 | 59 | def _obs_trans(self, obs, sub_obs): 60 | """ Observation Transformation. 61 | obs is observation from self.inter 62 | sub_obs are observations from sub_interfaces""" 63 | return sub_obs 64 | 65 | def _act_trans(self, act): 66 | act = [sub_inter.act_trans(ac) 67 | for ac, sub_inter in zip(act, self.sub_interfaces)] 68 | return act 69 | 70 | @property 71 | def observation_space(self): 72 | return spaces.Tuple([inter.observation_space 73 | for inter in self.sub_interfaces]) 74 | 75 | @property 76 | def action_space(self): 77 | return spaces.Tuple([inter.action_space 78 | for inter in self.sub_interfaces]) 79 | 80 | def __str__(self): 81 | """ Get the name of all stacked interface. """ 82 | s = str(self.inter) 83 | combine_s = str([str(inter) for inter in self.sub_interfaces]) 84 | return s+'<'+self.__class__.__name__+combine_s+'>' 85 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/_scenarios/bots.cfg: -------------------------------------------------------------------------------- 1 | { 2 | name Rambo 3 | aiming 67 4 | perfection 50 5 | reaction 70 6 | isp 50 7 | color "40 cf 00" 8 | skin base 9 | //weaponpref 012385678 10 | } 11 | 12 | { 13 | name McClane 14 | aiming 34 15 | perfection 75 16 | reaction 15 17 | isp 90 18 | color "b0 b0 b0" 19 | skin base 20 | //weaponpref 012345678 21 | } 22 | 23 | { 24 | name MacGyver 25 | aiming 80 26 | perfection 67 27 | reaction 72 28 | isp 87 29 | color "50 50 60" 30 | skin base 31 | //weaponpref 012345678 32 | } 33 | 34 | { 35 | name Plissken 36 | aiming 15 37 | perfection 50 38 | reaction 50 39 | isp 50 40 | color "8f 00 00" 41 | skin base 42 | //weaponpref 082345678 43 | } 44 | 45 | { 46 | name Machete 47 | aiming 50 48 | perfection 13 49 | reaction 20 50 | isp 100 51 | color "ff ff ff" 52 | skin base 53 | //weaponpref 012345678 54 | } 55 | 56 | { 57 | name Anderson 58 | aiming 45 59 | perfection 30 60 | reaction 70 61 | isp 60 62 | color "ff af 3f" 63 | skin base 64 | //weaponpref 012345678 65 | } 66 | 67 | { 68 | name Leone 69 | aiming 56 70 | perfection 34 71 | reaction 78 72 | isp 50 73 | color "bf 00 00" 74 | skin base 75 | //weaponpref 012345678 76 | } 77 | 78 | { 79 | name Predator 80 | aiming 25 81 | perfection 55 82 | reaction 32 83 | isp 70 84 | color "00 00 ff" 85 | skin base 86 | //weaponpref 012345678 87 | } 88 | 89 | { 90 | name Ripley 91 | aiming 61 92 | perfection 50 93 | reaction 23 94 | isp 32 95 | color "00 00 7f" 96 | skin base 97 | //weaponpref 012345678 98 | } 99 | 100 | { 101 | name T800 102 | aiming 90 103 | perfection 85 104 | reaction 10 105 | isp 30 106 | color "ff ff 00" 107 | skin base 108 | //weaponpref 012345678 109 | } 110 | 111 | { 112 | name Dredd 113 | aiming 12 114 | perfection 35 115 | reaction 56 116 | isp 37 117 | color "40 cf 00" 118 | skin base 119 | //weaponpref 012345678 120 | } 121 | 122 | { 123 | name Conan 124 | aiming 10 125 | perfection 35 126 | reaction 10 127 | isp 100 128 | color "b0 b0 b0" 129 | skin base 130 | //weaponpref 012345678 131 | } 132 | 133 | { 134 | name Bond 135 | aiming 67 136 | perfection 15 137 | reaction 76 138 | isp 37 139 | color "50 50 60" 140 | skin base 141 | //weaponpref 012345678 142 | } 143 | 144 | { 145 | name Jones 146 | aiming 52 147 | perfection 35 148 | reaction 50 149 | isp 37 150 | color "8f 00 00" 151 | skin base 152 | //weaponpref 012345678 153 | } 154 | 155 | { 156 | name Blazkowicz 157 | aiming 80 158 | perfection 80 159 | reaction 80 160 | isp 100 161 | color "00 00 00" 162 | skin base 163 | //weaponpref 012345678 164 | } 165 | 166 | -------------------------------------------------------------------------------- /arena/wrappers/vizdoom/action.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | import gym 6 | from gym.spaces import Box, Discrete 7 | 8 | 9 | class Discrete7Action(gym.ActionWrapper): 10 | """ Discrete 7 Actions """ 11 | def __init__(self, env): 12 | gym.ActionWrapper.__init__(self, env) 13 | 14 | self.action_space = Discrete(n=7) 15 | 16 | self._allowed_action = [ 17 | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], 18 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], 19 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 20 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 21 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 22 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], 23 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 24 | ] 25 | 26 | def action(self, a): 27 | return self._allowed_action[a] 28 | 29 | def reverse_action(self, action): 30 | raise NotImplementedError 31 | 32 | 33 | class Discrete3MoveAction(gym.ActionWrapper): 34 | """ Discrete 3 Actions, just for moving """ 35 | def __init__(self, env): 36 | gym.ActionWrapper.__init__(self, env) 37 | 38 | self.action_space = Discrete(n=3) 39 | 40 | self._allowed_action = [ 41 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], 42 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], 43 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 44 | ] 45 | 46 | def action(self, a): 47 | return self._allowed_action[a] 48 | 49 | def reverse_action(self, action): 50 | raise NotImplementedError 51 | 52 | 53 | class Discrete6MoveAction(gym.ActionWrapper): 54 | """ Discrete 6 Actions, just for moving """ 55 | def __init__(self, env): 56 | gym.ActionWrapper.__init__(self, env) 57 | 58 | self.action_space = Discrete(n=6) 59 | 60 | self._allowed_action = [ 61 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # TURN_LEFT 62 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # TURN_RIGHT 63 | [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], # MOVE_RIGHT 64 | [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], # MOVE_LEFT 65 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # MOVE_FORWARD 66 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # MOVE_BACKWARD 67 | ] 68 | 69 | def action(self, a): 70 | return self._allowed_action[a] 71 | 72 | def reverse_action(self, action): 73 | raise NotImplementedError 74 | 75 | class Discrete5MoveAction(gym.ActionWrapper): 76 | """ Discrete 6 Actions, just for moving """ 77 | def __init__(self, env): 78 | gym.ActionWrapper.__init__(self, env) 79 | 80 | self.action_space = Discrete(n=5) 81 | 82 | self._allowed_action = [ 83 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # No opt 84 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # TURN_LEFT 85 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # TURN_RIGHT 86 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # MOVE_FORWARD 87 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # MOVE_BACKWARD 88 | ] 89 | 90 | def action(self, a): 91 | return self._allowed_action[a] 92 | 93 | def reverse_action(self, action): 94 | raise NotImplementedError 95 | 96 | class Discrete7MoveAction(gym.ActionWrapper): 97 | """ Discrete 6 Actions, just for moving """ 98 | def __init__(self, env): 99 | gym.ActionWrapper.__init__(self, env) 100 | 101 | self.action_space = Discrete(n=7) 102 | 103 | self._allowed_action = [ 104 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], # No opt 105 | [0, 0, 1, 0, 0, 0, 0, 0, 0, 0], # TURN_LEFT 106 | [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # TURN_RIGHT 107 | [0, 0, 1, 0, 0, 0, 1, 0, 0, 0], # TURN_LEFT + MOVE_FORWARD 108 | [0, 0, 0, 1, 0, 0, 1, 0, 0, 0], # TURN_RIGHT + MOVE_FORWARD 109 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0], # MOVE_FORWARD 110 | [0, 0, 0, 0, 0, 0, 0, 1, 0, 0], # MOVE_BACKWARD 111 | ] 112 | 113 | def action(self, a): 114 | return self._allowed_action[a] 115 | 116 | def reverse_action(self, action): 117 | raise NotImplementedError 118 | -------------------------------------------------------------------------------- /arena/wrappers/vizdoom/game.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | from random import choice 5 | 6 | import gym 7 | from gym.spaces import Box, Discrete 8 | 9 | 10 | class EpisodicLifeEnv(gym.Wrapper): 11 | """Make end-of-life == end-of-episode, but only reset on true game over. 12 | Done by DeepMind for the DQN and co. since it helps value estimation. 13 | """ 14 | 15 | def __init__(self, env): 16 | gym.Wrapper.__init__(self, env) 17 | self.was_real_done = True 18 | 19 | def reset(self, **kwargs): 20 | """Reset only when lives are exhausted. 21 | This way all states are still reachable even though lives are episodic, 22 | and the learner need not know about any of this behind-the-scenes. 23 | """ 24 | if self.was_real_done: 25 | obs = self.env.reset(**kwargs) 26 | else: 27 | obs, _, _, _ = self.env.step(self.env.unwrapped.action_use) 28 | return obs 29 | 30 | def step(self, action): 31 | obs, reward, done, info = self.env.step(action) 32 | self.was_real_done = done 33 | done = self.env.unwrapped.game.is_player_dead() or self.was_real_done 34 | return obs, reward, done, info 35 | 36 | 37 | class SkipFrameVEnv(gym.Wrapper): 38 | """Return only every `skip`-th frame for vec env 39 | 40 | Note: in multi-player game via network connection, the most safe way seems to 41 | ensure all players simultaneously make ONLY ONE step once a time. See: 42 | https://github.com/mwydmuch/ViZDoom/issues/261 43 | """ 44 | 45 | def __init__(self, env, skip=4): 46 | gym.Wrapper.__init__(self, env) 47 | self._skip = skip 48 | 49 | def reset(self, **kwargs): 50 | return self.env.reset(**kwargs) 51 | 52 | def step(self, action): 53 | """Repeat action, sum reward""" 54 | total_reward = None 55 | done = None 56 | for i in range(self._skip): 57 | obs, reward, done, info = self.env.step(action) 58 | if total_reward is None: 59 | total_reward = reward 60 | else: 61 | total_reward = [a+b for a, b in zip(total_reward, reward)] 62 | if all(done): 63 | break 64 | return obs, total_reward, done, info 65 | 66 | 67 | class SkipFrameEnv(gym.Wrapper): 68 | """Return only every `skip`-th frame for PlayerEnv 69 | 70 | Seems problematic in multiplayer mode, see 71 | https://github.com/mwydmuch/ViZDoom/issues/261 72 | """ 73 | 74 | def __init__(self, env, skip=4): 75 | gym.Wrapper.__init__(self, env) 76 | self._skip = skip 77 | 78 | def reset(self, **kwargs): 79 | return self.env.reset(**kwargs) 80 | 81 | def step(self, action): 82 | """Repeat action, sum reward""" 83 | total_reward = 0.0 84 | done = None 85 | for i in range(self._skip): 86 | obs, reward, done, info = self.env.step(action) 87 | total_reward += reward 88 | if done: 89 | break 90 | return obs, total_reward, done, info 91 | 92 | 93 | class RepeatFrameEnv(gym.Wrapper): 94 | """Repeat Frame n times for PlayerEnv 95 | 96 | Seems problematic in multiplayer mode, see 97 | https://github.com/mwydmuch/ViZDoom/issues/261""" 98 | 99 | def __init__(self, env, n=4): 100 | gym.Wrapper.__init__(self, env) 101 | self._n = n 102 | self.env.unwrapped.cfg.repeat_frame = self._n 103 | 104 | def reset(self, **kwargs): 105 | return self.env.reset(**kwargs) 106 | 107 | def step(self, action): 108 | return self.env.step(action) 109 | 110 | 111 | class RandomConfigVEnv(gym.Wrapper): 112 | """ Start Doom game with randomly selected cfg file, for vec env""" 113 | 114 | def __init__(self, env, cfg_path_list): 115 | gym.Wrapper.__init__(self, env) 116 | self._cfg_path_list = cfg_path_list 117 | 118 | def reset(self, **kwargs): 119 | cfg_path = choice(self._cfg_path_list) 120 | for e in self.env.unwrapped.envs: 121 | e.unwrapped.cfg.config_path = cfg_path 122 | return self.env.reset(**kwargs) 123 | 124 | def step(self, action): 125 | return self.env.step(action) 126 | -------------------------------------------------------------------------------- /arena/env/env_int_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import Wrapper 3 | from gym import spaces 4 | 5 | class EnvIntWrapper(Wrapper): 6 | """ Environment Interface Wrapper """ 7 | inters = [] # list of Interface 8 | 9 | def __init__(self, env, inters=()): 10 | super(EnvIntWrapper, self).__init__(env) # do we need this? 11 | self.env = env 12 | assert len(inters) == len(self.env.action_space.spaces) 13 | self.inters = inters 14 | self.__update_obs_and_act_space() 15 | 16 | def __update_obs_and_act_space(self): 17 | obs_space = [] 18 | act_space = [] 19 | for i, inter in enumerate(self.inters): 20 | if inter is not None: 21 | obs_space.append(inter.observation_space) 22 | act_space.append(inter.action_space) 23 | else: 24 | obs_space.append(self.env.observation_space.spaces[i]) 25 | act_space.append(self.env.action_space.spaces[i]) 26 | # These two spaces are required by gym 27 | self.observation_space = spaces.Tuple(obs_space) 28 | self.action_space = spaces.Tuple(act_space) 29 | 30 | def __act_trans(self, acts): 31 | assert len(acts) == len(self.env.action_space.spaces) 32 | rets = [] 33 | for i, inter in enumerate(self.inters): 34 | if inter is not None: 35 | rets.append(self.inters[i].act_trans(acts[i])) 36 | else: 37 | rets.append(acts[i]) 38 | return rets 39 | 40 | def __obs_trans(self, obss): 41 | assert len(obss) == len(self.env.observation_space.spaces) 42 | rets = [] 43 | for i, inter in enumerate(self.inters): 44 | if inter is not None: 45 | rets.append(self.inters[i].obs_trans(obss[i])) 46 | else: 47 | rets.append(obss[i]) 48 | return rets 49 | 50 | def reset(self, **kwargs): 51 | inter_kwargs = [{}] * len(self.inters) 52 | #if hasattr(kwargs, 'inter_kwargs'): 53 | if 'inter_kwargs' in kwargs: 54 | assert len(kwargs['inter_kwargs']) == len(self.inters), '{}, {}'.format(len(kwargs['inter_kwargs']), 55 | len(self.inters)) 56 | inter_kwargs = kwargs.pop('inter_kwargs') 57 | obs = self.env.reset(**kwargs) 58 | for i, inter in enumerate(self.inters): 59 | if inter is not None: 60 | inter.setup(self.env.observation_space.spaces[i], 61 | self.env.action_space.spaces[i]) 62 | inter.reset(obs[i], **inter_kwargs[i]) 63 | self.__update_obs_and_act_space() 64 | return self.__obs_trans(obs) 65 | 66 | 67 | def step(self, acts, **kwargs): 68 | assert len(acts) == len(self.env.action_space.spaces) 69 | a = self.__act_trans(acts) 70 | obs, rwd, done, info = self.env.step(a, **kwargs) 71 | s = self.__obs_trans(obs) 72 | return s, rwd, done, info 73 | 74 | def close(self): 75 | self.env.close() 76 | 77 | 78 | class SC2EnvIntWrapper(EnvIntWrapper): 79 | def __init__(self, env, inters=(), noop_fns=lambda x: 1,): 80 | super(SC2EnvIntWrapper, self).__init__(env, inters) 81 | if callable(noop_fns): 82 | self.noop_fns = [noop_fns] * len(inters) 83 | else: 84 | assert isinstance(noop_fns, list) or isinstance(noop_fns, tuple) 85 | assert len(noop_fns) == len(inters) 86 | assert all(callable(fn) for fn in noop_fns) 87 | self._noop_steps = np.array([0] * len(inters), dtype=np.int32) 88 | 89 | def reset(self, **kwargs): 90 | self._noop_steps = np.zeros_like(self._noop_steps) 91 | return super(SC2EnvIntWrapper, self).reset(**kwargs) 92 | 93 | def step(self, acts, **kwargs): 94 | predict_noop_steps = [0 if act is None else noop_fn(act) 95 | for act, noop_fn in zip(acts, self.noop_fns)] 96 | self._noop_steps += np.array(predict_noop_steps, dtype=np.int32) 97 | min_noop_step = np.min(self._noop_steps) 98 | self.unwrapped.env._step_mul = min_noop_step 99 | self._noop_steps -= min_noop_step 100 | return super(SC2EnvIntWrapper, self).step(acts, **kwargs) 101 | -------------------------------------------------------------------------------- /arena/sandbox/run_mp_game.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | 3 | """Test script.""" 4 | 5 | from __future__ import absolute_import 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | from absl import app 10 | from absl import flags 11 | from pysc2.env import sc2_env 12 | from arena.utils.run_loop import run_loop 13 | import importlib 14 | 15 | FLAGS = flags.FLAGS 16 | flags.DEFINE_string("player1", "Bot", 17 | "Agent for player 1 ('Bot' for internal AI)") 18 | flags.DEFINE_string("player2", "Bot", 19 | "Agent for player 2 ('Bot' for internal AI)") 20 | flags.DEFINE_string("difficulty", "A", 21 | "Bot difficulty (from '1' to 'A')") 22 | flags.DEFINE_integer("max_steps_per_episode", 10000, 23 | "Max number of steps allowed per episode") 24 | flags.DEFINE_integer("episodes", 0, 25 | "Number of episodes (0 for infinity)") 26 | flags.DEFINE_integer("screen_resolution", "640", 27 | "Resolution for screen feature layers.") 28 | flags.DEFINE_float("sleep_time_per_step", 0, 29 | "Sleep time (in seconds) per step") 30 | flags.DEFINE_float("screen_ratio", "1.33", 31 | "Screen ratio of width / height") 32 | flags.DEFINE_string("agent_interface_format", "feature", 33 | "Agent Interface Format: [feature|rgb]") 34 | flags.DEFINE_integer("minimap_resolution", "64", 35 | "Resolution for minimap feature layers.") 36 | flags.DEFINE_integer("step_mul", 8, "Game steps per agent step.") 37 | flags.DEFINE_bool("disable_fog", False, "Turn off the Fog of War.") 38 | flags.DEFINE_bool("merge_units_info", False, "Merge units info in timesteps.") 39 | flags.DEFINE_string("map", None, "Name of a map to use.") 40 | flags.DEFINE_bool("visualize", False, "Visualize pygame screen") 41 | flags.mark_flag_as_required("map") 42 | 43 | def get_agent(agt_path): 44 | module, name = agt_path.rsplit('.', 1) 45 | agt_cls = getattr(importlib.import_module(module), name) 46 | return agt_cls() 47 | 48 | def get_difficulty(level): 49 | diff_dict = { \ 50 | "1": sc2_env.Difficulty.very_easy, 51 | "2": sc2_env.Difficulty.easy, 52 | "3": sc2_env.Difficulty.medium, 53 | "4": sc2_env.Difficulty.medium_hard, 54 | "5": sc2_env.Difficulty.hard, 55 | "6": sc2_env.Difficulty.hard, 56 | "7": sc2_env.Difficulty.very_hard, 57 | "8": sc2_env.Difficulty.cheat_vision, 58 | "9": sc2_env.Difficulty.cheat_money, 59 | "A": sc2_env.Difficulty.cheat_insane, 60 | } 61 | return diff_dict[level] 62 | 63 | def main(unused_argv): 64 | """Run an agent.""" 65 | step_mul = FLAGS.step_mul 66 | players = 2 67 | 68 | screen_res = (int(FLAGS.screen_ratio * FLAGS.screen_resolution)//4*4, FLAGS.screen_resolution) 69 | agent_interface_format = None 70 | if FLAGS.agent_interface_format == 'feature': 71 | agent_interface_format = sc2_env.AgentInterfaceFormat( 72 | feature_dimensions=sc2_env.Dimensions( 73 | screen=screen_res, 74 | minimap=FLAGS.minimap_resolution)) 75 | elif FLAGS.agent_interface_format == 'rgb': 76 | agent_interface_format = sc2_env.AgentInterfaceFormat( 77 | rgb_dimensions=sc2_env.Dimensions( 78 | screen=screen_res, 79 | minimap=FLAGS.minimap_resolution)) 80 | else: 81 | raise NotImplementedError 82 | players = [sc2_env.Agent(sc2_env.Race.zerg), sc2_env.Agent(sc2_env.Race.zerg)] 83 | agents = [] 84 | bot_difficulty = get_difficulty(FLAGS.difficulty) 85 | if FLAGS.player1 == 'Bot': 86 | players[0] = sc2_env.Bot(sc2_env.Race.zerg, bot_difficulty) 87 | else: 88 | agents.append(get_agent(FLAGS.player1)) 89 | if FLAGS.player2 == 'Bot': 90 | players[1] = sc2_env.Bot(sc2_env.Race.zerg, bot_difficulty) 91 | else: 92 | agents.append(get_agent(FLAGS.player2)) 93 | with sc2_env.SC2Env( 94 | map_name=FLAGS.map, 95 | visualize=FLAGS.visualize, 96 | players=players, 97 | step_mul=step_mul, 98 | game_steps_per_episode=FLAGS.max_steps_per_episode * step_mul, 99 | agent_interface_format=agent_interface_format, 100 | disable_fog=FLAGS.disable_fog) as env: 101 | run_loop(agents, env, 102 | max_frames=0, 103 | max_episodes=FLAGS.episodes, 104 | sleep_time_per_step=FLAGS.sleep_time_per_step, 105 | merge_units_info=FLAGS.merge_units_info) 106 | 107 | if __name__ == "__main__": 108 | app.run(main) 109 | -------------------------------------------------------------------------------- /arena/interfaces/vizdoom/obs_int.py: -------------------------------------------------------------------------------- 1 | """Vizdoom observation interfaces""" 2 | from copy import deepcopy 3 | 4 | import cv2 5 | import numpy as np 6 | import vizdoom as vd 7 | from gym.spaces import Box, Tuple 8 | 9 | from arena.interfaces.interface import Interface 10 | from arena.utils.vizdoom.Rect import Rect 11 | from arena.utils.vizdoom.player import PlayerConfig 12 | 13 | 14 | class FrameVarObsInt(Interface): 15 | """Wu Yuxin's Observation as Screen + GameVariables. 16 | 17 | Expose observations as a list [screen, game_var], where 18 | screen.shape = (height, width, channel) 19 | and 20 | game_var.shape = (2,) 21 | which includes (health, ammo) normalized to [0.0, 1.0] 22 | """ 23 | def __init__(self, inter, env, height=84, width=84): 24 | super(__class__, self).__init__(inter) 25 | self.env = env 26 | 27 | # export observation space 28 | channel = self.env.observation_space.shape[0] 29 | self.screen_sp = self.env.observation_space 30 | self.screen_sp.shape = (height, width, channel) 31 | 32 | self.game_var_sp = Box(low=0.0, high=1.0, shape=(2,), dtype=np.float32) 33 | 34 | self._height, self._width = height, width 35 | self._dft_gamevar = np.zeros(shape=self.game_var_sp.shape, 36 | dtype=self.game_var_sp.dtype) 37 | self._gamevar = np.array(deepcopy(self._dft_gamevar)) 38 | 39 | @property 40 | def observation_space(self): 41 | return Tuple([Box(low=0.0, high=1.0, shape=self.screen_sp.shape, dtype=np.float32), 42 | self.game_var_sp]) 43 | 44 | @property 45 | def action_space(self): 46 | return self.env.action_space 47 | 48 | def obs_trans(self, frame): 49 | # Permute and resize 50 | frame = np.transpose(frame, axes=(1, 2, 0)) 51 | frame = cv2.resize(frame, (self._height, self._width), 52 | interpolation=cv2.INTER_AREA) 53 | # normalized frame 54 | frame = (np.array(frame) / 255) 55 | # normalized game vars 56 | self._grab_gamevar() 57 | return np.array([frame, self._gamevar]) 58 | 59 | def reset(self, obs, **kwargs): 60 | super(FrameVarObsInt, self).reset(obs, **kwargs) 61 | # self.wrapper = FrameVarObsFunc(obs, self.env, self.use_attr) 62 | 63 | def _grab_gamevar(self): 64 | if self.env.unwrapped._state is not None: 65 | game = self.env.unwrapped.game 66 | self._gamevar[0] = game.get_game_variable(vd.GameVariable.HEALTH) / 100.0 67 | self._gamevar[1] = game.get_game_variable(vd.GameVariable.AMMO2) / 15.0 68 | 69 | def _get_available_game_variables_dim(player_cfg): 70 | g = vd.DoomGame() 71 | g = player_setup(g, player_cfg) 72 | return len(g.get_available_game_variables()) 73 | 74 | 75 | class ReshapedFrameObsInt(Interface): 76 | """Reshaped Frame as Observation. 77 | 78 | TODO(pengsun): more descriptions.""" 79 | def __init__(self, inter, env, height=168, width=168): 80 | super(__class__, self).__init__(inter) 81 | self.env = env 82 | # export observation space 83 | channel = self.env.observation_space.shape[0]*2 84 | center_patch = 0.3 85 | frac = center_patch / 2 86 | W, H = 800, 450 87 | self._center_rect = Rect(*map(int, 88 | [W / 2 - W * frac, H / 2 - H * frac, W * frac * 2, H * frac * 2])) 89 | self.screen_sp = self.env.observation_space 90 | self.screen_sp.shape = (height, width, channel) 91 | self._height, self._width = height, width 92 | 93 | @property 94 | def observation_space(self): 95 | return Box(low=0.0, high=1.0, shape=self.screen_sp.shape, dtype=np.float32) 96 | 97 | @property 98 | def action_space(self): 99 | return self.env.action_space 100 | 101 | def obs_trans(self, frame): 102 | if frame.shape != (168, 168, 6): 103 | # Permute and resize 104 | frame = np.transpose(frame, axes=(1, 2, 0)) 105 | ## normalized frame 106 | frame = (np.array(frame) / 255.0) 107 | center_patch = self._center_rect.roi(frame) 108 | frame = cv2.resize(frame, (self._height, self._width), 109 | interpolation=cv2.INTER_AREA) 110 | center_patch = cv2.resize(center_patch, (self._height, self._width), 111 | interpolation=cv2.INTER_AREA) 112 | frame = np.concatenate((frame, center_patch), axis=2) 113 | ## normalized frame 114 | # frame = (np.array(frame)/255) 115 | return frame 116 | 117 | def reset(self, obs, **kwargs): 118 | super(ReshapedFrameObsInt, self).reset(obs, **kwargs) 119 | -------------------------------------------------------------------------------- /arena/utils/pong2p/pong2p_env.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | import gym 4 | import pygame 5 | 6 | from gym import error, spaces, utils 7 | from gym.utils import seeding 8 | 9 | from .pong2p_game import PongGame 10 | 11 | 12 | class PongSinglePlayerEnv(gym.Env): 13 | metadata = {'render.modes': ['human', 'rgb_array']} 14 | 15 | def __init__(self, ball_speed=4, bat_speed=16, max_num_rounds=20, random_seed=None): 16 | SCREEN_WIDTH, SCREEN_HEIGHT = 160, 210 17 | 18 | self.observation_space = spaces.Box( 19 | low=0, high=255, shape=(SCREEN_HEIGHT, SCREEN_WIDTH, 3)) 20 | self.action_space = spaces.Discrete(6) 21 | 22 | pygame.init() 23 | self._surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT)) 24 | self._viewer = None 25 | self._rng = np.random.RandomState() 26 | self._game = PongGame( 27 | has_double_players=False, 28 | window_size=(SCREEN_WIDTH, SCREEN_WIDTH), 29 | ball_speed=ball_speed, 30 | bat_speed=bat_speed, 31 | max_num_rounds=max_num_rounds) 32 | 33 | def set_seed(self, seed): 34 | self._game.set_seed(seed) 35 | 36 | def _seed(self, seed=None): 37 | self._rng.seed(seed) 38 | 39 | def step(self, action): 40 | #assert self.action_space.contains(action) 41 | bat_directions = [0, 1, 2, 3, 4, 5] 42 | rewards, done = self._game.step(bat_directions[action], None) 43 | obs = self._get_screen_img() 44 | return (obs, rewards[0], done, {}) 45 | 46 | def reset(self): 47 | self._game.reset_game() 48 | obs = self._get_screen_img() 49 | return obs 50 | 51 | def _render(self, mode='human', close=False): 52 | if close: 53 | if self._viewer is not None: 54 | self._viewer.close() 55 | self._viewer = None 56 | pygame.quit() 57 | return 58 | img = self._get_screen_img() 59 | if mode == 'rgb_array': 60 | return img 61 | elif mode == 'human': 62 | from gym.envs.classic_control import rendering 63 | if self._viewer is None: 64 | self._viewer = rendering.SimpleImageViewer() 65 | self._viewer.imshow(img) 66 | 67 | def render(self, mode='human', close=False): 68 | return self._render(mode, close) 69 | 70 | def _get_screen_img(self): 71 | self._game.draw(self._surface) 72 | self._game.draw_scoreboard(self._surface) 73 | obs = self._surface_to_img(self._surface) 74 | return obs 75 | 76 | def _surface_to_img(self, surface): 77 | img = pygame.surfarray.array3d(surface).astype(np.uint8) 78 | return np.transpose(img, (1, 0, 2)) 79 | 80 | class PongDoublePlayerEnv(PongSinglePlayerEnv): 81 | metadata = {'render.modes': ['human', 'rgb_array']} 82 | 83 | def __init__(self, ball_speed=4, bat_speed=16, max_num_rounds=20, random_seed=None): 84 | SCREEN_WIDTH, SCREEN_HEIGHT = 160, 210 85 | self.observation_space = spaces.Tuple([ 86 | spaces.Box( 87 | low=0, high=255, shape=(SCREEN_HEIGHT, SCREEN_WIDTH, 3)), 88 | spaces.Box( 89 | low=0, high=255, shape=(SCREEN_HEIGHT, SCREEN_WIDTH, 3)) 90 | ]) 91 | self.action_space = spaces.Tuple( 92 | [spaces.Discrete(6), spaces.Discrete(6)]) 93 | 94 | pygame.init() 95 | self._surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT)) 96 | 97 | self._viewer = None 98 | self._game = PongGame( 99 | has_double_players=True, 100 | window_size=(SCREEN_WIDTH, SCREEN_WIDTH), 101 | ball_speed=ball_speed, 102 | bat_speed=bat_speed, 103 | max_num_rounds=max_num_rounds) 104 | 105 | def step(self, action): 106 | #assert self.action_space.contains(action) 107 | left_player_action, right_player_action = action 108 | bat_directions = [0, 1, 2, 3, 4, 5] 109 | rewards, done = self._game.step(bat_directions[left_player_action], 110 | bat_directions[right_player_action]) 111 | obs = self._get_screen_img_double_player() 112 | return (obs, rewards, done, {}) 113 | 114 | def reset(self): 115 | self._game.reset_game() 116 | obs = self._get_screen_img_double_player() 117 | return obs 118 | 119 | def _get_screen_img_double_player(self): 120 | self._game.draw(self._surface) 121 | surface_flipped = pygame.transform.flip(self._surface, True, False) 122 | self._game.draw_scoreboard(self._surface) 123 | self._game.draw_scoreboard(surface_flipped) 124 | obs = self._surface_to_img(self._surface) 125 | obs_flip = self._surface_to_img(surface_flipped) 126 | return (obs, obs_flip) 127 | -------------------------------------------------------------------------------- /arena/utils/vizdoom/core_env.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import numpy as np 4 | import gym 5 | from gym.spaces import Box 6 | from gym.spaces import Tuple 7 | import vizdoom as vd 8 | 9 | from arena.utils.vizdoom.run_parallel import RunParallel 10 | from arena.utils.vizdoom.player import player_setup, player_host_setup, player_join_setup 11 | 12 | 13 | def _get_screen_shape(vd_resolution, vd_format): 14 | tmp = { 15 | (vd.ScreenFormat.CBCGCR, vd.ScreenResolution.RES_160X120): (3, 120, 160), 16 | (vd.ScreenFormat.CBCGCR, vd.ScreenResolution.RES_640X360): (3, 640, 360), 17 | (vd.ScreenFormat.CBCGCR, vd.ScreenResolution.RES_800X450): (3, 800, 450), 18 | (vd.ScreenFormat.CBCGCR, vd.ScreenResolution.RES_800X600): (3, 800, 600), 19 | } 20 | return tmp[(vd_format, vd_resolution)] 21 | 22 | 23 | def _get_action_dim(player_cfg): 24 | g = vd.DoomGame() 25 | g = player_setup(g, player_cfg) 26 | return len(g.get_available_buttons()) 27 | 28 | 29 | def _get_available_game_variables_dim(player_cfg): 30 | g = vd.DoomGame() 31 | g = player_setup(g, player_cfg) 32 | return len(g.get_available_game_variables()) 33 | 34 | 35 | class PlayerEnv(gym.Env): 36 | # TODO(pengsun): delete/trim this class 37 | def __init__(self, cfg): 38 | self.cfg = cfg 39 | self.game = None 40 | 41 | # export observation space & action space 42 | self.observation_space = gym.spaces.Box( 43 | low=0, high=255, dtype=np.uint8, 44 | shape=_get_screen_shape(cfg.screen_resolution, cfg.screen_format) 45 | ) 46 | self.action_space = gym.spaces.Box(low=0, high=1, dtype=np.float32, 47 | shape=(_get_action_dim(cfg),)) 48 | # export predefined actions 49 | # self.action_noop = np.zeros(shape=self.action_space.shape) 50 | # self.action_noop[0] = 1 51 | self.action_noop = [0, 0, 0, 0, 0, 0, 0, 0, 0] 52 | self.action_use = [0, 1, 0, 0, 0, 0, 0, 0, 0] 53 | self.action_fire = [1, 0, 0, 0, 0, 0, 0, 0, 0] 54 | 55 | self._state = None 56 | self._obs = None 57 | self._rwd = None 58 | self._done = None 59 | self._act = None 60 | 61 | def reset(self): 62 | if not self.cfg.is_multiplayer_game: 63 | if self.game is None: 64 | self._init_game() 65 | self.game.new_episode() 66 | else: 67 | self._init_game() 68 | if self.cfg.num_bots > 0: 69 | self._add_bot() 70 | 71 | self._state, self._obs, self._done = self._grab() 72 | return self._obs 73 | 74 | def step(self, action): 75 | self._rwd = self.game.make_action(action, self.cfg.repeat_frame) 76 | self._state, self._obs, self._done = self._grab() 77 | return self._obs, self._rwd, self._done, {} 78 | 79 | def close(self): 80 | if self.game: 81 | self.game.close() 82 | 83 | def render(self, *args): 84 | return self._obs 85 | 86 | def _init_game(self): 87 | self.close() 88 | 89 | game = vd.DoomGame() 90 | game = player_setup(game, self.cfg) 91 | 92 | if self.cfg.is_multiplayer_game: 93 | if self.cfg.host_cfg is not None: 94 | game = player_host_setup(game, self.cfg.host_cfg) 95 | elif self.cfg.join_cfg is not None: 96 | game = player_join_setup(game, self.cfg.join_cfg) 97 | else: 98 | raise ValueError('neither host nor join, error!') 99 | 100 | 101 | game.init() 102 | self.game = game 103 | 104 | def _grab(self): 105 | state = self.game.get_state() 106 | done = self.game.is_episode_finished() 107 | if done: 108 | obs = np.ndarray(shape=self.observation_space.shape, 109 | dtype=self.observation_space.dtype) 110 | else: 111 | obs = state.screen_buffer 112 | return state, obs, done 113 | 114 | def _add_bot(self): 115 | self.game.send_game_command("removebots") 116 | for i in range(self.cfg.num_bots): 117 | self.game.send_game_command("addbot") 118 | 119 | 120 | class VecEnv(gym.Env): 121 | # TODO(pengsun): delete/trim this class 122 | def __init__(self, envs): 123 | # export observation space & action space 124 | self.observation_space = [e.observation_space for e in envs] 125 | self.action_space = [e.action_space for e in envs] 126 | 127 | self._envs = envs 128 | self._par = RunParallel() 129 | 130 | def reset(self): 131 | observations = self._par.run((e.reset) for e in self._envs) 132 | return observations 133 | 134 | def step(self, actions): 135 | ret = self._par.run((e.step, act) 136 | for e, act in zip(self._envs, actions)) 137 | observations, rewards, dones, infos = [item for item in zip(*ret)] 138 | return observations, rewards, dones, infos 139 | 140 | def close(self): 141 | self._par.run((e.close) for e in self._envs) 142 | 143 | def render(self, *args): 144 | obs = self._par.run((e.render) for e in self._envs) 145 | return obs 146 | 147 | @property 148 | def envs(self): 149 | return self._envs 150 | -------------------------------------------------------------------------------- /arena/env/pong2p_env.py: -------------------------------------------------------------------------------- 1 | """ Arena compatible pong2p env. 2 | 3 | written by loyavejmlu, jackzbzheng, xinghaisun """ 4 | from __future__ import absolute_import 5 | from __future__ import division 6 | from __future__ import print_function 7 | 8 | import pygame 9 | import numpy as np 10 | from gym import spaces, Env 11 | 12 | from arena.utils.pong2p.pong2p_game import PongGame 13 | 14 | 15 | class Pong2pEnv(Env): 16 | metadata = {'render.modes': ['human', 'rgb_array']} 17 | # def __init__(self, env_id='PongNoFrameskip-2p-v0', render=False): 18 | def __init__(self, ball_speed=4, bat_speed=16, max_num_rounds=20, random_seed=None): 19 | SCREEN_WIDTH, SCREEN_HEIGHT = 160, 210 20 | self.observation_space = spaces.Tuple([ 21 | spaces.Box( 22 | low=0, high=255, shape=(SCREEN_HEIGHT, SCREEN_WIDTH, 3)), 23 | spaces.Box( 24 | low=0, high=255, shape=(SCREEN_HEIGHT, SCREEN_WIDTH, 3)) 25 | ]) 26 | self.action_space = spaces.Tuple( 27 | [spaces.Discrete(6), spaces.Discrete(6)]) 28 | 29 | pygame.init() 30 | self._surface = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT)) 31 | 32 | self._viewer = None 33 | self._rng = np.random.RandomState() 34 | self._game = PongGame( 35 | has_double_players=True, 36 | window_size=(SCREEN_WIDTH, SCREEN_WIDTH), 37 | ball_speed=ball_speed, 38 | bat_speed=bat_speed, 39 | max_num_rounds=max_num_rounds) 40 | self.reward_sum = np.zeros(2) 41 | 42 | def set_seed(self, seed): 43 | self._game.set_seed(seed) 44 | 45 | def _seed(self, seed=None): 46 | self._rng.seed(seed) 47 | 48 | def step(self, action): 49 | #assert self.action_space.contains(action) 50 | left_player_action, right_player_action = action 51 | bat_directions = [0, 1, 2, 3, 4, 5] 52 | rewards, done = self._game.step(bat_directions[left_player_action], 53 | bat_directions[right_player_action]) 54 | obs = self._get_screen_img_double_player() 55 | info = {} 56 | self.reward_sum += np.array(rewards) 57 | if done: 58 | print(self.reward_sum) 59 | if self.reward_sum[0] > self.reward_sum[1]: 60 | info['outcome'] = [1,-1] 61 | elif self.reward_sum[0] < self.reward_sum[1]: 62 | info['outcome'] = [-1,1] 63 | else: 64 | info['outcome'] = [0,0] 65 | return (obs, rewards, done, info) 66 | 67 | def reset(self, **kwargs): 68 | self._game.reset_game() 69 | obs = self._get_screen_img_double_player() 70 | return obs 71 | 72 | def _get_screen_img_double_player(self): 73 | self._game.draw(self._surface) 74 | surface_flipped = pygame.transform.flip(self._surface, True, False) 75 | self._game.draw_scoreboard(self._surface) 76 | self._game.draw_scoreboard(surface_flipped) 77 | obs = self._surface_to_img(self._surface) 78 | obs_flip = self._surface_to_img(surface_flipped) 79 | return (obs, obs_flip) 80 | 81 | def _render(self, mode='human', close=False): 82 | if close: 83 | if self._viewer is not None: 84 | self._viewer.close() 85 | self._viewer = None 86 | pygame.quit() 87 | return 88 | img = self._get_screen_img() 89 | if mode == 'rgb_array': 90 | return img 91 | elif mode == 'human': 92 | from gym.envs.classic_control import rendering 93 | if self._viewer is None: 94 | self._viewer = rendering.SimpleImageViewer() 95 | self._viewer.imshow(img) 96 | 97 | def render(self, mode='human', close=False): 98 | return self._render(mode, close) 99 | 100 | def close(self): 101 | self._render(close=True) 102 | 103 | def _get_screen_img(self): 104 | self._game.draw(self._surface) 105 | self._game.draw_scoreboard(self._surface) 106 | obs = self._surface_to_img(self._surface) 107 | return obs 108 | 109 | def _surface_to_img(self, surface): 110 | img = pygame.surfarray.array3d(surface).astype(np.uint8) 111 | return np.transpose(img, (1, 0, 2)) 112 | 113 | 114 | def main(): 115 | import numpy as np 116 | from arena.wrappers.pong2p.pong2p_wrappers import ClipRewardEnv, WarpFrame, ScaledFloatFrame, FrameStack 117 | from arena.wrappers.pong2p.pong2p_compete import WrapCompete, NoRwdResetEnv 118 | 119 | env = Pong2pEnv() 120 | env = WarpFrame(env) 121 | env = ClipRewardEnv(env) 122 | env = FrameStack(env, 4) 123 | env = ScaledFloatFrame(env) 124 | env = NoRwdResetEnv(env, no_reward_thres = 1000) 125 | #env = WrapCompete(env) 126 | 127 | obs = env.reset() 128 | print(obs) 129 | ac_space = env.action_space.spaces[0] 130 | ob_space = env.observation_space.spaces[0] 131 | print(ac_space) 132 | print(ob_space) 133 | max_step = 3000 134 | step = 0 135 | done = False 136 | while step < max_step: 137 | # env.render() 138 | actions = [np.random.randint(6), np.random.randint(6)] 139 | state, reward, done, info = env.step(actions) 140 | env.render() 141 | print('step {}, action {}, reward {}, done {}, info {}'.format(step, actions, reward, done, info)) 142 | step += 1 143 | if done: 144 | print('episode done.') 145 | print(reward) 146 | obs = env.reset() 147 | #env.close() 148 | 149 | if __name__ == '__main__': 150 | main() 151 | 152 | -------------------------------------------------------------------------------- /arena/wrappers/basic_env_wrapper.py: -------------------------------------------------------------------------------- 1 | from gym import ObservationWrapper 2 | from gym import Wrapper, RewardWrapper 3 | from gym import spaces 4 | from itertools import chain 5 | from arena.utils.spaces import NoneSpace 6 | from random import random 7 | import numpy as np 8 | 9 | 10 | class RandPlayer(Wrapper): 11 | """ Wrapper for randomizing players order """ 12 | 13 | def __init__(self, env): 14 | super(RandPlayer, self).__init__(env) 15 | assert len(self.env.action_space.spaces) == 2 16 | assert self.env.action_space.spaces[0] == self.env.action_space.spaces[1] 17 | assert self.env.observation_space.spaces[0] == self.env.observation_space.spaces[1] 18 | self.change_player = random() < 0.5 19 | 20 | def reset(self, **kwargs): 21 | obs = super(RandPlayer, self).reset(**kwargs) 22 | self.change_player = random() < 0.5 23 | if self.change_player: 24 | obs = list(obs) 25 | obs.reverse() 26 | return obs 27 | 28 | def step(self, actions): 29 | if self.change_player: 30 | actions = list(actions) 31 | actions.reverse() 32 | obs, rwd, done, info = self.env.step(actions) 33 | if self.change_player: 34 | obs = list(obs) 35 | obs.reverse() 36 | rwd = list(rwd) 37 | rwd.reverse() 38 | return obs, rwd, done, info 39 | 40 | 41 | class VecRwdTransform(RewardWrapper): 42 | """ Reward Wrapper for sc2 full game """ 43 | 44 | def __init__(self, env, weights): 45 | super(VecRwdTransform, self).__init__(env) 46 | self.weights = weights 47 | 48 | def step(self, actions): 49 | obs, rwd, done, info = self.env.step(actions) 50 | rwd = [np.array(self.weights).dot(np.array(reward)) for reward in rwd] 51 | return obs, rwd, done, info 52 | 53 | 54 | class StepMul(Wrapper): 55 | def __init__(self, env, step_mul=3 * 60 * 4): 56 | super(StepMul, self).__init__(env) 57 | self._step_mul = step_mul 58 | self._cur_obs = None 59 | 60 | def reset(self, **kwargs): 61 | self._cur_obs = self.env.reset() 62 | self.action_space = self.env.action_space 63 | self.observation_space = self.env.observation_space 64 | return self._cur_obs 65 | 66 | def step(self, actions): 67 | done, info = False, {} 68 | cumrew = [0.0 for _ in actions] # number players 69 | for _ in range(self._step_mul): 70 | self._cur_obs, rew, done, info = self.env.step(actions) 71 | cumrew = [a + b for a, b in zip(cumrew, rew)] 72 | if done: 73 | break 74 | return self._cur_obs, cumrew, done, info 75 | 76 | 77 | class AllObs(ObservationWrapper): 78 | """ Give all players' observation to cheat_players 79 | cheat_players = None means all players are cheating, 80 | cheat_players = [] means no one is cheating) """ 81 | 82 | def __init__(self, env, cheat_players=None): 83 | super(AllObs, self).__init__(env) 84 | self.observation_space = NoneSpace() 85 | self.cheat_players = cheat_players 86 | 87 | def observation(self, obs): 88 | observation = [] 89 | for i in range(len(obs)): 90 | if i in self.cheat_players: 91 | if isinstance(obs[0], list) or isinstance(obs[0], tuple): 92 | observation.append(list(chain(*obs[i:], *obs[0:i]))) 93 | else: 94 | observation.append(list(obs[i:]) + list(obs[0:i])) 95 | else: 96 | observation.append(obs[i]) 97 | return observation 98 | 99 | def reset(self): 100 | obs = self.env.reset() 101 | self.action_space = self.env.action_space 102 | obs_space = self.env.observation_space 103 | assert isinstance(obs_space, spaces.Tuple) 104 | assert all([sp == obs_space.spaces[0] for sp in obs_space.spaces]) 105 | n_player = len(obs_space.spaces) 106 | if self.cheat_players is None: 107 | self.cheat_players = range(n_player) 108 | if isinstance(obs_space.spaces[0], spaces.Tuple): 109 | sp = spaces.Tuple(obs_space.spaces[0].spaces * n_player) 110 | else: 111 | sp = spaces.Tuple(obs_space.spaces[0] * n_player) 112 | sps = [sp if i in self.cheat_players else obs_space.spaces[i] 113 | for i in range(n_player)] 114 | self.observation_space = spaces.Tuple(sps) 115 | return self.observation(obs) 116 | 117 | 118 | class OppoObsAsObs(Wrapper): 119 | """ A base wrapper for appending (part of) the opponent's obs to obs """ 120 | 121 | def __init__(self, env): 122 | super(OppoObsAsObs, self).__init__(env) 123 | self._me_id = 0 124 | self._oppo_id = 1 125 | 126 | def reset(self, **kwargs): 127 | obs = self.env.reset(**kwargs) 128 | return self._process_obs(obs) 129 | 130 | def _expand_obs_space(self, **kwargs): 131 | raise NotImplementedError("Implement your own func.") 132 | 133 | def _parse_oppo_obs(self, raw_oppo_obs): 134 | raise NotImplementedError("Implement your own func.") 135 | 136 | def _append_obs(self, self_obs, raw_oppo_obs): 137 | if isinstance(self_obs, tuple): 138 | return self_obs + self._parse_oppo_obs(raw_oppo_obs) 139 | elif isinstance(self_obs, dict): 140 | self_obs.update(self._parse_oppo_obs(raw_oppo_obs)) 141 | return self_obs 142 | else: 143 | raise Exception("Unknown obs type in OppoObsAsObs wrapper.") 144 | 145 | def _process_obs(self, obs): 146 | if obs[0] is None: 147 | return obs 148 | else: 149 | appended_self_obs = self._append_obs( 150 | obs[self._me_id], self.env.unwrapped._obs[self._oppo_id]) 151 | return [appended_self_obs, obs[self._oppo_id]] 152 | 153 | def step(self, actions): 154 | obs, rwd, done, info = self.env.step(actions) 155 | assert len(obs) == 2, "OppoObsAsObs only supports 2 players game." 156 | return self._process_obs(obs), rwd, done, info 157 | 158 | 159 | -------------------------------------------------------------------------------- /arena/env/sc2_base_env.py: -------------------------------------------------------------------------------- 1 | """ gym compatible sc2 env """ 2 | import random 3 | from gym import Env 4 | from gym import spaces 5 | from pysc2.env import sc2_env 6 | from pysc2.env.sc2_env import SC2Env 7 | from arena.utils.spaces import SC2RawObsSpace, SC2RawActSpace 8 | 9 | 10 | class SC2BaseEnv(Env): 11 | def __init__(self, map_name="4MarineA", 12 | players=(sc2_env.Agent(sc2_env.Race.zerg), 13 | sc2_env.Bot(sc2_env.Race.zerg, 14 | sc2_env.Difficulty.very_hard)), 15 | agent_interface_format=None, 16 | agent_interface="feature", 17 | max_steps_per_episode=10000, 18 | screen_resolution=64, 19 | screen_ratio=1.33, 20 | camera_width_world_units=24, 21 | minimap_resolution=64, 22 | step_mul=4, 23 | score_index=-1, 24 | score_multiplier=1.0/1000, 25 | disable_fog=False, 26 | random_seed=None, 27 | visualize=False, 28 | max_reset_num=100, 29 | save_replay_episodes=0, 30 | replay_dir=None, 31 | version=None, 32 | update_game_info=False, 33 | use_pysc2_feature=True, 34 | game_core_config={} 35 | ): 36 | self._version = version 37 | self.replay_dir = replay_dir 38 | self.save_replay_episodes = save_replay_episodes 39 | self.map_name = map_name 40 | self.players = players 41 | assert len(self.players) == 2 42 | assert all([isinstance(p, sc2_env.Agent) or 43 | isinstance(p, sc2_env.Bot) for p in self.players]) 44 | self.agent_players = [p for p in self.players 45 | if isinstance(p, sc2_env.Agent)] 46 | self.max_steps_per_episode = max_steps_per_episode 47 | self.step_mul = step_mul 48 | self.disable_fog = disable_fog 49 | self.visualize = visualize 50 | self.random_seed = random_seed 51 | self.score_index = score_index 52 | self.score_multiplier = score_multiplier 53 | self.agent_interface_format = agent_interface_format 54 | self.use_pysc2_feature = use_pysc2_feature 55 | self.game_core_config = game_core_config 56 | self.update_game_info = update_game_info 57 | if agent_interface_format is None: 58 | self.agent_interface = agent_interface 59 | screen_res = (int(screen_ratio * screen_resolution) // 4 * 4, 60 | screen_resolution) 61 | if agent_interface == 'rgb': 62 | self.agent_interface_format = \ 63 | sc2_env.AgentInterfaceFormat( 64 | rgb_dimensions=sc2_env.Dimensions( 65 | screen=screen_res, 66 | minimap=minimap_resolution), 67 | camera_width_world_units=camera_width_world_units) 68 | elif agent_interface == 'feature': 69 | self.agent_interface_format = \ 70 | sc2_env.AgentInterfaceFormat( 71 | feature_dimensions=sc2_env.Dimensions( 72 | screen=screen_res, 73 | minimap=minimap_resolution), 74 | camera_width_world_units=camera_width_world_units) 75 | 76 | self._reset_num = 0 77 | self.max_reset_num = max_reset_num 78 | self._start_env() 79 | self._gameinfo = self.env._controllers[0].game_info() 80 | 81 | self._obs = None 82 | self._rew = None 83 | self._done = None 84 | self._info = None 85 | 86 | self.observation_space = spaces.Tuple([SC2RawObsSpace()] * len(self.agent_players)) 87 | self.action_space = spaces.Tuple([SC2RawActSpace()] * len(self.agent_players)) 88 | 89 | def _start_env(self): 90 | if isinstance(self.map_name, list) or isinstance(self.map_name, tuple): 91 | map_name = random.choice(self.map_name) 92 | self.max_reset_num = 0 93 | else: 94 | map_name = self.map_name 95 | self.env = SC2Env(map_name=map_name, 96 | players=self.players, 97 | step_mul=self.step_mul, 98 | agent_interface_format=self.agent_interface_format, 99 | game_steps_per_episode=self.max_steps_per_episode, 100 | disable_fog=self.disable_fog, 101 | visualize=self.visualize, 102 | random_seed=self.random_seed, 103 | score_index=self.score_index, 104 | score_multiplier=self.score_multiplier, 105 | save_replay_episodes=self.save_replay_episodes, 106 | replay_dir=self.replay_dir, 107 | version=self._version, 108 | use_pysc2_feature=self.use_pysc2_feature, 109 | update_game_info=self.update_game_info, 110 | **self.game_core_config, 111 | ) 112 | 113 | def reset(self, **kwargs): 114 | self._reset_num += 1 115 | if self._reset_num > self.max_reset_num >=0: 116 | self._reset_num = 0 117 | self.close() 118 | self._start_env() 119 | self._obs = self.env.reset() 120 | return self._obs 121 | 122 | def step(self, raw_actions, **kwargs): 123 | timesteps = self.env.step(raw_actions, **kwargs) 124 | self._obs = timesteps 125 | self._rew = [timestep.reward for timestep in timesteps] 126 | self._done = True if timesteps[0].last() else False 127 | self._info = {'outcome': [0] * len(self.agent_players)} 128 | if self._done: 129 | self._info = {'outcome': self.get_outcome()} 130 | return self._obs, self._rew, self._done, self._info 131 | 132 | def close(self): 133 | if self.env: 134 | self.env.close() 135 | 136 | def get_outcome(self): 137 | outcome = [0] * len(self.agent_players) 138 | for i, o in enumerate(self._obs): 139 | player_id = o.observation.player_common.player_id 140 | for result in self.env._obs[i].player_result: 141 | if result.player_id == player_id: 142 | outcome[i] = sc2_env.possible_results.get(result.result, 0) 143 | return outcome 144 | -------------------------------------------------------------------------------- /arena/wrappers/vizdoom/reward.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | 5 | from math import sqrt, cos, sin, pi 6 | 7 | import gym 8 | from gym.spaces import Box, Discrete 9 | import vizdoom as vd 10 | import numpy as np 11 | 12 | 13 | class FragReward(gym.RewardWrapper): 14 | """ Frag as reward """ 15 | def __init__(self, env): 16 | gym.RewardWrapper.__init__(self, env) 17 | self._prev_frag = 0 18 | self._cur_frag = 0 19 | 20 | def reset(self, **kwargs): 21 | obs = self.env.reset(**kwargs) 22 | self._prev_frag = self.env.unwrapped.game.get_game_variable( 23 | vd.GameVariable.FRAGCOUNT) 24 | self._cur_frag = self._prev_frag 25 | return obs 26 | 27 | def reward(self, r): 28 | self._cur_frag = self.env.unwrapped.game.get_game_variable( 29 | vd.GameVariable.FRAGCOUNT 30 | ) 31 | rr = self._cur_frag - self._prev_frag 32 | self._prev_frag = self._cur_frag 33 | return float(rr) 34 | 35 | 36 | class TrackerObjectDistAngleExampleReward(gym.RewardWrapper): 37 | """ Tracker-Object Distance-Angle Reward. Wrap over vec env. 38 | Presume GameVariables[0:3] are (x, y, angle), which can be done by config 39 | file, see: 40 | https://github.com/mwydmuch/ViZDoom/blob/6fe0d2470872adbfa5d18c53c7704e6ff103cacc/scenarios/health_gathering.cfg#L34 41 | https://github.com/mwydmuch/ViZDoom/blob/6fe0d2470872adbfa5d18c53c7704e6ff103cacc/examples/python/shaping.py#L82 42 | or by game.add_available_variables(...). See: 43 | https://github.com/mwydmuch/ViZDoom/blob/master/doc/Types.md#gamevariable 44 | https://github.com/mwydmuch/ViZDoom/blob/6fe0d2470872adbfa5d18c53c7704e6ff103cacc/examples/python/labels.py#L51 45 | For object tracking task. An example. """ 46 | def __init__(self, venv): 47 | gym.RewardWrapper.__init__(self, venv) 48 | self._cur_pos_xya = None 49 | self._num_players = 2 50 | 51 | def reset(self, **kwargs): 52 | observations = self.env.reset(**kwargs) 53 | self._checkup() 54 | self._grab() 55 | return observations 56 | 57 | def step(self, actions): 58 | observations, _, dones, infos = self.env.step(actions) 59 | self._grab() 60 | rs = self.rewards_position_angle() 61 | return observations, rs, dones, infos 62 | 63 | def rewards_position_angle(self): 64 | def my_dist(xya_one, xya_two): 65 | dx, dy, da = [item1 - item2 for item1, item2 in zip(xya_one, xya_two)] 66 | return sqrt(dx*dx + dy*dy + da*da) 67 | 68 | if self._cur_pos_xya is None or len(self._cur_pos_xya) != self._num_players: 69 | return [0.0, 0.0] 70 | 71 | xya_tracker = self._cur_pos_xya[0] 72 | xya_object = self._cur_pos_xya[1] 73 | print('xya_tracker = ', xya_tracker) 74 | print('xya_object = ', xya_object) 75 | 76 | r_tracker = 1 / (my_dist(xya_tracker, xya_object) + 0.1) 77 | r_object = -r_tracker 78 | 79 | #return [1.0] 80 | return [r_tracker, r_object] 81 | 82 | def _checkup(self): 83 | assert(len(self.env.unwrapped.envs) == self._num_players) 84 | for e in self.env.unwrapped.envs: 85 | assert(len(e.unwrapped._state.game_variables) >= 3) # at least x, y, a 86 | 87 | def _grab(self): 88 | self._cur_pos_xya = [] 89 | for e in self.env.unwrapped.envs: 90 | this_state = e.unwrapped._state 91 | if this_state is not None: 92 | self._cur_pos_xya.append(this_state.game_variables) 93 | 94 | 95 | class TrackerObjectDistAngleReward(gym.RewardWrapper): 96 | """ Tracker-Object Distance-Angle Reward. Wrap over vec env. 97 | 98 | Fangwei's settings. """ 99 | def __init__(self, venv, max_steps): 100 | gym.RewardWrapper.__init__(self, venv) 101 | self._cur_pos_xya = None 102 | self._num_players = 2 103 | self.max_steps = max_steps 104 | 105 | def reset(self, **kwargs): 106 | observations = self.env.reset(**kwargs) 107 | self._checkup() 108 | self._grab() 109 | self.count_done = 0 110 | self.count_step = 0 111 | return observations 112 | 113 | def step(self, actions): 114 | self.count_step += 1 115 | 116 | observations, _, dones, infos = self.env.step(actions) 117 | self._grab() 118 | 119 | if all(dones): # episode end 120 | return observations, [0.0, 0.0], dones, infos 121 | 122 | # print('tracker xya = ', self._cur_pos_xya[0]) 123 | # print('object xya = ', self._cur_pos_xya[1]) 124 | 125 | rs, outrange = self.rewards_position_angle() 126 | if outrange: 127 | self.count_done += 1 128 | else: 129 | self.count_done=0 130 | if self.count_done > 20 or self.count_step > self.max_steps: 131 | dones = [True, True] 132 | else: 133 | dones = [False, False] 134 | return observations, rs, dones, infos 135 | 136 | def world_to_local(self, xya_one, xya_two): 137 | # vizdoom fixed point angle to radius 138 | x0, y0, a0 = xya_one 139 | xt, yt, at = xya_two 140 | theta = a0 / 180.0 * pi 141 | # orientation to rotation 142 | theta -= pi/2 143 | # common origin of world and local coordinate system 144 | dx, dy = xt - x0, yt - y0 145 | # coordinate rotation 146 | x_ = dx * cos(theta) + dy * sin(theta) 147 | y_ = -dx * sin(theta) + dy * cos(theta) 148 | a_ = a0 - at 149 | return x_, y_, a_ 150 | 151 | def get_reward(self, dx, dy, exp_dis=128): 152 | dist = sqrt(dx * dx + dy * dy) 153 | theta = abs(np.arctan2(dx, dy)/(pi/4)) 154 | e_dis_relative = abs((dist - exp_dis)/exp_dis) 155 | reward_tracker = 1.0 - min(e_dis_relative, 1.0) - min(theta, 1.0) 156 | reward_tracker = max(reward_tracker, -1) 157 | e_theta = abs(theta - 1.0) 158 | e_dis = abs(e_dis_relative - 1.0) 159 | reward_object = 1.0 - (min(e_dis, 1.0) + min(e_theta, 1.0)) 160 | outrange = False 161 | if e_dis_relative > 2 or theta > 1: 162 | outrange = True 163 | 164 | return reward_tracker, reward_object, outrange 165 | 166 | def rewards_position_angle(self): 167 | if self._cur_pos_xya is None: 168 | print ('None players!') 169 | return [0.0, 0.0], True 170 | if len(self._cur_pos_xya) == 0: 171 | print('0 players!') 172 | return [0.0, 0.0], True 173 | 174 | xya_object = self._cur_pos_xya[1] 175 | xya_tracker = self._cur_pos_xya[0] 176 | xx_, yy_, aa_ = self.world_to_local(xya_tracker, xya_object) 177 | r_tracker, r_object, outrange = self.get_reward(xx_, yy_, exp_dis=128) 178 | r_object = -r_tracker 179 | return [r_tracker, r_object], outrange 180 | 181 | def _checkup(self): 182 | assert(len(self.env.unwrapped.envs) == self._num_players) 183 | for e in self.env.unwrapped.envs: 184 | assert(len(e.unwrapped._state.game_variables) >= 3) # at least x, y, a 185 | 186 | def _grab(self): 187 | self._cur_pos_xya = [] 188 | for e in self.env.unwrapped.envs: 189 | this_state = e.unwrapped._state 190 | if this_state is not None: 191 | self._cur_pos_xya.append(this_state.game_variables) 192 | -------------------------------------------------------------------------------- /arena/env/soccer_env.py: -------------------------------------------------------------------------------- 1 | """ Arena compatible soccer env """ 2 | from dm_control.locomotion import soccer as dm_soccer 3 | from gym import core, spaces 4 | from gym.utils import seeding 5 | import numpy as np 6 | from dm_env import specs 7 | import pyglet 8 | import sys 9 | import cv2 10 | from arena.utils.spaces import NoneSpace 11 | from arena.interfaces.combine import Combine 12 | 13 | class DmControlViewer: 14 | def __init__(self, width, height, depth=False): 15 | self.window = pyglet.window.Window(width=width, height=height, display=None) 16 | self.width = width 17 | self.height = height 18 | self.depth = depth 19 | 20 | if depth: 21 | self.format = 'RGB' 22 | self.pitch = self.width * -3 23 | else: 24 | self.format = 'RGB' 25 | self.pitch = self.width * -3 26 | 27 | def update(self, pixel): 28 | self.window.clear() 29 | self.window.switch_to() 30 | self.window.dispatch_events() 31 | if self.depth: 32 | pixel = np.dstack([pixel.astype(np.uint8)] * 3) 33 | pyglet.image.ImageData(self.width, self.height, self.format, pixel.tobytes(), pitch=self.pitch).blit(0, 0) 34 | self.window.flip() 35 | 36 | def close(self): 37 | self.window.close() 38 | 39 | class soccer_gym(core.Env): 40 | def __init__(self, team_size = 2, time_limit=45, disable_walker_contacts=True, team_num=2, render_name="human"): 41 | self.team_size = team_size 42 | self.team_num = team_num 43 | self.env = dm_soccer.load(self.team_size, time_limit, disable_walker_contacts) 44 | ac_sp_i = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32) 45 | ac_sp = spaces.Tuple([spaces.Tuple(tuple([ac_sp_i]*self.team_size))]*self.team_num) 46 | self.action_space = ac_sp 47 | #print(self.action_space) 48 | self.observation_space = spaces.Tuple([NoneSpace(), NoneSpace()]) 49 | self.timestep = None 50 | odict_sp = {} 51 | odict = self.env.observation_spec() 52 | for key in odict[0]: 53 | odict_sp[key] = spaces.Box(-np.inf, np.inf, shape=(np.int(np.prod(odict[0][key].shape)),)) 54 | self.observation_space = spaces.Tuple([spaces.Tuple([spaces.Dict(odict_sp)]*self.team_size)]*self.team_num) 55 | # render 56 | render_mode_list = self.create_render_mode(render_name, show=False, return_pixel=True) 57 | if render_mode_list is not None: 58 | self.metadata['render.modes'] = list(render_mode_list.keys()) 59 | self.viewer = {key:None for key in render_mode_list.keys()} 60 | else: 61 | self.metadata['render.modes'] = [] 62 | self.render_mode_list = render_mode_list 63 | # set seed 64 | self._seed() 65 | 66 | def _seed(self, seed=None): 67 | self.np_random, seed = seeding.np_random(seed) 68 | return [seed] 69 | 70 | def _team_obs_trans(self, timestep_obs): 71 | obs = [] 72 | for i in range(self.team_num): 73 | obs_t = [] 74 | for j in range(self.team_size): 75 | indx = i * self.team_size + j 76 | obs_t.append(timestep_obs[indx]) 77 | obs.append(obs_t) 78 | return obs 79 | 80 | def reset(self): 81 | self.timestep = self.env.reset() 82 | obs = self._team_obs_trans(self.timestep.observation) 83 | return obs 84 | 85 | def step(self, a): 86 | # team actions 87 | act = [] 88 | for i in range(self.team_num): 89 | act.extend(a[i]) 90 | act = np.clip(act, -1., 1.) 91 | self.timestep = self.env.step(act) 92 | r, obs, info = [], [], [] 93 | for i in range(self.team_num): 94 | r_t, obs_t, info_t = [], [], [] 95 | for j in range(self.team_size): 96 | ar = [] 97 | indx = i * self.team_size + j 98 | ar.append(float(self.timestep.observation[indx]["stats_home_score"])) 99 | ar.append(-1. * float(self.timestep.observation[indx]["stats_away_score"])) 100 | ar.append(float(self.timestep.observation[indx]["stats_vel_to_ball"])) 101 | ar.append(float(self.timestep.observation[indx]["stats_vel_ball_to_goal"])) 102 | ainfo = [int(self.timestep.observation[indx]["stats_home_score"]), int(self.timestep.observation[indx]["stats_away_score"])] 103 | r_t.append(ar) 104 | obs_t.append(self.timestep.observation[indx]) 105 | info_t.append(ainfo) 106 | obs.append(obs_t) 107 | r.append(np.mean(np.array(r_t),0).tolist()) 108 | info.append(info_t) 109 | return obs, r, self.timestep.last(), info 110 | 111 | def create_render_mode(self, name, show=True, return_pixel=False, height=480, width=640, camera_id=0, overlays=(), 112 | depth=False, scene_option=None): 113 | render_mode_list = {} 114 | render_kwargs = { 'height': height, 'width': width, 'camera_id': camera_id, 115 | 'overlays': overlays, 'depth': depth, 'scene_option': scene_option} 116 | render_mode_list[name] = {'show': show, 'return_pixel': return_pixel, 'render_kwargs': render_kwargs} 117 | return render_mode_list 118 | 119 | def render(self, mode='human', close=False): 120 | self.pixels = self.env.physics.render(**self.render_mode_list[mode]['render_kwargs']) 121 | if close: 122 | if self.viewer[mode] is not None: 123 | self._get_viewer(mode).close() 124 | self.viewer[mode] = None 125 | return 126 | elif self.render_mode_list[mode]['show']: 127 | self._get_viewer(mode).update(self.pixels) 128 | 129 | if self.render_mode_list[mode]['return_pixel']: 130 | #return self.pixels 131 | frame = self.pixels 132 | cv2.imshow('demo', frame) 133 | cv2.waitKey(100) 134 | return 135 | 136 | def _get_viewer(self, mode): 137 | if self.viewer[mode] is None: 138 | self.viewer[mode] = DmControlViewer(self.pixels.shape[1], self.pixels.shape[0], self.render_mode_list[mode]['render_kwargs']['depth']) 139 | return self.viewer[mode] 140 | 141 | def main(): 142 | team_size = 2 143 | env = soccer_gym(team_size, time_limit=45.) 144 | 145 | from arena.env.env_int_wrapper import EnvIntWrapper 146 | from arena.interfaces.soccer.obs_int import ConcatObsAct, Dict2Vec 147 | 148 | inter1 = Combine(None, [Dict2Vec(None), Dict2Vec(None)]) 149 | inter1 = ConcatObsAct(inter1) 150 | inter2 = Combine(None, [Dict2Vec(None), Dict2Vec(None)]) 151 | inter2 = ConcatObsAct(inter2) 152 | env = EnvIntWrapper(env, [inter1, inter2]) 153 | state = env.reset() 154 | done = False 155 | print(env.observation_space) 156 | print(env.action_space) 157 | for t in range(10): 158 | all_acts = env.action_space.sample() 159 | observation, reward, done, info = env.step(all_acts) 160 | 161 | if __name__ == '__main__': 162 | main() -------------------------------------------------------------------------------- /arena/wrappers/pong2p/pong2p_wrappers.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import cv2 6 | import gym 7 | import numpy as np 8 | 9 | from collections import deque 10 | from gym import spaces 11 | 12 | class ClipRewardEnv(gym.RewardWrapper): 13 | def __init__(self, env): 14 | gym.RewardWrapper.__init__(self, env) 15 | 16 | def reward(self, reward): 17 | """Bin reward to {+1, 0, -1} by its sign.""" 18 | if isinstance(reward, tuple): 19 | return tuple([np.sign(r) for r in reward]) 20 | else: 21 | return np.sign(reward) 22 | 23 | 24 | class WarpFrame(gym.ObservationWrapper): 25 | def __init__(self, env): 26 | """Warp frames to 84x84 as done in the Nature paper and later work.""" 27 | gym.ObservationWrapper.__init__(self, env) 28 | self.width = 84 29 | self.height = 84 30 | if isinstance(self.observation_space, spaces.Tuple): 31 | self.observation_space = spaces.Tuple([ 32 | spaces.Box(low=0, high=255, 33 | shape=(self.height, self.width, 1), dtype=np.uint8) 34 | for space in self.env.observation_space.spaces 35 | ]) 36 | else: 37 | self.observation_space = spaces.Box(low=0, high=255, 38 | shape=(self.height, self.width, 1), dtype=np.uint8) 39 | 40 | def observation(self, observation): 41 | if isinstance(observation, tuple): 42 | return tuple([self._process_frame(obs) for obs in observation]) 43 | else: 44 | return self._process_frame(observation) 45 | 46 | def _process_frame(self, frame): 47 | assert (frame.ndim == 3 and 48 | (frame.shape[2] == 3 or frame.shape[2] == 1) and 49 | frame.shape[0] == 210 and frame.shape[1] == 160) 50 | 51 | frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) 52 | frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA) 53 | return frame[:, :, None] 54 | 55 | 56 | class ScaledFloatFrame(gym.ObservationWrapper): 57 | def __init__(self, env): 58 | gym.ObservationWrapper.__init__(self, env) 59 | if isinstance(self.observation_space, spaces.Tuple): 60 | self.observation_space = spaces.Tuple([ 61 | spaces.Box(low=0.0, high=1.0, 62 | shape=space.shape, dtype=np.float32) 63 | for space in self.env.observation_space.spaces 64 | ]) 65 | else: 66 | self.observation_space = spaces.Box(low=0.0, high=1.0, 67 | shape=self.env.observation_space.shape, dtype=np.float32) 68 | # shape=self.env.observation_space.spaces[0].shape, dtype=np.float32) 69 | 70 | def observation(self, observation): 71 | # careful! This undoes the memory optimization, use 72 | # with smaller replay buffers only. 73 | if isinstance(observation, tuple): 74 | return tuple([ 75 | np.array(obs).astype(np.float32) / 255.0 76 | for obs in observation 77 | ]) 78 | else: 79 | return np.array(observation).astype(np.float32) / 255.0 80 | 81 | 82 | class FrameStack(gym.Wrapper): 83 | def __init__(self, env, k): 84 | """Stack k last frames. 85 | 86 | Returns lazy array, which is much more memory efficient. 87 | 88 | See Also 89 | -------- 90 | baselines.common.atari_wrappers.LazyFrames 91 | """ 92 | gym.Wrapper.__init__(self, env) 93 | self.k = k 94 | if isinstance(env.observation_space, spaces.Tuple): 95 | self.frames = tuple([ 96 | deque([], maxlen=k) 97 | for _ in env.observation_space.spaces 98 | ]) 99 | self.observation_space = spaces.Tuple([ 100 | spaces.Box(low=0, high=255, shape=(obs.shape[0], obs.shape[1], obs.shape[2] * k), dtype=np.uint8) 101 | for obs in env.observation_space.spaces 102 | ]) 103 | else: 104 | self.frames = deque([], maxlen=k) 105 | shp = env.observation_space.shape 106 | self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8) 107 | 108 | def reset(self, **kwargs): 109 | obs = self.env.reset(**kwargs) 110 | if isinstance(obs, tuple): 111 | for frame, ob in zip(self.frames, obs): 112 | for _ in range(self.k): 113 | frame.append(ob) 114 | else: 115 | for _ in range(self.k): 116 | self.frames.append(obs) 117 | 118 | return self._get_ob() 119 | 120 | def step(self, action): 121 | obs, reward, done, info = self.env.step(action) 122 | if isinstance(obs, tuple): 123 | for frame, ob in zip(self.frames, obs): 124 | frame.append(ob) 125 | else: 126 | self.frames.append(obs) 127 | 128 | return self._get_ob(), reward, done, info 129 | 130 | def _get_ob(self): 131 | if isinstance(self.observation_space, spaces.Tuple): 132 | for frame in self.frames: 133 | assert len(frame) == self.k 134 | return tuple([ 135 | np.array(LazyFrames(list(frame))) 136 | for frame in self.frames 137 | ]) 138 | else: 139 | assert len(self.frames) == self.k 140 | return np.array(LazyFrames(list(self.frames))) 141 | 142 | class LazyFrames(object): 143 | def __init__(self, frames): 144 | """This object ensures that common frames between the observations are only stored once. 145 | It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay 146 | buffers. 147 | 148 | This object should only be converted to numpy array before being passed to the model. 149 | 150 | You'd not believe how complex the previous solution was.""" 151 | self._frames = frames 152 | self._out = None 153 | 154 | def _force(self): 155 | if self._out is None: 156 | self._out = np.concatenate(self._frames, axis=2) 157 | self._frames = None 158 | return self._out 159 | 160 | def __array__(self, dtype=None): 161 | out = self._force() 162 | if dtype is not None: 163 | out = out.astype(dtype) 164 | return out 165 | 166 | def __len__(self): 167 | return len(self._force()) 168 | 169 | def __getitem__(self, i): 170 | return self._force()[i] 171 | 172 | def wrap_pong(env_id, episode_life=True, clip_rewards=True, frame_stack=False, scale=False, seed=None): 173 | """Configure environment for Pong Selfplay. 174 | """ 175 | env = gym.make(env_id) 176 | env.set_seed(seed) 177 | env = WarpFrame(env) 178 | if scale: 179 | env = ScaledFloatFrame(env) 180 | if clip_rewards: 181 | env = ClipRewardEnv(env) 182 | if frame_stack: 183 | env = FrameStack(env, 4) 184 | return env 185 | 186 | -------------------------------------------------------------------------------- /arena/wrappers/vizdoom/reward_shape.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import absolute_import 4 | from math import sqrt 5 | from copy import deepcopy 6 | 7 | import gym 8 | from gym import spaces 9 | import numpy as np 10 | import vizdoom as vd 11 | 12 | 13 | def calc_dist(pos, pos_prev): 14 | x, y = pos 15 | xx, yy = pos_prev 16 | dx, dy = x - xx, y - yy 17 | return sqrt(dx * dx + dy * dy) 18 | 19 | 20 | class RwdShapeWuBasic(gym.RewardWrapper): 21 | """ Wu Yuxing's reward shaping. """ 22 | IND_GAMEVAR_FRAG = 0 23 | IND_GAMEVAR_HEALTH = 1 24 | IND_GAMEVAR_AMMO = 2 25 | EPISODE_REAL_START = 1 26 | LIFE_REAL_START = 3 27 | 28 | def __init__(self, env, is_recompute_reward=True, dist_penalty_thres=1): 29 | super(RwdShapeWuBasic, self).__init__(env) 30 | self.is_recompute_reward = is_recompute_reward 31 | self.dist_penalty_thres = dist_penalty_thres 32 | 33 | self.game_variables = None 34 | self.game_variables_prev = None 35 | self.pos = None 36 | self.pos_prev = None 37 | self.is_dead = None 38 | self.is_dead_prev = None 39 | self.step_this_episode = 0 40 | self.step_this_life = 0 41 | self._update_vars() 42 | 43 | # reset 44 | def reset(self, **kwargs): 45 | obs = self.env.reset(**kwargs) 46 | 47 | self.game_variables = None 48 | self.game_variables_prev = None 49 | self.pos = None 50 | self.pos_prev = None 51 | self.is_dead = None 52 | self.is_dead_prev = None 53 | self.step_this_episode = 0 54 | self.step_this_life = 0 55 | self._update_vars() 56 | 57 | return obs 58 | 59 | # step 60 | def reward(self, reward): 61 | self._update_vars() 62 | self._update_dead() 63 | self.step_this_episode += 1 64 | if not self.is_dead and self.is_dead_prev: 65 | self.step_this_life = 0 66 | else: 67 | self.step_this_life += 1 68 | 69 | r = 0. if self.is_recompute_reward else reward 70 | r += self._reward_living() 71 | # if self.is_dead: print('r = ', r) 72 | r += self._reward_dist() 73 | # if self.is_dead: print('r = ', r) 74 | r += self._reward_frag() 75 | # if self.is_dead: print('r = ', r) 76 | r += self._reward_health() 77 | # if self.is_dead: print('r = ', r) 78 | r += self._reward_ammo() 79 | # if self.is_dead: print('r = ', r) 80 | return r 81 | 82 | # helpers: updating 83 | def _update_vars(self): 84 | self.game_variables_prev = deepcopy(self.game_variables) 85 | if self.unwrapped._state is not None: # ensure current frame is available 86 | game = self.unwrapped.game 87 | # common game variables 88 | self.game_variables = [ 89 | game.get_game_variable(vd.GameVariable.FRAGCOUNT), 90 | game.get_game_variable(vd.GameVariable.HEALTH), 91 | game.get_game_variable(vd.GameVariable.SELECTED_WEAPON_AMMO), 92 | ] 93 | self.pos_prev = deepcopy(self.pos) 94 | self.pos = [ 95 | game.get_game_variable(vd.GameVariable.POSITION_X), 96 | game.get_game_variable(vd.GameVariable.POSITION_Y), 97 | ] 98 | 99 | def _update_dead(self): 100 | self.is_dead_prev = deepcopy(self.is_dead) 101 | self.is_dead = self.unwrapped.game.is_player_dead() 102 | 103 | # helpers for reward 104 | def _reward_living(self): 105 | raise NotImplementedError 106 | 107 | def _reward_dist(self): 108 | raise NotImplementedError 109 | 110 | def _reward_frag(self): 111 | raise NotImplementedError 112 | 113 | def _reward_health(self): 114 | raise NotImplementedError 115 | 116 | def _reward_ammo(self): 117 | raise NotImplementedError 118 | 119 | 120 | class RwdShapeWu2(RwdShapeWuBasic): 121 | """ Tweak of Wu Yuxin's reward shaping """ 122 | 123 | # helpers for _reward 124 | def _reward_living(self): 125 | return -0.001 126 | 127 | def _reward_dist(self): 128 | ret = 0. 129 | if self.pos_prev is not None and self.step_this_life > self.LIFE_REAL_START: 130 | d = calc_dist(self.pos, self.pos_prev) 131 | # print(self.pos) 132 | # print(self.pos_prev) 133 | ret = 0.002 if d > self.dist_penalty_thres else 0.0 134 | return ret 135 | 136 | def _reward_frag(self): 137 | ret = 0. 138 | # unavailable for single player game; fine to keep it zero in this case 139 | if self.step_this_life > self.LIFE_REAL_START: 140 | r = (self.game_variables[self.IND_GAMEVAR_FRAG] - 141 | self.game_variables_prev[self.IND_GAMEVAR_FRAG]) 142 | ret = float(r) 143 | return ret 144 | 145 | def _reward_health(self): 146 | ret = 0. 147 | if self.step_this_life > self.LIFE_REAL_START: 148 | r = (self.game_variables[self.IND_GAMEVAR_HEALTH] - 149 | self.game_variables_prev[self.IND_GAMEVAR_HEALTH]) 150 | if r != 0: 151 | ret = 0.5 if r > 0 else -0.1 152 | return ret 153 | 154 | def _reward_ammo(self): 155 | ret = 0. 156 | if self.step_this_life > self.LIFE_REAL_START: 157 | r = (self.game_variables[self.IND_GAMEVAR_AMMO] - 158 | self.game_variables_prev[self.IND_GAMEVAR_AMMO]) 159 | if r != 0: 160 | ret = 0.5 if r > 0 else -0.1 161 | return ret 162 | 163 | class RwdShapeWu3(RwdShapeWuBasic): 164 | """ Tweak of Wu Yuxin's reward shaping """ 165 | 166 | def __init__(self, env, 167 | live, 168 | dist_inc, dist_dec, 169 | health_inc, health_dec, 170 | ammo_inc, ammo_dec): 171 | super(RwdShapeWu3, self).__init__(env) 172 | 173 | self._live = live 174 | self._dist_inc = dist_inc 175 | self._dist_dec = dist_dec 176 | self._health_inc = health_inc 177 | self._health_dec = health_dec 178 | self._ammo_inc = ammo_inc 179 | self._ammo_dec = ammo_dec 180 | 181 | # helpers for _reward 182 | def _reward_living(self): 183 | return self._live 184 | 185 | def _reward_dist(self): 186 | ret = 0. 187 | if self.pos_prev is not None and self.step_this_life > self.LIFE_REAL_START: 188 | d = calc_dist(self.pos, self.pos_prev) 189 | ret = self._dist_inc * d if d > self.dist_penalty_thres else self._dist_dec 190 | return ret 191 | 192 | def _reward_frag(self): 193 | ret = 0. 194 | # unavailable for single player game; fine to keep it zero in this case 195 | if self.step_this_life > self.LIFE_REAL_START: 196 | r = (self.game_variables[self.IND_GAMEVAR_FRAG] - 197 | self.game_variables_prev[self.IND_GAMEVAR_FRAG]) 198 | ret = float(r) 199 | return ret 200 | 201 | def _reward_health(self): 202 | ret = 0. 203 | if self.step_this_life > self.LIFE_REAL_START: 204 | r = (self.game_variables[self.IND_GAMEVAR_HEALTH] - 205 | self.game_variables_prev[self.IND_GAMEVAR_HEALTH]) 206 | if r != 0: 207 | ret = self._health_inc * r if r > 0 else self._health_dec * r 208 | return ret 209 | 210 | def _reward_ammo(self): 211 | ret = 0. 212 | if self.step_this_life > self.LIFE_REAL_START: 213 | r = (self.game_variables[self.IND_GAMEVAR_AMMO] - 214 | self.game_variables_prev[self.IND_GAMEVAR_AMMO]) 215 | if r != 0: 216 | ret = self._ammo_inc * r if r > 0 else self._ammo_dec * r 217 | return ret 218 | -------------------------------------------------------------------------------- /arena/utils/constant.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from pysc2.lib.typeenums import UNIT_TYPEID 3 | 4 | 5 | class AllianceType(Enum): 6 | SELF = 1 7 | ALLY = 2 8 | NEUTRAL = 3 9 | ENEMY = 4 10 | 11 | 12 | COMBAT_UNITS = set([ 13 | # Zerg 14 | UNIT_TYPEID.ZERG_BANELING.value, 15 | UNIT_TYPEID.ZERG_BANELINGBURROWED.value, 16 | UNIT_TYPEID.ZERG_BROODLING.value, 17 | UNIT_TYPEID.ZERG_BROODLORD.value, 18 | UNIT_TYPEID.ZERG_CHANGELING.value, 19 | UNIT_TYPEID.ZERG_CHANGELINGMARINE.value, 20 | UNIT_TYPEID.ZERG_CHANGELINGMARINESHIELD.value, 21 | UNIT_TYPEID.ZERG_CHANGELINGZEALOT.value, 22 | UNIT_TYPEID.ZERG_CHANGELINGZERGLING.value, 23 | UNIT_TYPEID.ZERG_CHANGELINGZERGLINGWINGS.value, 24 | UNIT_TYPEID.ZERG_CORRUPTOR.value, 25 | UNIT_TYPEID.ZERG_HYDRALISK.value, 26 | UNIT_TYPEID.ZERG_HYDRALISKBURROWED.value, 27 | UNIT_TYPEID.ZERG_INFESTOR.value, 28 | UNIT_TYPEID.ZERG_INFESTORBURROWED.value, 29 | UNIT_TYPEID.ZERG_MUTALISK.value, 30 | UNIT_TYPEID.ZERG_NYDUSCANAL.value, 31 | UNIT_TYPEID.ZERG_OVERLORD.value, 32 | UNIT_TYPEID.ZERG_OVERSEER.value, 33 | UNIT_TYPEID.ZERG_QUEEN.value, 34 | UNIT_TYPEID.ZERG_QUEENBURROWED.value, 35 | UNIT_TYPEID.ZERG_RAVAGER.value, 36 | UNIT_TYPEID.ZERG_ROACH.value, 37 | UNIT_TYPEID.ZERG_ROACHBURROWED.value, 38 | UNIT_TYPEID.ZERG_SPINECRAWLER.value, 39 | UNIT_TYPEID.ZERG_SPINECRAWLERUPROOTED.value, 40 | UNIT_TYPEID.ZERG_SPORECRAWLER.value, 41 | UNIT_TYPEID.ZERG_SPORECRAWLERUPROOTED.value, 42 | UNIT_TYPEID.ZERG_SWARMHOSTMP.value, 43 | UNIT_TYPEID.ZERG_ULTRALISK.value, 44 | UNIT_TYPEID.ZERG_ZERGLING.value, 45 | UNIT_TYPEID.ZERG_ZERGLINGBURROWED.value, 46 | UNIT_TYPEID.ZERG_LURKERMP.value, 47 | UNIT_TYPEID.ZERG_LURKERMPBURROWED.value, 48 | UNIT_TYPEID.ZERG_VIPER.value, 49 | 50 | # TERRAN 51 | UNIT_TYPEID.TERRAN_SCV.value, 52 | UNIT_TYPEID.TERRAN_GHOST.value, 53 | UNIT_TYPEID.TERRAN_MARAUDER.value, 54 | UNIT_TYPEID.TERRAN_MARINE.value, 55 | UNIT_TYPEID.TERRAN_REAPER.value, 56 | UNIT_TYPEID.TERRAN_HELLION.value, 57 | UNIT_TYPEID.TERRAN_CYCLONE.value, 58 | UNIT_TYPEID.TERRAN_SIEGETANK.value, 59 | UNIT_TYPEID.TERRAN_THOR.value, 60 | UNIT_TYPEID.TERRAN_WIDOWMINE.value, 61 | UNIT_TYPEID.TERRAN_NUKE.value, 62 | UNIT_TYPEID.TERRAN_BANSHEE.value, 63 | UNIT_TYPEID.TERRAN_BATTLECRUISER.value, 64 | UNIT_TYPEID.TERRAN_LIBERATOR.value, 65 | UNIT_TYPEID.TERRAN_VIKINGFIGHTER.value, 66 | UNIT_TYPEID.TERRAN_RAVEN.value, 67 | UNIT_TYPEID.TERRAN_MEDIVAC.value, 68 | UNIT_TYPEID.TERRAN_MULE.value, 69 | 70 | # Protoss 71 | UNIT_TYPEID.PROTOSS_PROBE.value, 72 | UNIT_TYPEID.PROTOSS_MOTHERSHIPCORE.value, 73 | UNIT_TYPEID.PROTOSS_ZEALOT.value, 74 | UNIT_TYPEID.PROTOSS_SENTRY.value, 75 | UNIT_TYPEID.PROTOSS_STALKER.value, 76 | UNIT_TYPEID.PROTOSS_HIGHTEMPLAR.value, 77 | UNIT_TYPEID.PROTOSS_DARKTEMPLAR.value, 78 | UNIT_TYPEID.PROTOSS_ADEPT.value, 79 | UNIT_TYPEID.PROTOSS_COLOSSUS.value, 80 | UNIT_TYPEID.PROTOSS_DISRUPTOR.value, 81 | UNIT_TYPEID.PROTOSS_WARPPRISM.value, 82 | UNIT_TYPEID.PROTOSS_OBSERVER.value, 83 | UNIT_TYPEID.PROTOSS_IMMORTAL.value, 84 | UNIT_TYPEID.PROTOSS_CARRIER.value, 85 | UNIT_TYPEID.PROTOSS_ORACLE.value, 86 | UNIT_TYPEID.PROTOSS_PHOENIX.value, 87 | UNIT_TYPEID.PROTOSS_VOIDRAY.value, 88 | UNIT_TYPEID.PROTOSS_TEMPEST.value, 89 | UNIT_TYPEID.PROTOSS_INTERCEPTOR.value, 90 | UNIT_TYPEID.PROTOSS_ORACLESTASISTRAP.value 91 | ]) 92 | 93 | ZERG_COMBAT_UNITS = set([ 94 | # Zerg 95 | UNIT_TYPEID.ZERG_BANELING.value, 96 | # UNIT_TYPEID.ZERG_BANELINGBURROWED.value, 97 | UNIT_TYPEID.ZERG_BROODLING.value, 98 | UNIT_TYPEID.ZERG_BROODLORD.value, 99 | # UNIT_TYPEID.ZERG_CHANGELING.value, 100 | # UNIT_TYPEID.ZERG_CHANGELINGMARINE.value, 101 | # UNIT_TYPEID.ZERG_CHANGELINGMARINESHIELD.value, 102 | # UNIT_TYPEID.ZERG_CHANGELINGZEALOT.value, 103 | # UNIT_TYPEID.ZERG_CHANGELINGZERGLING.value, 104 | # UNIT_TYPEID.ZERG_CHANGELINGZERGLINGWINGS.value, 105 | UNIT_TYPEID.ZERG_CORRUPTOR.value, 106 | UNIT_TYPEID.ZERG_HYDRALISK.value, 107 | # UNIT_TYPEID.ZERG_HYDRALISKBURROWED.value, 108 | UNIT_TYPEID.ZERG_INFESTOR.value, 109 | # UNIT_TYPEID.ZERG_INFESTORBURROWED.value, 110 | UNIT_TYPEID.ZERG_MUTALISK.value, 111 | # UNIT_TYPEID.ZERG_NYDUSCANAL.value, 112 | UNIT_TYPEID.ZERG_OVERLORD.value, 113 | UNIT_TYPEID.ZERG_OVERSEER.value, 114 | UNIT_TYPEID.ZERG_QUEEN.value, 115 | # UNIT_TYPEID.ZERG_QUEENBURROWED.value, 116 | UNIT_TYPEID.ZERG_RAVAGER.value, 117 | UNIT_TYPEID.ZERG_ROACH.value, 118 | # UNIT_TYPEID.ZERG_ROACHBURROWED.value, 119 | UNIT_TYPEID.ZERG_SPINECRAWLER.value, 120 | # UNIT_TYPEID.ZERG_SPINECRAWLERUPROOTED.value, 121 | UNIT_TYPEID.ZERG_SPORECRAWLER.value, 122 | # UNIT_TYPEID.ZERG_SPORECRAWLERUPROOTED.value, 123 | # UNIT_TYPEID.ZERG_SWARMHOSTMP.value, 124 | UNIT_TYPEID.ZERG_ULTRALISK.value, 125 | UNIT_TYPEID.ZERG_ZERGLING.value, 126 | # UNIT_TYPEID.ZERG_ZERGLINGBURROWED.value, 127 | UNIT_TYPEID.ZERG_LURKERMP.value, 128 | # UNIT_TYPEID.ZERG_LURKERMPBURROWED.value, 129 | UNIT_TYPEID.ZERG_VIPER.value, 130 | ]) 131 | 132 | ZERG_BUILDING_UNITS = set([ 133 | # Zerg 134 | UNIT_TYPEID.ZERG_HATCHERY.value, 135 | # UNIT_TYPEID.ZERG_SPINECRAWLER.value, 136 | # UNIT_TYPEID.ZERG_SPORECRAWLER.value, 137 | UNIT_TYPEID.ZERG_EXTRACTOR.value, 138 | UNIT_TYPEID.ZERG_SPAWNINGPOOL.value, 139 | UNIT_TYPEID.ZERG_EVOLUTIONCHAMBER.value, 140 | UNIT_TYPEID.ZERG_ROACHWARREN.value, 141 | UNIT_TYPEID.ZERG_BANELINGNEST.value, 142 | # UNIT_TYPEID.ZERG_CREEPTUMOR.value, 143 | UNIT_TYPEID.ZERG_LAIR.value, 144 | UNIT_TYPEID.ZERG_HYDRALISKDEN.value, 145 | UNIT_TYPEID.ZERG_LURKERDENMP.value, 146 | UNIT_TYPEID.ZERG_SPIRE.value, 147 | UNIT_TYPEID.ZERG_SWARMHOSTBURROWEDMP.value, 148 | # UNIT_TYPEID.ZERG_NYDUSNETWORK.value, 149 | UNIT_TYPEID.ZERG_INFESTATIONPIT.value, 150 | UNIT_TYPEID.ZERG_HIVE.value, 151 | UNIT_TYPEID.ZERG_GREATERSPIRE.value, 152 | UNIT_TYPEID.ZERG_ULTRALISKCAVERN.value, 153 | ]) 154 | 155 | NOR_CONST = 50.0 156 | 157 | ZERG_COMBAT_UNITS_FEAT_NOR = { 158 | UNIT_TYPEID.ZERG_BANELING.value: 1.0 / NOR_CONST, 159 | # UNIT_TYPEID.ZERG_BANELINGBURROWED.value, 160 | UNIT_TYPEID.ZERG_BROODLING.value: 0.2 / NOR_CONST, 161 | UNIT_TYPEID.ZERG_BROODLORD.value: 4.0 / NOR_CONST, 162 | # UNIT_TYPEID.ZERG_CHANGELING.value, 163 | # UNIT_TYPEID.ZERG_CHANGELINGMARINE.value, 164 | # UNIT_TYPEID.ZERG_CHANGELINGMARINESHIELD.value, 165 | # UNIT_TYPEID.ZERG_CHANGELINGZEALOT.value, 166 | # UNIT_TYPEID.ZERG_CHANGELINGZERGLING.value, 167 | # UNIT_TYPEID.ZERG_CHANGELINGZERGLINGWINGS.value, 168 | UNIT_TYPEID.ZERG_CORRUPTOR.value: 2.0 / NOR_CONST, 169 | UNIT_TYPEID.ZERG_HYDRALISK.value: 2.0 / NOR_CONST, 170 | # UNIT_TYPEID.ZERG_HYDRALISKBURROWED.value, 171 | UNIT_TYPEID.ZERG_INFESTOR.value: 2.0 / NOR_CONST, 172 | # UNIT_TYPEID.ZERG_INFESTORBURROWED.value, 173 | UNIT_TYPEID.ZERG_MUTALISK.value: 2.0 / NOR_CONST, 174 | # UNIT_TYPEID.ZERG_NYDUSCANAL.value, 175 | UNIT_TYPEID.ZERG_OVERLORD.value: 1.0 / 30.0, 176 | UNIT_TYPEID.ZERG_OVERSEER.value: 1.0 / 30.0, 177 | UNIT_TYPEID.ZERG_QUEEN.value: 2.0 / NOR_CONST, 178 | # UNIT_TYPEID.ZERG_QUEENBURROWED.value, 179 | UNIT_TYPEID.ZERG_RAVAGER.value: 3.0 / NOR_CONST, 180 | UNIT_TYPEID.ZERG_ROACH.value: 2.0 / NOR_CONST, 181 | # UNIT_TYPEID.ZERG_ROACHBURROWED.value, 182 | UNIT_TYPEID.ZERG_SPINECRAWLER.value: 1.0 / 3.0, 183 | # UNIT_TYPEID.ZERG_SPINECRAWLERUPROOTED.value, 184 | UNIT_TYPEID.ZERG_SPORECRAWLER.value: 1.0 / 3.0, 185 | # UNIT_TYPEID.ZERG_SPORECRAWLERUPROOTED.value, 186 | # UNIT_TYPEID.ZERG_SWARMHOSTMP.value, 187 | UNIT_TYPEID.ZERG_ULTRALISK.value: 6.0 / NOR_CONST, 188 | UNIT_TYPEID.ZERG_ZERGLING.value: 1.0 / NOR_CONST, 189 | # UNIT_TYPEID.ZERG_ZERGLINGBURROWED.value, 190 | UNIT_TYPEID.ZERG_LURKERMP.value: 3.0 / NOR_CONST, 191 | # UNIT_TYPEID.ZERG_LURKERMPBURROWED.value, 192 | UNIT_TYPEID.ZERG_VIPER.value: 3.0 / NOR_CONST, 193 | UNIT_TYPEID.ZERG_DRONE.value: 1.0 / NOR_CONST 194 | } 195 | 196 | MAIN_BASE_BUILDS = set([ 197 | # Zerg 198 | UNIT_TYPEID.ZERG_SPAWNINGPOOL.value, 199 | UNIT_TYPEID.ZERG_ROACHWARREN.value, 200 | UNIT_TYPEID.ZERG_BANELINGNEST.value, 201 | ]) 202 | 203 | MINERAL_UNITS = set([UNIT_TYPEID.NEUTRAL_RICHMINERALFIELD.value, 204 | UNIT_TYPEID.NEUTRAL_RICHMINERALFIELD750.value, 205 | UNIT_TYPEID.NEUTRAL_MINERALFIELD.value, 206 | UNIT_TYPEID.NEUTRAL_MINERALFIELD750.value, 207 | UNIT_TYPEID.NEUTRAL_LABMINERALFIELD.value, 208 | UNIT_TYPEID.NEUTRAL_LABMINERALFIELD750.value, 209 | UNIT_TYPEID.NEUTRAL_PURIFIERRICHMINERALFIELD.value, 210 | UNIT_TYPEID.NEUTRAL_PURIFIERRICHMINERALFIELD750.value, 211 | UNIT_TYPEID.NEUTRAL_PURIFIERMINERALFIELD.value, 212 | UNIT_TYPEID.NEUTRAL_PURIFIERMINERALFIELD750.value, 213 | UNIT_TYPEID.NEUTRAL_BATTLESTATIONMINERALFIELD.value, 214 | UNIT_TYPEID.NEUTRAL_BATTLESTATIONMINERALFIELD750.value]) 215 | 216 | MAXIMUM_NUM = { 217 | UNIT_TYPEID.ZERG_SPAWNINGPOOL: 1, 218 | UNIT_TYPEID.ZERG_ROACHWARREN: 1, 219 | UNIT_TYPEID.ZERG_HYDRALISKDEN: 1, 220 | UNIT_TYPEID.ZERG_HATCHERY: 4, 221 | UNIT_TYPEID.ZERG_EVOLUTIONCHAMBER: 2, 222 | UNIT_TYPEID.ZERG_BANELINGNEST: 1, 223 | UNIT_TYPEID.ZERG_INFESTATIONPIT: 1, 224 | UNIT_TYPEID.ZERG_SPIRE: 1, 225 | UNIT_TYPEID.ZERG_ULTRALISKCAVERN: 1, 226 | UNIT_TYPEID.ZERG_NYDUSNETWORK: 1, 227 | UNIT_TYPEID.ZERG_LURKERDENMP: 1, 228 | UNIT_TYPEID.ZERG_LAIR: 1, 229 | UNIT_TYPEID.ZERG_HIVE: 1, 230 | UNIT_TYPEID.ZERG_GREATERSPIRE: 1, 231 | UNIT_TYPEID.ZERG_OVERSEER: 10, 232 | UNIT_TYPEID.ZERG_QUEEN: 4, 233 | UNIT_TYPEID.ZERG_CORRUPTOR: 6, 234 | UNIT_TYPEID.ZERG_INFESTOR: 3, 235 | UNIT_TYPEID.ZERG_RAVAGER: 5, 236 | UNIT_TYPEID.ZERG_ULTRALISK: 6, 237 | UNIT_TYPEID.ZERG_MUTALISK: 6, 238 | UNIT_TYPEID.ZERG_BROODLORD: 4, 239 | UNIT_TYPEID.ZERG_DRONE: 66, 240 | UNIT_TYPEID.ZERG_VIPER: 3, 241 | UNIT_TYPEID.ZERG_LURKERMP: 6, 242 | UNIT_TYPEID.ZERG_ZERGLING: 30, 243 | UNIT_TYPEID.ZERG_ROACH: 30, 244 | UNIT_TYPEID.ZERG_HYDRALISK: 20, 245 | UNIT_TYPEID.ZERG_BANELING: 5, 246 | UNIT_TYPEID.ZERG_SPINECRAWLER: 6, 247 | UNIT_TYPEID.ZERG_SPORECRAWLER: 2, 248 | } 249 | -------------------------------------------------------------------------------- /arena/interfaces/sc2full_formal/obs_int.py: -------------------------------------------------------------------------------- 1 | """ SC2 full formal Observation Interfaces""" 2 | from copy import deepcopy 3 | from collections import OrderedDict 4 | import logging 5 | 6 | import gym.spaces as spaces 7 | from arena.interfaces.interface import Interface 8 | from arena.interfaces.common import ActAsObsV2 9 | import numpy as np 10 | 11 | from timitate.lib6.pb2feature_converter import PB2FeatureConverter as PB2FeatureConverterV6 12 | from timitate.lib6.pb2mask_converter import PB2MaskConverter as PB2MaskConverterV6 13 | from timitate.utils.rep_db import unique_key_to_rep_info 14 | 15 | 16 | class FullObsIntV7(Interface): 17 | def __init__(self, inter, zstat_data_src, input_map_resolution=(128, 128), 18 | output_map_resolution=(128, 128), mmr=3500, game_version='4.10.0', 19 | max_unit_num=600, max_bo_count=50, max_bobt_count=50, 20 | zstat_presort_order_name=None, dict_space=False, zmaker_version='v4', 21 | inj_larv_rule=False, ban_zb_rule=False, ban_rr_rule=False, 22 | ban_hydra_rule=False, rr_food_cap=40, zb_food_cap=10, 23 | hydra_food_cap=10, mof_lair_rule=False, hydra_spire_rule=False, 24 | overseer_rule=False, expl_map_rule=False, baneling_rule=False, 25 | add_cargo_to_units=False, crop_to_playable_area=False, ab_dropout_list=None): 26 | super(FullObsIntV7, self).__init__(inter) 27 | self.pb2feat = PB2FeatureConverterV6(map_resolution=input_map_resolution, 28 | zstat_data_src=zstat_data_src, 29 | max_bo_count=max_bo_count, 30 | max_bobt_count=max_bobt_count, 31 | game_version=game_version, 32 | dict_space=dict_space, 33 | zstat_version=zmaker_version, 34 | max_unit_num=max_unit_num, 35 | add_cargo_to_units=add_cargo_to_units, 36 | crop_to_playable_area=crop_to_playable_area) 37 | self.feat_spec = self.pb2feat.space 38 | self.pb2mask = PB2MaskConverterV6(map_resolution=output_map_resolution, 39 | game_version=game_version, 40 | dict_space=dict_space, 41 | max_unit_num=max_unit_num, 42 | inj_larv_rule=inj_larv_rule, 43 | ban_zb_rule=ban_zb_rule, 44 | ban_rr_rule=ban_rr_rule, 45 | ban_hydra_rule=ban_hydra_rule, 46 | rr_food_cap=rr_food_cap, 47 | zb_food_cap=zb_food_cap, 48 | hydra_food_cap=hydra_food_cap, 49 | mof_lair_rule=mof_lair_rule, 50 | hydra_spire_rule=hydra_spire_rule, 51 | overseer_rule=overseer_rule, 52 | expl_map_rule=expl_map_rule, 53 | baneling_rule=baneling_rule, 54 | add_cargo_to_units=add_cargo_to_units, 55 | ab_dropout_list=ab_dropout_list) 56 | self.mask_spec = self.pb2mask.space 57 | self._arg_mask = self.pb2mask.get_arg_mask() 58 | self.pb = None 59 | self._last_tar_tag = None 60 | self._last_units = None 61 | self._last_selected_unit_tags = None 62 | self.mmr = mmr 63 | self._zstat_presort_order_name = zstat_presort_order_name 64 | self._dict_space = dict_space 65 | self._max_unit_num = max_unit_num 66 | 67 | def reset(self, obs, **kwargs): 68 | super(FullObsIntV7, self).reset(obs, **kwargs) 69 | self._last_tar_tag = None 70 | self._last_units = None 71 | ################### 72 | # VERY DANGEROUS, be careful 73 | # map_name = 'KairosJunction' # for temp debugging 74 | map_name = obs.game_info.map_name 75 | ################### 76 | start_pos = obs.game_info.start_raw.start_locations[0] 77 | # get the (zstat) zeroing probability 78 | if 'zeroing_prob' not in kwargs: 79 | logging.info('FullObsIntV5.reset: no zeroing_prob, defaults to 0.0') 80 | zstat_zeroing_prob = 0.0 81 | else: 82 | zstat_zeroing_prob = kwargs['zeroing_prob'] 83 | # get the distribution 84 | if 'distrib' not in kwargs: 85 | logging.info('FullObsIntV5.reset: no distrib, defaults to None') 86 | distrib = None 87 | else: 88 | distrib = kwargs['distrib'] 89 | # get the zstat category 90 | if 'zstat_category' not in kwargs: 91 | logging.info( 92 | 'FullObsIntV5.reset: no zstat_category, defaults to None') 93 | zstat_category = None 94 | else: 95 | zstat_category = kwargs['zstat_category'] 96 | # get the concrete zstat 97 | replay_name, player_id = self._sample_replay( 98 | distrib=distrib, 99 | zstat_presort_order=self._zstat_presort_order_name, 100 | zstat_category=zstat_category 101 | ) 102 | # book-keep it to the root interface 103 | self.unwrapped().cur_zstat_fn = '{}-{}'.format(replay_name, player_id) 104 | self.pb2feat.reset(replay_name=replay_name, player_id=player_id, 105 | mmr=self.mmr, map_name=map_name, 106 | start_pos=(start_pos.x, start_pos.y), 107 | zstat_zeroing_prob=zstat_zeroing_prob) 108 | self.pb2mask.reset() 109 | self._last_selected_unit_tags = None 110 | 111 | def _sample_replay(self, distrib, zstat_presort_order, zstat_category): 112 | logging.info('FullObsIntV5._sample_reply: zstat_presort_order_name={}'.format( 113 | zstat_presort_order)) 114 | # check consistency 115 | if zstat_presort_order is not None and zstat_category is not None: 116 | raise ValueError('zstat_presort_order and zstat_category cannot be used simultaneously.') 117 | # decide the replay names from which we really sample 118 | if zstat_presort_order: 119 | all_replay_names = self.pb2feat.tarzstat_maker.zstat_keys_index.get_keys_by_presort_order( 120 | presort_order_name=zstat_presort_order 121 | ) 122 | elif zstat_category: 123 | all_replay_names = self.pb2feat.tarzstat_maker.zstat_keys_index.get_keys_by_category( 124 | category_name=zstat_category 125 | ) 126 | else: 127 | all_replay_names = self.pb2feat.tarzstat_maker.zstat_db.keys() 128 | # decide the distribution and do the sampling accordingly 129 | if distrib is None: 130 | logging.info('FullObsIntV5._sample_reply: distrib is None, defaults to uniform distribution.') 131 | distrib = np.ones(shape=(len(all_replay_names),), dtype=np.float32) 132 | p = distrib / distrib.sum() 133 | assert len(p) == len(all_replay_names), 'n={}, no. replays={}'.format( 134 | len(p), len(all_replay_names) 135 | ) 136 | key = np.random.choice(all_replay_names, p=p) 137 | logging.info('FullObsIntV5._sample_reply: sampled from n={} replay files'.format(p.size)) 138 | logging.info('FullObsIntV5._sample_reply: distrib={}'.format(p)) 139 | replay_name, player_id = unique_key_to_rep_info(key) 140 | return replay_name, player_id 141 | 142 | @property 143 | def observation_space(self): 144 | if self._dict_space: 145 | obs_spec = spaces.Dict( 146 | OrderedDict(list(self.feat_spec.spaces.items()) 147 | + list(self.mask_spec.spaces.items()))) 148 | else: 149 | obs_spec = spaces.Tuple( 150 | self.feat_spec.spaces + self.mask_spec.spaces) 151 | return obs_spec 152 | 153 | def obs(self, feat, mask): 154 | if self._dict_space: 155 | for k in mask: 156 | feat[k] = mask[k] 157 | return feat 158 | else: 159 | return list(feat) + list(mask) 160 | 161 | def act_trans(self, act): 162 | # cache the act 163 | self._last_act = act 164 | # use the cached act to determine the updated last-target-tag 165 | pysc2_timestep = self.unwrapped()._obs 166 | if not self._dict_space: 167 | # new action space and new index 168 | ab_index = self._last_act[0] 169 | last_selected_indices = self._last_act[3] 170 | target_index = self._last_act[4] 171 | else: 172 | ab_index = self._last_act['A_AB'] 173 | last_selected_indices = self._last_act['A_SELECT'] 174 | target_index = self._last_act['A_CMD_UNIT'] 175 | self._last_units = pysc2_timestep.observation.raw_data.units 176 | self._last_tar_tag = self._last_units[target_index].tag \ 177 | if self._arg_mask[ab_index, 4-1] else None 178 | self._last_selected_unit_tags = ( 179 | [] if not self._arg_mask[ab_index, 3-1] else 180 | [self._last_units[idx].tag 181 | for idx in last_selected_indices if idx != self._max_unit_num]) 182 | # do the routine 183 | if self.inter: 184 | act = self.inter.act_trans(act) 185 | return act 186 | 187 | def obs_trans(self, raw_obs): 188 | if self.inter: 189 | obs = self.inter.obs_trans(raw_obs) 190 | else: 191 | obs = raw_obs 192 | 193 | pb = obs, obs.game_info # TODO: to be simplified 194 | self.pb = pb 195 | feat = self.pb2feat.convert(pb, self._last_tar_tag, self._last_units, self._last_selected_unit_tags) 196 | mask = self.pb2mask.convert(pb) 197 | return self.obs(feat, mask) 198 | 199 | 200 | class ActAsObsSC2(ActAsObsV2): 201 | def __init__(self, inter, override=False): 202 | super(ActAsObsSC2, self).__init__(inter, override) 203 | self._game_loop = 0 204 | 205 | def reset(self, obs, **kwargs): 206 | super(ActAsObsSC2, self).reset(obs, **kwargs) 207 | self._game_loop = 0 208 | 209 | def obs_trans(self, obs): 210 | obs_old = obs 211 | if self.inter: 212 | obs_old = self.inter.obs_trans(obs) 213 | # obs = (obs_pb, game_info), remove game_info? 214 | game_loop = int(obs.observation.game_loop) 215 | self._action['A_NOOP_NUM'] = min(game_loop - self._game_loop - 1, 127) 216 | self._game_loop = game_loop 217 | return self.wrapper.observation_transform(obs_old, self._action) 218 | -------------------------------------------------------------------------------- /arena/interfaces/sc2full_formal/zerg_obs_int.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from gym import spaces 7 | from pysc2.lib.typeenums import UNIT_TYPEID as UNIT_TYPE 8 | from pysc2.lib.typeenums import UNIT_TYPEID, UPGRADE_ID 9 | from arena.interfaces.common import AppendObsInt 10 | from arena.utils.spaces import NoneSpace 11 | 12 | 13 | class ZergTechObsInt(AppendObsInt): 14 | class Wrapper(object): 15 | def __init__(self, override, space_old): 16 | '''upgrade of self (enemy's upgrade is unavailable)''' 17 | self.tech_list = [UPGRADE_ID.BURROW.value, 18 | UPGRADE_ID.CENTRIFICALHOOKS.value, 19 | UPGRADE_ID.CHITINOUSPLATING.value, 20 | UPGRADE_ID.EVOLVEMUSCULARAUGMENTS.value, 21 | UPGRADE_ID.GLIALRECONSTITUTION.value, 22 | UPGRADE_ID.INFESTORENERGYUPGRADE.value, 23 | UPGRADE_ID.ZERGLINGATTACKSPEED.value, 24 | UPGRADE_ID.ZERGLINGMOVEMENTSPEED.value, 25 | UPGRADE_ID.ZERGFLYERARMORSLEVEL1.value, 26 | UPGRADE_ID.ZERGFLYERARMORSLEVEL2.value, 27 | UPGRADE_ID.ZERGFLYERARMORSLEVEL3.value, 28 | UPGRADE_ID.ZERGFLYERWEAPONSLEVEL1.value, 29 | UPGRADE_ID.ZERGFLYERWEAPONSLEVEL2.value, 30 | UPGRADE_ID.ZERGFLYERWEAPONSLEVEL3.value, 31 | UPGRADE_ID.ZERGGROUNDARMORSLEVEL1.value, 32 | UPGRADE_ID.ZERGGROUNDARMORSLEVEL2.value, 33 | UPGRADE_ID.ZERGGROUNDARMORSLEVEL3.value, 34 | UPGRADE_ID.ZERGMELEEWEAPONSLEVEL1.value, 35 | UPGRADE_ID.ZERGMELEEWEAPONSLEVEL2.value, 36 | UPGRADE_ID.ZERGMELEEWEAPONSLEVEL3.value, 37 | UPGRADE_ID.ZERGMISSILEWEAPONSLEVEL1.value, 38 | UPGRADE_ID.ZERGMISSILEWEAPONSLEVEL2.value, 39 | UPGRADE_ID.ZERGMISSILEWEAPONSLEVEL3.value] 40 | observation_space = spaces.Box(0.0, 1.0, [len(self.tech_list)], dtype=np.float32) 41 | self.override = override 42 | if self.override or isinstance(space_old, NoneSpace): 43 | self.observation_space = spaces.Tuple((observation_space,)) 44 | else: 45 | self.observation_space = \ 46 | spaces.Tuple(space_old.spaces + (observation_space,)) 47 | 48 | def observation_transform(self, obs_pre, obs): 49 | new_obs = [upgrade in obs.observation['raw_data'].player.upgrade_ids for upgrade in self.tech_list] 50 | new_obs = np.array(new_obs, dtype=np.float32) 51 | return [new_obs] if self.override else list(obs_pre) + [new_obs] 52 | 53 | def reset(self, obs, **kwargs): 54 | super(ZergTechObsInt, self).reset(obs, **kwargs) 55 | self.wrapper = self.Wrapper(override=self.override, 56 | space_old=self.inter.observation_space) 57 | 58 | 59 | class ZergUnitProg(object): 60 | def __init__(self, tech_tree, override, space_old, 61 | building_list=None, tech_list=None, dtype=np.float32): 62 | '''Return (in_progress, progess) for each building and tech 63 | in_progress includes the period the ordered drone moving to target pos 64 | Only self, enemy's information not available''' 65 | self.TT = tech_tree 66 | self.dtype = dtype 67 | self.building_list = building_list or \ 68 | [UNIT_TYPE.ZERG_SPAWNINGPOOL.value, 69 | UNIT_TYPE.ZERG_ROACHWARREN.value, 70 | UNIT_TYPE.ZERG_HYDRALISKDEN.value, 71 | UNIT_TYPE.ZERG_HATCHERY.value, 72 | UNIT_TYPE.ZERG_EVOLUTIONCHAMBER.value, 73 | UNIT_TYPE.ZERG_BANELINGNEST.value, 74 | UNIT_TYPE.ZERG_INFESTATIONPIT.value, 75 | UNIT_TYPE.ZERG_SPIRE.value, 76 | UNIT_TYPE.ZERG_ULTRALISKCAVERN.value, 77 | UNIT_TYPE.ZERG_LURKERDENMP.value, 78 | UNIT_TYPE.ZERG_LAIR.value, 79 | UNIT_TYPE.ZERG_HIVE.value, 80 | UNIT_TYPE.ZERG_GREATERSPIRE.value] 81 | self.tech_list = tech_list or \ 82 | [UPGRADE_ID.BURROW.value, 83 | UPGRADE_ID.CENTRIFICALHOOKS.value, 84 | UPGRADE_ID.CHITINOUSPLATING.value, 85 | UPGRADE_ID.EVOLVEMUSCULARAUGMENTS.value, 86 | UPGRADE_ID.GLIALRECONSTITUTION.value, 87 | UPGRADE_ID.INFESTORENERGYUPGRADE.value, 88 | UPGRADE_ID.ZERGLINGATTACKSPEED.value, 89 | UPGRADE_ID.ZERGLINGMOVEMENTSPEED.value, 90 | UPGRADE_ID.ZERGFLYERARMORSLEVEL1.value, 91 | UPGRADE_ID.ZERGFLYERARMORSLEVEL2.value, 92 | UPGRADE_ID.ZERGFLYERARMORSLEVEL3.value, 93 | UPGRADE_ID.ZERGFLYERWEAPONSLEVEL1.value, 94 | UPGRADE_ID.ZERGFLYERWEAPONSLEVEL2.value, 95 | UPGRADE_ID.ZERGFLYERWEAPONSLEVEL3.value, 96 | UPGRADE_ID.ZERGGROUNDARMORSLEVEL1.value, 97 | UPGRADE_ID.ZERGGROUNDARMORSLEVEL2.value, 98 | UPGRADE_ID.ZERGGROUNDARMORSLEVEL3.value, 99 | UPGRADE_ID.ZERGMELEEWEAPONSLEVEL1.value, 100 | UPGRADE_ID.ZERGMELEEWEAPONSLEVEL2.value, 101 | UPGRADE_ID.ZERGMELEEWEAPONSLEVEL3.value, 102 | UPGRADE_ID.ZERGMISSILEWEAPONSLEVEL1.value, 103 | UPGRADE_ID.ZERGMISSILEWEAPONSLEVEL2.value, 104 | UPGRADE_ID.ZERGMISSILEWEAPONSLEVEL3.value] 105 | n_dims = len(self.building_list) * 2 + len(self.tech_list) * 2 106 | observation_space = spaces.Box(0.0, 1.0, [n_dims], dtype=dtype) 107 | self.override = override 108 | if self.override or isinstance(space_old, NoneSpace): 109 | self.observation_space = spaces.Tuple((observation_space,)) 110 | else: 111 | self.observation_space = \ 112 | spaces.Tuple(space_old.spaces + (observation_space,)) 113 | self.morph_history = {} # tag: [ability_id, game_loop_start, game_loop_now] 114 | 115 | def building_progress(self, unit_type, obs, alliance=1): 116 | in_progress = 0 117 | progress = 0 118 | unit_data = self.TT.getUnitData(unit_type) 119 | if not unit_data.isBuilding: 120 | print('building_in_progress can only be used for buildings!') 121 | game_loop = obs.observation.game_loop 122 | if isinstance(game_loop, np.ndarray): 123 | game_loop = game_loop[0] 124 | if unit_type in [UNIT_TYPE.ZERG_LAIR.value, 125 | UNIT_TYPE.ZERG_HIVE.value, 126 | UNIT_TYPE.ZERG_GREATERSPIRE.value]: 127 | builders = [unit for unit in obs.observation.raw_data.units 128 | if unit.unit_type in unit_data.whatBuilds 129 | and unit.alliance == alliance] 130 | for builder in builders: 131 | if len(builder.orders) > 0 and builder.orders[0].ability_id == unit_data.buildAbility: 132 | # pb do not return the progress of unit morphing 133 | if (builder.unit_type not in self.morph_history or 134 | self.morph_history[builder.unit_type][0] != unit_data.buildAbility): 135 | self.morph_history[builder.unit_type] = [unit_data.buildAbility, game_loop, game_loop] 136 | else: 137 | self.morph_history[builder.unit_type][2] = game_loop 138 | in_progress = 1 139 | progress = self.morph_history[builder.unit_type][2] - self.morph_history[builder.unit_type][1] 140 | progress /= float(unit_data.buildTime) 141 | else: 142 | for unit in obs.observation.raw_data.units: 143 | if (unit.unit_type == unit_type 144 | and unit.alliance == alliance 145 | and unit.build_progress < 1): 146 | in_progress = 1 147 | progress = max(progress, unit.build_progress) 148 | if (unit.unit_type == UNIT_TYPEID.ZERG_DRONE.value 149 | and unit.alliance == alliance 150 | and len(unit.orders) > 0 151 | and unit.orders[0].ability_id == unit_data.buildAbility): 152 | in_progress = 1 153 | return in_progress, progress 154 | 155 | def update_morph_history(self, obs): 156 | # pb do not return the progress of unit morphing 157 | game_loop = obs.observation.game_loop 158 | if isinstance(game_loop, np.ndarray): 159 | game_loop = game_loop[0] 160 | for tag in self.morph_history: 161 | if self.morph_history[tag][2] != game_loop: 162 | self.morph_history[tag][0] = None 163 | 164 | def upgrade_progress(self, upgrade_type, obs, alliance=1): 165 | in_progress = 0 166 | progress = 0 167 | data = self.TT.getUpgradeData(upgrade_type) 168 | builders = [unit for unit in obs.observation.raw_data.units 169 | if unit.unit_type in data.whatBuilds 170 | and unit.alliance == alliance] 171 | for builder in builders: 172 | if len(builder.orders) > 0 and builder.orders[0].ability_id == data.buildAbility: 173 | in_progress = 1 174 | progress = builder.orders[0].progress 175 | return in_progress, progress 176 | 177 | def observation_transform(self, obs_pre, obs): 178 | new_obs = [] 179 | for building in self.building_list: 180 | new_obs.extend(self.building_progress(building, obs)) 181 | for upgrade in self.tech_list: 182 | new_obs.extend(self.upgrade_progress(upgrade, obs)) 183 | self.update_morph_history(obs) 184 | new_obs = np.array(new_obs, dtype=self.dtype) 185 | return [new_obs] if self.override else list(obs_pre) + [new_obs] 186 | 187 | 188 | class ZergUnitProgObsInt(AppendObsInt): 189 | def reset(self, obs, **kwargs): 190 | super(ZergUnitProgObsInt, self).reset(obs, **kwargs) 191 | self.wrapper = ZergUnitProg(self.unwrapped().dc.sd.TT, 192 | override=self.override, 193 | space_old=self.inter.observation_space) 194 | -------------------------------------------------------------------------------- /arena/interfaces/pommerman/obs_int.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import copy 6 | import numpy as np 7 | from arena.interfaces.interface import Interface 8 | from arena.interfaces.common import AppendObsInt 9 | from arena.utils.spaces import NoneSpace 10 | from gym import spaces 11 | from pommerman.constants import Item 12 | 13 | 14 | class BoardMapObsFunc(object): 15 | def __init__(self, obs, items, override, use_attr, space_old): 16 | self.override = override 17 | self.use_attr = use_attr 18 | self.items = list(items) 19 | self.items.extend(obs['teammate'] + obs['enemies']) 20 | self.shape = list(obs['board'].shape) + [len(self.items) + 2 + 3 * use_attr] 21 | self.item_map = {v:k for k, v in enumerate(self.items)} 22 | observation_space = spaces.Box(0.0, float('inf'), self.shape, dtype=np.float32) 23 | if self.override or isinstance(space_old, NoneSpace): 24 | self.observation_space = spaces.Tuple((observation_space,)) 25 | else: 26 | self.observation_space = \ 27 | spaces.Tuple(space_old.spaces + (observation_space,)) 28 | 29 | def observation_transform(self, obs_pre, obs): 30 | new_obs = np.zeros(self.shape, dtype=np.float32) 31 | for i in range(self.shape[0]): 32 | for j in range(self.shape[1]): 33 | k = obs['board'][i, j] 34 | if k in self.item_map: 35 | new_obs[i, j, self.item_map[k]] = 1 36 | def expand_bomb_blast(board): 37 | new_board = np.zeros(board.shape) 38 | for i in range(board.shape[0]): 39 | for j in range(board.shape[1]): 40 | if board[i][j] == 0: 41 | continue 42 | s = int(board[i][j] - 1) 43 | for ii in range(max(0, i-s), min(11, i+s+1)): 44 | new_board[ii, j] = 1 45 | for jj in range(max(0, j-s), min(11, j+s+1)): 46 | new_board[i, jj] = 1 47 | return new_board 48 | new_obs[:, :, -(2 + 3 * self.use_attr)] = expand_bomb_blast(obs['bomb_blast_strength']) 49 | def expand_bomb_life(obs): 50 | new_board = np.array(obs['bomb_life']) 51 | max_x, max_y = new_board.shape 52 | bomb_pos = np.nonzero(new_board) 53 | bomb_life = new_board[bomb_pos] 54 | for i in np.argsort(bomb_life): 55 | x, y = bomb_pos[0][i], bomb_pos[1][i] 56 | s = int(obs['bomb_blast_strength'][x, y] - 1) 57 | for ii in range(max(0, x-s), min(max_x, x+s+1)): 58 | new_board[ii, y] = new_board[x, y] 59 | for jj in range(max(0, y-s), min(max_y, y+s+1)): 60 | new_board[x, jj] = new_board[x, y] 61 | for i in range(max_x): 62 | for j in range(max_y): 63 | if new_board[i, j] != 0: 64 | new_board[i, j] = (10 - new_board[i, j]) / 10.0 65 | return new_board 66 | new_obs[:, :, -(1 + 3 * self.use_attr)] = expand_bomb_life(obs) 67 | 68 | if self.use_attr: 69 | for i, pos in enumerate(obs['position']): 70 | if obs['alive'][i]: 71 | new_obs[pos[0], pos[1], -3] = obs['blast_strength'][i] / 5.0 72 | new_obs[pos[0], pos[1], -2] = obs['can_kick'][i] 73 | new_obs[pos[0], pos[1], -1] = obs['ammo'][i] / 3.0 74 | return (new_obs,) if self.override else obs_pre + (new_obs,) 75 | 76 | 77 | class BoardMapObs(AppendObsInt): 78 | def __init__(self, inter, override=True, use_attr=True): 79 | super(BoardMapObs, self).__init__(inter, override) 80 | self.items = (Item.Rigid.value, 81 | Item.Wood.value, 82 | Item.Bomb.value, 83 | Item.Flames.value, 84 | Item.Fog.value, 85 | Item.ExtraBomb.value, 86 | Item.IncrRange.value, 87 | Item.Kick.value) 88 | self.use_attr = use_attr 89 | 90 | def reset(self, obs, **kwargs): 91 | super(BoardMapObs, self).reset(obs, **kwargs) 92 | self.wrapper = BoardMapObsFunc(self.unwrapped()._obs, self.items, 93 | self.override, self.use_attr, 94 | space_old=self.inter.observation_space) 95 | 96 | 97 | class CombineObsInt(Interface): 98 | def __init__(self, inter, remove_dead_view = True): 99 | super(CombineObsInt, self).__init__(inter) 100 | self.remove_dead_view = remove_dead_view 101 | 102 | def reset(self, obs, **kwargs): 103 | super(CombineObsInt, self).reset(obs, **kwargs) 104 | self.obs_trans(obs) 105 | 106 | def obs_trans(self, obs): 107 | state1, state2 = obs 108 | if self.remove_dead_view and state2['teammate'].value not in state2['alive']: 109 | state = copy.deepcopy(state2) 110 | elif self.remove_dead_view and state1['teammate'].value not in state1['alive']: 111 | state = copy.deepcopy(state1) 112 | else: 113 | state = copy.deepcopy(state1) 114 | for i in range(state1['board'].shape[0]): 115 | for j in range(state1['board'].shape[1]): 116 | if (state1['board'][i, j] == Item.Fog.value and 117 | state2['board'][i, j] != Item.Fog.value): 118 | state['board'][i, j] = state2['board'][i, j] 119 | state['position'] = (state1['position'], state2['position']) 120 | state['blast_strength'] = (state1['blast_strength'], state2['blast_strength']) 121 | state['can_kick'] = (state1['can_kick'], state2['can_kick']) 122 | state['ammo'] = (state1['ammo'], state2['ammo']) 123 | state['teammate'] = (state2['teammate'].value, state1['teammate'].value) 124 | state['enemies'] = (state['enemies'][0].value, state['enemies'][1].value) 125 | state['alive'] = [t in state['alive'] for t in state['teammate']] + \ 126 | [e in state['alive'] for e in state['enemies']] 127 | self.unwrapped()._obs = state 128 | return state 129 | 130 | 131 | class AttrObsInt(AppendObsInt): 132 | class AttrFunc(object): 133 | def __init__(self, obs, override, space_old): 134 | self._attr_dim = 6 135 | self._board_shape = obs['board'].shape 136 | obs_space = spaces.Box(low=-1, high=2, dtype=np.float32, 137 | shape=(self._attr_dim * len(obs['position']),)) 138 | self.override = override 139 | if self.override or isinstance(space_old, NoneSpace): 140 | self.observation_space = spaces.Tuple((obs_space,)) 141 | else: 142 | self.observation_space = \ 143 | spaces.Tuple(space_old.spaces + (obs_space,)) 144 | 145 | def observation_transform(self, obs_pre, obs): 146 | units_vec = [] 147 | for i, pos in enumerate(obs['position']): 148 | alive = obs['alive'][i] 149 | units_vec.append(alive) 150 | units_vec.append(pos[0] / (self._board_shape[0] - 1) - 0.5) 151 | units_vec.append(pos[1] / (self._board_shape[1] - 1) - 0.5) 152 | if alive: 153 | units_vec.extend([obs['blast_strength'][i] / 5.0, 154 | obs['can_kick'][i], 155 | obs['ammo'][i] / 3.0]) 156 | else: 157 | units_vec.extend([0., 0., 0.]) 158 | observation = np.array(units_vec, dtype=np.float32) 159 | return (observation,) if self.override else obs_pre + (observation,) 160 | 161 | def reset(self, obs, **kwargs): 162 | super(AttrObsInt, self).reset(obs, **kwargs) 163 | self.wrapper = self.AttrFunc(self.unwrapped()._obs, override=self.override, 164 | space_old=self.inter.observation_space) 165 | 166 | 167 | class PosObsInt(AppendObsInt): 168 | class PosFunc(object): 169 | def __init__(self, obs, override, space_old): 170 | self._attr_dim = 2 * len(obs['position']) 171 | obs_space = spaces.Box(low=-1, high=2, dtype=np.int32, 172 | shape=(self._attr_dim,)) 173 | self.override = override 174 | if self.override or isinstance(space_old, NoneSpace): 175 | self.observation_space = spaces.Tuple((obs_space,)) 176 | else: 177 | self.observation_space = \ 178 | spaces.Tuple(space_old.spaces + (obs_space,)) 179 | 180 | def observation_transform(self, obs_pre, obs): 181 | units_vec = [] 182 | for i, pos in enumerate(obs['position']): 183 | units_vec.append(pos[0]) 184 | units_vec.append(pos[1]) 185 | observation = np.array(units_vec, dtype=np.int32) 186 | return (observation,) if self.override else obs_pre + (observation,) 187 | 188 | def reset(self, obs, **kwargs): 189 | super(PosObsInt, self).reset(obs, **kwargs) 190 | self.wrapper = self.PosFunc(self.unwrapped()._obs, override=self.override, 191 | space_old=self.inter.observation_space) 192 | 193 | 194 | class ActMaskObsInt(AppendObsInt): 195 | class ActMaskFunc(object): 196 | def __init__(self, obs, space_old): 197 | self.shape = obs['board'].shape 198 | obs_space = spaces.Box(low=-1, high=2, dtype=np.int32, 199 | shape=(6*2,)) 200 | self.observation_space = \ 201 | spaces.Tuple(space_old.spaces + (obs_space,)) 202 | self.pathing_items = [Item.Passage.value, Item.ExtraBomb.value, 203 | Item.IncrRange.value, Item.Kick.value] 204 | 205 | def in_board(self, pos): 206 | return (0 <= pos[0] < self.shape[0]) and (0 <= pos[1] < self.shape[1]) 207 | 208 | def observation_transform(self, obs_pre, obs): 209 | act_mask = np.zeros((2, 6)) 210 | for i, pos in enumerate(obs['position'][0:2]): 211 | act_mask[i, 0] = 1 212 | if not obs['alive'][i]: 213 | continue 214 | for j, (dx, dy) in enumerate([(-1, 0), (1, 0), (0, -1), (0, 1)]): 215 | new_pos = (pos[0] + dx, pos[1] + dy) 216 | if self.in_board(new_pos): 217 | if obs['board'][new_pos] in self.pathing_items: 218 | act_mask[i, j + 1] = 1 219 | elif obs['can_kick'][i] and obs['board'][new_pos] == Item.Bomb.value: 220 | further_pos = (pos[0] + 2 * dx, pos[1] + 2 * dy) 221 | if (self.in_board(further_pos) and 222 | obs['board'][further_pos] in self.pathing_items): 223 | act_mask[i, j + 1] = 1 224 | act_mask[i, -1] = obs['ammo'][i] > 0 and obs['bomb_blast_strength'][pos] == 0 225 | return obs_pre + (act_mask.reshape([-1]),) 226 | 227 | def reset(self, obs, **kwargs): 228 | super(ActMaskObsInt, self).reset(obs, **kwargs) 229 | self.wrapper = self.ActMaskFunc(self.unwrapped()._obs, 230 | space_old=self.inter.observation_space) -------------------------------------------------------------------------------- /arena/interfaces/sc2full_formal/act_int.py: -------------------------------------------------------------------------------- 1 | """ Gym env wrappers """ 2 | from copy import deepcopy 3 | 4 | import numpy as np 5 | from timitate.utils.commands import cmd_with_pos, cmd_with_tar, noop 6 | from pysc2.lib import UNIT_TYPEID 7 | from timitate.lib6.action2pb_converter import Action2PBConverter as Action2PBConverterV6 8 | 9 | 10 | from arena.interfaces.interface import Interface 11 | 12 | class NoopActIntV4(Interface): 13 | def __init__(self, inter, noop_nums=(i+1 for i in range(128)), 14 | noop_func=lambda x: x[1]): 15 | super(self.__class__, self).__init__(inter) 16 | self.noop_nums = list(noop_nums) 17 | self.target_game_loop = 0 18 | self.noop_func = noop_func 19 | 20 | def obs_trans(self, obs): 21 | game_loop = obs.observation.game_loop 22 | if game_loop < self.target_game_loop: 23 | return None 24 | else: 25 | obs = self.inter.obs_trans(obs) 26 | return obs 27 | 28 | def act_trans(self, action): 29 | game_loop = self.unwrapped()._obs.observation.game_loop 30 | if game_loop < self.target_game_loop: 31 | return [] 32 | else: 33 | self.target_game_loop = game_loop + self.noop_nums[int(self.noop_func(action))] 34 | act = self.inter.act_trans(action) 35 | return act 36 | 37 | def reset(self, obs, **kwargs): 38 | super(NoopActIntV4, self).reset(obs, **kwargs) 39 | self.target_game_loop = 0 40 | 41 | 42 | class TRTActInt(Interface): 43 | # Tower rush trick (TRT) action interface, only for KairosJunction 44 | def __init__(self, inter): 45 | super(TRTActInt, self).__init__(inter) 46 | self.use_trt = False 47 | self._started_print = False 48 | self._completed_print = False 49 | self._executors = [] 50 | self._n_drones_trt = 2 51 | self._base_pos = [(31.5, 140.5), (120.5, 27.5)] 52 | 53 | def _dist(self, x1, y1, x2, y2): 54 | return ((x1-x2)**2+(y1-y2)**2)**0.5 55 | 56 | def _ready_to_go(self): 57 | units = self.unwrapped()._obs.observation.raw_data.units 58 | drones = [u for u in units if u.alliance == 1 and 59 | u.unit_type == UNIT_TYPEID.ZERG_DRONE.value] 60 | if len(self._executors) >= self._n_drones_trt: 61 | # update executors' attributes 62 | for i, e in enumerate(self._executors): 63 | if e is None: 64 | continue 65 | is_alive = False 66 | for d in drones: 67 | if d.tag == e.tag: 68 | self._executors[i] = d 69 | is_alive = True 70 | break 71 | if not is_alive: 72 | self._executors[i] = None 73 | return True 74 | for u in units: 75 | # once self spawning pool is on building after 0.3 progresses 76 | if u.alliance == 1 and u.unit_type == UNIT_TYPEID.ZERG_SPAWNINGPOOL.value \ 77 | and u.build_progress > 0.3 and len(drones) > 0 \ 78 | and len(self._executors) < self._n_drones_trt: 79 | for d in drones: 80 | if d not in self._executors: 81 | self._executors.append(d) 82 | if len(self._executors) >= self._n_drones_trt: 83 | return True 84 | return False 85 | 86 | def _mission_completed(self): 87 | # if there had been executors and now the executors have been eliminated (failed or success) 88 | if len(self._executors) == 0: 89 | return False 90 | for e in self._executors: 91 | if e is not None: 92 | return False 93 | return True 94 | 95 | def _target_pos(self): 96 | units = self.unwrapped()._obs.observation.raw_data.units 97 | self_h = [u for u in units if u.alliance == 1 and u.unit_type == UNIT_TYPEID.ZERG_HATCHERY.value] 98 | if len(self_h) == 0: 99 | return 0, 0 100 | self_h0 = None 101 | for h in self_h: 102 | if min(self._dist(h.pos.x, h.pos.y, 103 | self._base_pos[0][0], self._base_pos[0][1]), 104 | self._dist(h.pos.x, h.pos.y, 105 | self._base_pos[1][0], self._base_pos[1][1])) < 1: 106 | self_h0 = h 107 | if self_h0 is None: 108 | raise BaseException('Not KJ map in TRTActInt.') 109 | if self._dist(self_h0.pos.x, self_h0.pos.y, 110 | self._base_pos[0][0], self._base_pos[0][1]) < \ 111 | self._dist(self_h0.pos.x, self_h0.pos.y, 112 | self._base_pos[1][0], self._base_pos[1][1]): 113 | return self._base_pos[1] 114 | else: 115 | return self._base_pos[0] 116 | 117 | def _drone_micro(self): 118 | def _build_spinecrawler(drone): 119 | detect_range = 10 120 | order_ab_ids = [o.ability_id for o in drone.orders] 121 | if 1166 not in order_ab_ids: 122 | random_r1 = np.random.random() 123 | random_r2 = np.random.random() 124 | build_pos = (random_r1*(drone.pos.x-detect_range/2)+ 125 | (1-random_r1)*(drone.pos.x+detect_range/2), 126 | random_r2*(drone.pos.y-detect_range/2)+ 127 | (1-random_r2)*(drone.pos.y+detect_range/2)) 128 | return cmd_with_pos(ability_id=1166, 129 | x=build_pos[0], 130 | y=build_pos[1], 131 | tags=[drone.tag], 132 | shift=False) 133 | else: 134 | return noop() 135 | 136 | def _move_to_tar(drone, tar_pos): 137 | order_ab_ids = [o.ability_id for o in drone.orders] 138 | if 1 not in order_ab_ids: 139 | return cmd_with_pos(ability_id=1, 140 | x=tar_pos[0], 141 | y=tar_pos[1], 142 | tags=[drone.tag], 143 | shift=False) 144 | else: 145 | return noop() 146 | 147 | def _atk_tar(drone, tar_tag): 148 | order_ab_ids = [o.ability_id for o in drone.orders] 149 | if 23 not in order_ab_ids: 150 | return cmd_with_tar(ability_id=23, 151 | target_tag=tar_tag, 152 | tags=[drone.tag], 153 | shift=False) 154 | else: 155 | return noop() 156 | 157 | units = self.unwrapped()._obs.observation.raw_data.units 158 | enemy_d = [u for u in units if u.alliance == 4 and 159 | u.unit_type == UNIT_TYPEID.ZERG_DRONE.value] 160 | pb_actions = [] 161 | tar_pos = self._target_pos() 162 | for i, drone in enumerate(self._executors): 163 | if drone is None: 164 | continue 165 | if i % 2 == 0: 166 | if self._dist(drone.pos.x, drone.pos.y, tar_pos[0], tar_pos[1]) < 10: 167 | # enemy hatchery in drone's detect range; why drone's detect_range = 0? 168 | pb_actions.append(_build_spinecrawler(drone)) 169 | else: 170 | pb_actions.append(_move_to_tar(drone, tar_pos)) 171 | else: 172 | if len(enemy_d) > 0: 173 | # enemy drones in drone's detect range 174 | enemy_d_tags = [u.tag for u in enemy_d] 175 | pb_actions.append(_atk_tar(drone, min(enemy_d_tags))) 176 | else: 177 | pb_actions.append(_move_to_tar(drone, tar_pos)) 178 | return pb_actions 179 | 180 | def _trt_act(self): 181 | if self._ready_to_go(): 182 | if not self._started_print: 183 | print('Launch tower rush trick.') 184 | self._started_print = True 185 | if not self._mission_completed(): 186 | return self._drone_micro() 187 | else: 188 | if not self._completed_print: 189 | print('Tower rush trick completed.') 190 | self._completed_print = True 191 | return [] 192 | 193 | def _get_cmd_tags(self, raw_acts): 194 | all_tags = [] 195 | for a in raw_acts: 196 | if hasattr(a, 'action_raw') and hasattr(a.action_raw, 'unit_command') \ 197 | and hasattr(a.action_raw.unit_command, 'unit_tags'): 198 | all_tags += a.action_raw.unit_command.unit_tags 199 | return all_tags 200 | 201 | def _remove_tags(self, ori_act, tags): 202 | if len(tags) == 0: 203 | return ori_act 204 | for a in ori_act: 205 | if hasattr(a, 'action_raw') and hasattr(a.action_raw, 'unit_command') \ 206 | and hasattr(a.action_raw.unit_command, 'unit_tags'): 207 | a_tags = set(a.action_raw.unit_command.unit_tags) 208 | for t in tags: 209 | if t in a_tags: 210 | a_tags.remove(t) 211 | # protobuff repeated only support pop() 212 | while len(a.action_raw.unit_command.unit_tags) > 0: 213 | a.action_raw.unit_command.unit_tags.pop() 214 | for t in a_tags: 215 | a.action_raw.unit_command.unit_tags.append(t) 216 | return ori_act 217 | 218 | def act_trans(self, act): 219 | # TODO: assert act is raw_action 220 | ori_act = act 221 | if self.use_trt: 222 | trt_act = self._trt_act() 223 | trt_a_tags = self._get_cmd_tags(trt_act) 224 | ori_act = self._remove_tags(ori_act, trt_a_tags) 225 | # Note: the order of the added items matters; 226 | # if ori_act is placed before trt_act, the remained selection is determined 227 | # by trt_act and then the model will be confused. 228 | act = trt_act + ori_act 229 | return act 230 | 231 | def reset(self, obs, **kwargs): 232 | super(TRTActInt, self).reset(obs, **kwargs) 233 | if 'use_trt' in kwargs: 234 | self.use_trt = kwargs['use_trt'] 235 | else: 236 | self.use_trt = False 237 | self._started_print = False 238 | self._completed_print = False 239 | self._executors = [] 240 | 241 | 242 | class FullActIntV6(Interface): 243 | def __init__(self, inter, map_resolution=(128, 128), max_noop_num=128, 244 | max_unit_num=600, correct_pos_radius=2.0, dict_space=False, 245 | correct_building_pos=False, crop_to_playable_area=False, 246 | verbose=30): 247 | super(self.__class__, self).__init__(inter) 248 | self.act2pb = Action2PBConverterV6(map_padding_size=map_resolution, 249 | max_noop_num=max_noop_num, 250 | dict_space=dict_space, 251 | verbose=verbose, 252 | max_unit_num=max_unit_num, 253 | correct_building_pos=correct_building_pos, 254 | crop_to_playable_area=crop_to_playable_area, 255 | correct_pos_radius=correct_pos_radius) 256 | 257 | @property 258 | def action_space(self): 259 | return self.act2pb.space 260 | 261 | def act_trans(self, action): 262 | if self.inter: 263 | action = self.inter.act_trans(action) 264 | if self.inter.pb: 265 | pb_action = self.act2pb.convert(action, self.inter.pb) 266 | else: 267 | raise BaseException 268 | if isinstance(pb_action, list): 269 | return pb_action 270 | else: 271 | return [pb_action] 272 | 273 | def reset(self, obs, **kwargs): 274 | self.act2pb.reset(obs, **kwargs) 275 | super(FullActIntV6, self).reset(obs, **kwargs) 276 | --------------------------------------------------------------------------------