├── src ├── __init__.py ├── modules │ ├── __init__.py │ ├── mixers │ │ ├── __init__.py │ │ ├── vdn.py │ │ ├── qmix.py │ │ └── flex_qmix.py │ ├── layers │ │ ├── __init__.py │ │ └── attention.py │ └── agents │ │ ├── __init__.py │ │ ├── ff_agent.py │ │ ├── rnn_agent.py │ │ ├── entity_rnn_agent.py │ │ └── entity_ff_agent.py ├── utils │ ├── __init__.py │ ├── dict2namedtuple.py │ ├── timehelper.py │ ├── rl_utils.py │ └── logging.py ├── config │ ├── .gitignore │ ├── envs │ │ ├── group_matching.yaml │ │ ├── sc2.yaml │ │ └── sc2custom.yaml │ ├── algs │ │ ├── vdn_atten.yaml │ │ ├── refil_vdn.yaml │ │ ├── qmix_atten.yaml │ │ ├── refil.yaml │ │ ├── qmix_atten_group_matching.yaml │ │ └── refil_group_matching.yaml │ └── default.yaml ├── .gitignore ├── components │ ├── __init__.py │ ├── transforms.py │ ├── epsilon_schedules.py │ ├── action_selectors.py │ └── episode_buffer.py ├── envs │ ├── group_matching │ │ ├── __init__.py │ │ └── group_matching.py │ ├── starcraft2 │ │ ├── maps │ │ │ ├── SMAC_Maps │ │ │ │ ├── 25m.SC2Map │ │ │ │ ├── 3m.SC2Map │ │ │ │ ├── 8m.SC2Map │ │ │ │ ├── MMM.SC2Map │ │ │ │ ├── 1c3s5z.SC2Map │ │ │ │ ├── 2s3z.SC2Map │ │ │ │ ├── 3s5z.SC2Map │ │ │ │ ├── MMM2.SC2Map │ │ │ │ ├── 2m_vs_1z.SC2Map │ │ │ │ ├── 3s_vs_3z.SC2Map │ │ │ │ ├── 3s_vs_4z.SC2Map │ │ │ │ ├── 3s_vs_5z.SC2Map │ │ │ │ ├── 5m_vs_6m.SC2Map │ │ │ │ ├── 6h_vs_8z.SC2Map │ │ │ │ ├── 8m_vs_9m.SC2Map │ │ │ │ ├── corridor.SC2Map │ │ │ │ ├── 10m_vs_11m.SC2Map │ │ │ │ ├── 27m_vs_30m.SC2Map │ │ │ │ ├── 2c_vs_64zg.SC2Map │ │ │ │ ├── 2s_vs_1sc.SC2Map │ │ │ │ ├── 3s5z_vs_3s6z.SC2Map │ │ │ │ ├── bane_vs_bane.SC2Map │ │ │ │ ├── empty_passive.SC2Map │ │ │ │ └── so_many_baneling.SC2Map │ │ │ ├── __init__.py │ │ │ └── smac_maps.py │ │ ├── __init__.py │ │ └── custom_scenarios.py │ ├── __init__.py │ └── multiagentenv.py ├── learners │ ├── __init__.py │ └── q_learner.py ├── controllers │ ├── __init__.py │ ├── entity_controller.py │ └── basic_controller.py ├── runners │ ├── __init__.py │ ├── episode_runner.py │ └── parallel_runner.py ├── main.py └── run.py ├── docker ├── build.sh └── Dockerfile ├── run.sh ├── it_run.sh ├── LICENSE ├── install_sc2.sh ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/mixers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/config/.gitignore: -------------------------------------------------------------------------------- 1 | mongo_creds.yaml -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | tb_logs/ 2 | results/ 3 | -------------------------------------------------------------------------------- /src/components/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | -------------------------------------------------------------------------------- /src/envs/group_matching/__init__.py: -------------------------------------------------------------------------------- 1 | from .group_matching import GroupMatching 2 | -------------------------------------------------------------------------------- /src/modules/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import EntityAttentionLayer, EntityPoolingLayer 2 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Building Dockerfile with image name refil' 4 | docker build -t refil . 5 | -------------------------------------------------------------------------------- /src/learners/__init__.py: -------------------------------------------------------------------------------- 1 | from .q_learner import QLearner 2 | 3 | REGISTRY = {} 4 | REGISTRY["q_learner"] = QLearner 5 | -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/25m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/25m.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/3m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/3m.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/8m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/8m.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/MMM.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/MMM.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/1c3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/1c3s5z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/2s3z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/2s3z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/3s5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/3s5z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/MMM2.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/MMM2.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/2m_vs_1z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/3s_vs_3z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/3s_vs_4z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/3s_vs_5z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/5m_vs_6m.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/6h_vs_8z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/8m_vs_9m.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/corridor.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/corridor.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/10m_vs_11m.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/27m_vs_30m.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/2c_vs_64zg.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/2s_vs_1sc.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/bane_vs_bane.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/empty_passive.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/empty_passive.SC2Map -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shariqiqbal2810/REFIL/HEAD/src/envs/starcraft2/maps/SMAC_Maps/so_many_baneling.SC2Map -------------------------------------------------------------------------------- /src/utils/dict2namedtuple.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | def convert(dictionary): 3 | return namedtuple('GenericDict', dictionary.keys())(**dictionary) 4 | -------------------------------------------------------------------------------- /src/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .basic_controller import BasicMAC 4 | from .entity_controller import EntityMAC 5 | 6 | REGISTRY["basic_mac"] = BasicMAC 7 | REGISTRY["entity_mac"] = EntityMAC 8 | -------------------------------------------------------------------------------- /src/runners/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .episode_runner import EpisodeRunner 4 | REGISTRY["episode"] = EpisodeRunner 5 | 6 | from .parallel_runner import ParallelRunner 7 | REGISTRY["parallel"] = ParallelRunner 8 | -------------------------------------------------------------------------------- /src/modules/mixers/vdn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | 5 | class VDNMixer(nn.Module): 6 | def __init__(self): 7 | super(VDNMixer, self).__init__() 8 | 9 | def forward(self, agent_qs, batch, imagine_groups=None): 10 | return th.sum(agent_qs, dim=2, keepdim=True) 11 | -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from . import smac_maps 6 | 7 | 8 | def get_map_params(map_name): 9 | map_param_registry = smac_maps.get_smac_map_registry() 10 | return map_param_registry[map_name] 11 | -------------------------------------------------------------------------------- /src/config/envs/group_matching.yaml: -------------------------------------------------------------------------------- 1 | env: group_matching 2 | 3 | env_args: 4 | entity_scheme: True 5 | n_agents: 8 6 | n_states: 6 7 | n_groups: 2 8 | rand_trans: 0.1 9 | episode_limit: 50 10 | fixed_scen: False 11 | 12 | test_nepisode: 80 13 | test_interval: 10000 14 | log_interval: 2000 15 | runner_log_interval: 2000 16 | learner_log_interval: 2000 17 | t_max: 1000000 18 | -------------------------------------------------------------------------------- /src/envs/starcraft2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from .starcraft2 import StarCraft2Env 6 | from .starcraft2custom import StarCraft2CustomEnv 7 | from .custom_scenarios import custom_scenario_registry 8 | 9 | from absl import flags 10 | FLAGS = flags.FLAGS 11 | FLAGS(['main.py']) 12 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | HASH=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 4 | head -n 1) 3 | GPU=$1 4 | name=${USER}_pymarl_GPU_${GPU}_${HASH} 5 | 6 | echo "Launching container named '${name}' on GPU '${GPU}'" 7 | # Launches a docker container using our image, and runs the provided command 8 | 9 | docker run \ 10 | --gpus '"device='$GPU'"' \ 11 | --name $name \ 12 | --user $(id -u) \ 13 | -v `pwd`:/REFIL \ 14 | --entrypoint /usr/bin/python3.6 \ 15 | -t refil \ 16 | ${@:2} 17 | -------------------------------------------------------------------------------- /it_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | HASH=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 4 | head -n 1) 3 | GPU=$1 4 | name=${USER}_pymarl_GPU_${GPU}_${HASH} 5 | 6 | echo "Launching container named '${name}' on GPU '${GPU}'" 7 | # Launches a docker container using our image, and runs the provided command 8 | 9 | docker run \ 10 | --gpus '"device='$GPU'"' \ 11 | --name $name \ 12 | --cap-add=SYS_PTRACE \ 13 | --net host \ 14 | --user $(id -u) \ 15 | -v `pwd`:/REFIL \ 16 | --entrypoint /usr/bin/python3.6 \ 17 | -it refil \ 18 | ${@:2} 19 | -------------------------------------------------------------------------------- /src/modules/agents/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .rnn_agent import RNNAgent 4 | from .ff_agent import FFAgent 5 | from .entity_rnn_agent import ImagineEntityAttentionRNNAgent, EntityAttentionRNNAgent 6 | from .entity_ff_agent import EntityAttentionFFAgent, ImagineEntityAttentionFFAgent 7 | 8 | REGISTRY["rnn"] = RNNAgent 9 | REGISTRY["ff"] = FFAgent 10 | REGISTRY["entity_attend_ff"] = EntityAttentionFFAgent 11 | REGISTRY["imagine_entity_attend_ff"] = ImagineEntityAttentionFFAgent 12 | REGISTRY["entity_attend_rnn"] = EntityAttentionRNNAgent 13 | REGISTRY["imagine_entity_attend_rnn"] = ImagineEntityAttentionRNNAgent 14 | -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from .multiagentenv import MultiAgentEnv 4 | from .starcraft2 import StarCraft2Env, StarCraft2CustomEnv 5 | from .group_matching import GroupMatching 6 | from .starcraft2 import custom_scenario_registry as sc_scenarios 7 | 8 | 9 | # TODO: Do we need this? 10 | def env_fn(env, **kwargs) -> MultiAgentEnv: # TODO: this may be a more complex function 11 | # env_args = kwargs.get("env_args", {}) 12 | return env(**kwargs) 13 | 14 | 15 | REGISTRY = {} 16 | REGISTRY["sc2"] = partial(env_fn, env=StarCraft2Env) 17 | REGISTRY["sc2custom"] = partial(env_fn, env=StarCraft2CustomEnv) 18 | REGISTRY["group_matching"] = partial(env_fn, env=GroupMatching) 19 | 20 | s_REGISTRY = {} 21 | s_REGISTRY.update(sc_scenarios) 22 | -------------------------------------------------------------------------------- /src/components/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | 4 | class Transform: 5 | def transform(self, tensor): 6 | raise NotImplementedError 7 | 8 | def infer_output_info(self, vshape_in, dtype_in): 9 | raise NotImplementedError 10 | 11 | class OneHot(Transform): 12 | def __init__(self, out_dim): 13 | self.out_dim = out_dim 14 | 15 | def transform(self, tensor): 16 | y_onehot = tensor.new(*tensor.shape[:-1], self.out_dim).zero_() 17 | y_onehot.scatter_(-1, tensor.long(), 1) 18 | return y_onehot.float() 19 | 20 | def infer_output_info(self, vshape_in, dtype_in): 21 | # TODO: Check this shouldn't be here 22 | # assert vshape_in == (1,) 23 | return (self.out_dim,), th.float32 -------------------------------------------------------------------------------- /src/config/algs/vdn_atten.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 500000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | training_iters: 8 12 | 13 | entity_last_action: True # Include the user-controlled agents' last actions (one_hot) in their entities 14 | 15 | buffer_size: 5000 16 | 17 | # update the target network every {} episodes 18 | target_update_interval: 200 19 | 20 | # use the Q_Learner to train 21 | agent_output_type: "q" 22 | learner: "q_learner" 23 | double_q: True 24 | mixer: "vdn" 25 | agent: "entity_attend_rnn" 26 | rnn_hidden_dim: 64 27 | mac: "entity_mac" 28 | attn_embed_dim: 128 29 | attn_n_heads: 4 30 | 31 | name: "vdn_atten" 32 | -------------------------------------------------------------------------------- /src/config/algs/refil_vdn.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 500000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | training_iters: 8 12 | 13 | entity_last_action: True # Include the user-controlled agents' last actions (one_hot) in their entities 14 | 15 | buffer_size: 5000 16 | 17 | # update the target network every {} episodes 18 | target_update_interval: 200 19 | 20 | # use the Q_Learner to train 21 | agent_output_type: "q" 22 | learner: "q_learner" 23 | double_q: True 24 | mixer: "vdn" 25 | agent: "imagine_entity_attend_rnn" 26 | rnn_hidden_dim: 64 27 | mac: "entity_mac" 28 | attn_embed_dim: 128 29 | attn_n_heads: 4 30 | lmbda: 0.5 31 | 32 | name: "refil_vdn" 33 | -------------------------------------------------------------------------------- /src/components/epsilon_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class DecayThenFlatSchedule(): 4 | 5 | def __init__(self, 6 | start, 7 | finish, 8 | time_length, 9 | decay="exp"): 10 | 11 | self.start = start 12 | self.finish = finish 13 | self.time_length = time_length 14 | self.delta = (self.start - self.finish) / self.time_length 15 | self.decay = decay 16 | 17 | if self.decay in ["exp"]: 18 | self.exp_scaling = (-1) * self.time_length / np.log(self.finish) if self.finish > 0 else 1 19 | 20 | def eval(self, T): 21 | if self.decay in ["linear"]: 22 | return max(self.finish, self.start - self.delta * T) 23 | elif self.decay in ["exp"]: 24 | return min(self.start, max(self.finish, np.exp(- T / self.exp_scaling))) 25 | pass -------------------------------------------------------------------------------- /src/config/algs/qmix_atten.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 500000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | training_iters: 8 12 | 13 | entity_last_action: True # Include the user-controlled agents' last actions (one_hot) in their entities 14 | 15 | buffer_size: 5000 16 | 17 | # update the target network every {} episodes 18 | target_update_interval: 200 19 | 20 | # use the Q_Learner to train 21 | agent_output_type: "q" 22 | learner: "q_learner" 23 | double_q: True 24 | mixer: "flex_qmix" 25 | mixing_embed_dim: 32 26 | hypernet_embed: 128 27 | softmax_mixing_weights: True 28 | agent: "entity_attend_rnn" 29 | rnn_hidden_dim: 64 30 | mac: "entity_mac" 31 | attn_embed_dim: 128 32 | attn_n_heads: 4 33 | lmbda: 0.5 34 | 35 | name: "qmix_atten" 36 | -------------------------------------------------------------------------------- /src/config/algs/refil.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 500000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | training_iters: 8 12 | 13 | entity_last_action: True # Include the user-controlled agents' last actions (one_hot) in their entities 14 | 15 | buffer_size: 5000 16 | 17 | # update the target network every {} episodes 18 | target_update_interval: 200 19 | 20 | # use the Q_Learner to train 21 | agent_output_type: "q" 22 | learner: "q_learner" 23 | double_q: True 24 | mixer: "flex_qmix" 25 | mixing_embed_dim: 32 26 | hypernet_embed: 128 27 | softmax_mixing_weights: True 28 | agent: "imagine_entity_attend_rnn" 29 | rnn_hidden_dim: 64 30 | mac: "entity_mac" 31 | attn_embed_dim: 128 32 | attn_n_heads: 4 33 | lmbda: 0.5 34 | 35 | name: "refil" 36 | -------------------------------------------------------------------------------- /src/config/algs/qmix_atten_group_matching.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 5000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | training_iters: 8 12 | 13 | entity_last_action: False # Include the user-controlled agents' last actions (one_hot) in their entities 14 | 15 | buffer_size: 2000 16 | 17 | # update the target network every {} episodes 18 | target_update_interval: 200 19 | 20 | # use the Q_Learner to train 21 | agent_output_type: "q" 22 | learner: "q_learner" 23 | double_q: True 24 | mixer: "lin_flex_qmix" 25 | mixing_embed_dim: 32 26 | hypernet_embed: 64 27 | softmax_mixing_weights: True 28 | agent: "entity_attend_ff" 29 | rnn_hidden_dim: 64 30 | mac: "entity_mac" 31 | attn_embed_dim: 64 32 | attn_n_heads: 4 33 | lmbda: 0.5 34 | 35 | name: "qmix_atten" 36 | -------------------------------------------------------------------------------- /src/modules/agents/ff_agent.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class FFAgent(nn.Module): 6 | def __init__(self, input_shape, args): 7 | super(FFAgent, self).__init__() 8 | self.args = args 9 | 10 | # Easiest to reuse rnn_hidden_dim variable 11 | self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) 12 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.rnn_hidden_dim) 13 | self.fc3 = nn.Linear(args.rnn_hidden_dim, args.n_actions) 14 | 15 | def init_hidden(self): 16 | # make hidden states on same device as model 17 | return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() 18 | 19 | def forward(self, inputs, hidden_state): 20 | x = F.relu(self.fc1(inputs)) 21 | # h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) 22 | h = F.relu(self.fc2(x)) 23 | q = self.fc3(h) 24 | return q, h 25 | -------------------------------------------------------------------------------- /src/config/envs/sc2.yaml: -------------------------------------------------------------------------------- 1 | env: sc2 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "3m" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | state_last_action: True 27 | state_timestep_number: False 28 | step_mul: 8 29 | heuristic_ai: False 30 | heuristic_rest: False 31 | debug: False 32 | 33 | test_nepisode: 32 34 | test_interval: 10000 35 | log_interval: 2000 36 | runner_log_interval: 2000 37 | learner_log_interval: 2000 38 | t_max: 2000000 39 | -------------------------------------------------------------------------------- /src/config/algs/refil_group_matching.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 5000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | training_iters: 8 12 | 13 | entity_last_action: False # Include the user-controlled agents' last actions (one_hot) in their entities 14 | 15 | buffer_size: 2000 16 | 17 | # update the target network every {} episodes 18 | target_update_interval: 200 19 | 20 | test_gt_factors: True # record proportion of q_tot coming from ground truth in-group 21 | 22 | # use the Q_Learner to train 23 | agent_output_type: "q" 24 | learner: "q_learner" 25 | double_q: True 26 | mixer: "lin_flex_qmix" 27 | mixing_embed_dim: 32 28 | hypernet_embed: 64 29 | softmax_mixing_weights: True 30 | agent: "imagine_entity_attend_ff" 31 | rnn_hidden_dim: 64 32 | mac: "entity_mac" 33 | attn_embed_dim: 64 34 | attn_n_heads: 4 35 | lmbda: 0.5 36 | 37 | name: "refil" 38 | -------------------------------------------------------------------------------- /src/config/envs/sc2custom.yaml: -------------------------------------------------------------------------------- 1 | env: sc2custom 2 | scenario: "1-5m_symmetric" 3 | 4 | env_args: 5 | entity_scheme: True 6 | continuing_episode: False 7 | episode_limit: 150 8 | difficulty: "7" 9 | game_version: null 10 | move_amount: 2 11 | random_tags: True 12 | obs_all_health: True 13 | obs_instead_of_state: False 14 | obs_own_health: True 15 | obs_last_action: False 16 | obs_pathing_grid: False 17 | obs_terrain_height: False 18 | obs_timestep_number: False 19 | state_last_action: True 20 | sight_range: 9 21 | reward_death_value: 10 22 | reward_defeat: 0 23 | reward_negative_scale: 0.5 24 | reward_only_positive: True 25 | reward_scale: True 26 | reward_scale_rate: 20 27 | reward_sparse: False 28 | reward_win: 200 29 | replay_dir: "" 30 | replay_prefix: "" 31 | state_timestep_number: False 32 | step_mul: 8 33 | heuristic_ai: False 34 | debug: False 35 | 36 | test_nepisode: 160 37 | test_interval: 50000 38 | log_interval: 10000 39 | runner_log_interval: 10000 40 | learner_log_interval: 10000 41 | t_max: 10000000 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Shariq Iqbal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /install_sc2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Install SC2 and add the custom maps 3 | 4 | if [ -z "$EXP_DIR" ] 5 | then 6 | EXP_DIR=~ 7 | fi 8 | 9 | echo "EXP_DIR: $EXP_DIR" 10 | cd $EXP_DIR/REFIL 11 | 12 | mkdir 3rdparty 13 | cd 3rdparty 14 | 15 | export SC2PATH=`pwd`'/StarCraftII' 16 | echo 'SC2PATH is set to '$SC2PATH 17 | 18 | if [ ! -d $SC2PATH ]; then 19 | echo 'StarCraftII is not installed. Installing now ...'; 20 | wget http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.6.2.69232.zip 21 | unzip -P iagreetotheeula SC2.4.6.2.69232.zip 22 | rm -rf SC2.4.6.2.69232.zip 23 | else 24 | echo 'StarCraftII is already installed.' 25 | fi 26 | 27 | echo 'Adding SMAC maps.' 28 | MAP_DIR="$SC2PATH/Maps/" 29 | echo 'MAP_DIR is set to '$MAP_DIR 30 | 31 | if [ ! -d $MAP_DIR ]; then 32 | mkdir -p $MAP_DIR 33 | fi 34 | 35 | cd .. 36 | wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip 37 | unzip SMAC_Maps.zip 38 | mv SMAC_Maps $MAP_DIR 39 | rm -rf SMAC_Maps.zip 40 | 41 | cp src/envs/starcraft2/maps/SMAC_Maps/empty_passive.SC2Map $MAP_DIR 42 | 43 | echo 'StarCraft II and SMAC are installed.' 44 | 45 | -------------------------------------------------------------------------------- /src/modules/agents/rnn_agent.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class RNNAgent(nn.Module): 6 | def __init__(self, input_shape, args): 7 | super(RNNAgent, self).__init__() 8 | self.args = args 9 | 10 | self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) 11 | self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) 12 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) 13 | 14 | def init_hidden(self): 15 | # make hidden states on same device as model 16 | return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() 17 | 18 | def forward(self, inputs, hidden_state): 19 | bs, ts, na, os = inputs.shape 20 | 21 | x = F.relu(self.fc1(inputs)) 22 | 23 | h = hidden_state.reshape(-1, self.args.rnn_hidden_dim) 24 | hs = [] 25 | for t in range(ts): 26 | curr_x = x[:, t].reshape(-1, self.args.rnn_hidden_dim) 27 | h = self.rnn(curr_x, h) 28 | hs.append(h.view(bs, na, self.args.rnn_hidden_dim)) 29 | hs = th.stack(hs, dim=1) # Concat over time 30 | 31 | q = self.fc2(hs) 32 | return q, hs 33 | -------------------------------------------------------------------------------- /src/utils/timehelper.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | def print_time(start_time, T, t_max, episode, episode_rewards): 6 | time_elapsed = time.time() - start_time 7 | T = max(1, T) 8 | time_left = time_elapsed * (t_max - T) / T 9 | # Just in case its over 100 days 10 | time_left = min(time_left, 60 * 60 * 24 * 100) 11 | last_reward = "N\A" 12 | if len(episode_rewards) > 5: 13 | last_reward = "{:.2f}".format(np.mean(episode_rewards[-50:])) 14 | print("\033[F\033[F\x1b[KEp: {:,}, T: {:,}/{:,}, Reward: {}, \n\x1b[KElapsed: {}, Left: {}\n".format(episode, T, t_max, last_reward, time_str(time_elapsed), time_str(time_left)), " " * 10, end="\r") 15 | 16 | def time_left(start_time, t_start, t_current, t_max): 17 | if t_current >= t_max: 18 | return "-" 19 | time_elapsed = time.time() - start_time 20 | t_current = max(1, t_current) 21 | time_left = time_elapsed * (t_max - t_current) / (t_current - t_start) 22 | # Just in case its over 100 days 23 | time_left = min(time_left, 60 * 60 * 24 * 100) 24 | return time_str(time_left) 25 | 26 | def time_str(s): 27 | """ 28 | Convert seconds to a nicer string showing days, hours, minutes and seconds 29 | """ 30 | days, remainder = divmod(s, 60 * 60 * 24) 31 | hours, remainder = divmod(remainder, 60 * 60) 32 | minutes, seconds = divmod(remainder, 60) 33 | string = "" 34 | if days > 0: 35 | string += "{:d} days, ".format(int(days)) 36 | if hours > 0: 37 | string += "{:d} hours, ".format(int(hours)) 38 | if minutes > 0: 39 | string += "{:d} minutes, ".format(int(minutes)) 40 | string += "{:d} seconds".format(int(seconds)) 41 | return string 42 | -------------------------------------------------------------------------------- /src/utils/rl_utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | def build_td_lambda_targets__old(rewards, terminated, mask, target_qs, n_agents, gamma, td_lambda): 5 | bs = rewards.size(0) 6 | max_t = rewards.size(1) 7 | targets = rewards.new(target_qs.size()).zero_()[:,:-1] # Produce 1 less target than the inputted Q-Values 8 | running_target = rewards.new(bs, n_agents).zero_() 9 | terminated = terminated.float() 10 | for t in reversed(range(max_t)): 11 | if t == max_t - 1: 12 | running_target = mask[:, t] * (rewards[:, t] + gamma * (1 - terminated[:, t]) * target_qs[:, t]) 13 | else: 14 | running_target = mask[:, t] * ( 15 | terminated[:, t] * rewards[:, t] # Just the reward if the env terminates 16 | + (1 - terminated[:, t]) * (rewards[:, t] + gamma * (td_lambda * running_target + (1 - td_lambda) * target_qs[:, t])) 17 | ) 18 | targets[:, t, :] = running_target 19 | return targets 20 | 21 | 22 | def build_td_lambda_targets(rewards, terminated, mask, target_qs, n_agents, gamma, td_lambda): 23 | # Assumes in B*T*A and , , in (at least) B*T-1*1 24 | # Initialise last lambda -return for not terminated episodes 25 | ret = target_qs.new_zeros(*target_qs.shape) 26 | ret[:, -1] = target_qs[:, -1] * (1 - th.sum(terminated, dim=1)) 27 | # Backwards recursive update of the "forward view" 28 | for t in range(ret.shape[1] - 2, -1, -1): 29 | ret[:, t] = td_lambda * gamma * ret[:, t + 1] + mask[:, t] \ 30 | * (rewards[:, t] + (1 - td_lambda) * gamma * target_qs[:, t + 1] * (1 - terminated[:, t])) 31 | # Returns lambda-return from t=0 to t=T-1, i.e. in B*T-1*A 32 | return ret[:, 0:-1] 33 | 34 | -------------------------------------------------------------------------------- /src/controllers/entity_controller.py: -------------------------------------------------------------------------------- 1 | from .basic_controller import BasicMAC 2 | import torch as th 3 | 4 | 5 | # This multi-agent controller shares parameters between agents and takes 6 | # entities + observation masks as input 7 | class EntityMAC(BasicMAC): 8 | def __init__(self, scheme, groups, args): 9 | super(EntityMAC, self).__init__(scheme, groups, args) 10 | 11 | def _build_inputs(self, batch, t): 12 | # Assumes homogenous agents with entity + observation mask inputs. 13 | bs = batch.batch_size 14 | entities = [] 15 | entities.append(batch["entities"][:, t]) # bs, ts, n_entities, vshape 16 | if self.args.entity_last_action: 17 | ent_acs = th.zeros(bs, t.stop - t.start, self.args.n_entities, 18 | self.args.n_actions, device=batch.device, 19 | dtype=batch["entities"].dtype) 20 | if t.start == 0: 21 | ent_acs[:, 1:, :self.args.n_agents] = ( 22 | batch["actions_onehot"][:, slice(0, t.stop - 1)]) 23 | else: 24 | ent_acs[:, :, :self.args.n_agents] = ( 25 | batch["actions_onehot"][:, slice(t.start - 1, t.stop - 1)]) 26 | entities.append(ent_acs) 27 | entities = th.cat(entities, dim=3) 28 | if self.args.gt_mask_avail: 29 | return (entities, batch["obs_mask"][:, t], batch["entity_mask"][:, t], batch["gt_mask"][:, t]) 30 | return (entities, batch["obs_mask"][:, t], batch["entity_mask"][:, t]) 31 | 32 | def _get_input_shape(self, scheme): 33 | input_shape = scheme["entities"]["vshape"] 34 | if self.args.entity_last_action: 35 | input_shape += scheme["actions_onehot"]["vshape"][0] 36 | return input_shape 37 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn7-devel-ubuntu16.04 2 | # FROM ubuntu:16.04 3 | MAINTAINER Christian Schroeder de Witt 4 | 5 | # CUDA includes 6 | ENV CUDA_PATH /usr/local/cuda 7 | ENV CUDA_INCLUDE_PATH /usr/local/cuda/include 8 | ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64 9 | 10 | # Ubuntu Packages 11 | RUN apt-get update -y && apt-get install software-properties-common -y && \ 12 | add-apt-repository -y multiverse && \ 13 | apt-get install -y sudo curl libssl-dev openssl libopenblas-dev \ 14 | libhdf5-dev hdf5-helpers hdf5-tools libhdf5-serial-dev libprotobuf-dev \ 15 | protobuf-compiler apt-utils nano vim man build-essential wget && \ 16 | curl -sk https://raw.githubusercontent.com/torch/distro/master/install-deps | bash && \ 17 | add-apt-repository -y ppa:deadsnakes/ppa && apt-get update -y && apt-get upgrade -y && \ 18 | apt-get install -y python3.6 19 | 20 | RUN curl "https://bootstrap.pypa.io/get-pip.py" -o "get-pip.py" && python3.6 get-pip.py && \ 21 | rm -rf /var/lib/apt/lists/* 22 | 23 | RUN pip3 install numpy scipy pyyaml matplotlib imageio pygame imageio-ffmpeg tensorboard-logger ruamel.base ryd jsonpickle==0.9.6 24 | 25 | RUN mkdir /install 26 | WORKDIR /install 27 | 28 | # install Sacred 29 | RUN pip3 install setuptools 30 | RUN git clone https://github.com/oxwhirl/sacred.git /install/sacred && cd /install/sacred && python3.6 setup.py install 31 | 32 | # Install pymongo 33 | RUN pip3 install pymongo 34 | 35 | #### ------------------------------------------------------------------- 36 | #### install pytorch 37 | #### ------------------------------------------------------------------- 38 | RUN pip3 install torch==1.1.0 39 | 40 | ## -- SMAC 41 | ENV smac_ver 1 42 | RUN pip3 install git+https://github.com/oxwhirl/smac.git 43 | ENV SC2PATH /REFIL/3rdparty/StarCraftII 44 | 45 | WORKDIR /REFIL 46 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #created by http://www.gitignore.io 2 | 3 | ### Python ### 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | 8 | # C extensions 9 | # *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | bin/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # Installer logs 29 | pip-log.txt 30 | pip-delete-this-directory.txt 31 | 32 | # Unit test / coverage reports 33 | htmlcov/ 34 | .tox/ 35 | .coverage 36 | .cache 37 | nosetests.xml 38 | coverage.xml 39 | 40 | # Translations 41 | *.mo 42 | 43 | # Mr Developer 44 | .mr.developer.cfg 45 | .project 46 | .pydevproject 47 | 48 | # Rope 49 | .ropeproject 50 | 51 | # Django stuff: 52 | *.log 53 | *.pot 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # Jupyter 59 | .ipynb_checkpoints 60 | 61 | 62 | ### PyCharm ### 63 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm 64 | 65 | ## Directory-based project format 66 | .idea/ 67 | # if you remove the above rule, at least ignore user-specific stuff: 68 | # .idea/workspace.xml 69 | # .idea/tasks.xml 70 | # and these sensitive or high-churn files: 71 | # .idea/dataSources.ids 72 | # .idea/dataSources.xml 73 | # .idea/sqlDataSources.xml 74 | # .idea/dynamic.xml 75 | 76 | ## File-based project format 77 | *.ipr 78 | *.iws 79 | *.iml 80 | 81 | ## Additional for IntelliJ 82 | out/ 83 | 84 | # generated by mpeltonen/sbt-idea plugin 85 | .idea_modules/ 86 | 87 | # generated by JIRA plugin 88 | atlassian-ide-plugin.xml 89 | 90 | # generated by Crashlytics plugin (for Android Studio and Intellij) 91 | com_crashlytics_export_strings.xml 92 | 93 | # GEdit temporary files 94 | *~ 95 | 96 | # own additions 97 | __pycache__ 98 | mongodb/ 99 | results/ 100 | pool/ 101 | notebook/ 102 | tb_logs/ 103 | StarCraftII/ 104 | profiler/ 105 | *.out 106 | 3rdparty/ 107 | .DS_Store 108 | venv/ 109 | 3rdparty 110 | *.txt 111 | -------------------------------------------------------------------------------- /src/envs/multiagentenv.py: -------------------------------------------------------------------------------- 1 | class MultiAgentEnv(object): 2 | 3 | def step(self, actions): 4 | """ Returns reward, terminated, info """ 5 | raise NotImplementedError 6 | 7 | def get_obs(self): 8 | """ Returns all agent observations in a list """ 9 | raise NotImplementedError 10 | 11 | def get_obs_agent(self, agent_id): 12 | """ Returns observation for agent_id """ 13 | raise NotImplementedError 14 | 15 | def get_obs_size(self): 16 | """ Returns the shape of the observation """ 17 | raise NotImplementedError 18 | 19 | def get_state(self): 20 | raise NotImplementedError 21 | 22 | def get_state_size(self): 23 | """ Returns the shape of the state""" 24 | raise NotImplementedError 25 | 26 | def get_avail_actions(self): 27 | raise NotImplementedError 28 | 29 | def get_avail_agent_actions(self, agent_id): 30 | """ Returns the available actions for agent_id """ 31 | raise NotImplementedError 32 | 33 | def get_total_actions(self): 34 | """ Returns the total number of actions an agent could ever take """ 35 | # TODO: This is only suitable for a discrete 1 dimensional action space for each agent 36 | raise NotImplementedError 37 | 38 | def get_stats(self): 39 | raise NotImplementedError 40 | 41 | # TODO: Temp hack 42 | def get_agg_stats(self, stats): 43 | return {} 44 | 45 | def reset(self): 46 | """ Returns initial observations and states""" 47 | raise NotImplementedError 48 | 49 | def render(self): 50 | raise NotImplementedError 51 | 52 | def close(self): 53 | pass # This gets called all the time. 54 | 55 | def seed(self): 56 | raise NotImplementedError 57 | 58 | def save_replay(self): 59 | raise NotImplementedError 60 | 61 | def get_env_info(self, args): 62 | env_info = {"state_shape": self.get_state_size(), 63 | "obs_shape": self.get_obs_size(), 64 | "n_actions": self.get_total_actions(), 65 | "n_agents": self.n_agents, 66 | "episode_limit": self.episode_limit} 67 | if hasattr(self, 'get_obs_st_masks'): 68 | env_info['masks'] = self.get_obs_st_masks(args) 69 | if hasattr(self, 'unit_dim'): 70 | env_info['unit_dim'] = self.unit_dim 71 | return env_info 72 | -------------------------------------------------------------------------------- /src/components/action_selectors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | from torch.autograd import Variable 4 | from torch.distributions import Categorical 5 | from torch.nn.functional import softmax 6 | from .epsilon_schedules import DecayThenFlatSchedule 7 | 8 | REGISTRY = {} 9 | 10 | class MultinomialActionSelector(): 11 | 12 | def __init__(self, args): 13 | self.args = args 14 | 15 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, 16 | decay="linear") 17 | self.epsilon = self.schedule.eval(0) 18 | self.test_greedy = getattr(args, "test_greedy", True) 19 | 20 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 21 | masked_policies = agent_inputs.clone() 22 | masked_policies[avail_actions == 0.0] = 0.0 23 | 24 | self.epsilon = self.schedule.eval(t_env) 25 | 26 | if test_mode and self.test_greedy: 27 | picked_actions = masked_policies.max(dim=2)[1] 28 | else: 29 | picked_actions = Categorical(masked_policies).sample().long() 30 | 31 | return picked_actions 32 | 33 | REGISTRY["multinomial"] = MultinomialActionSelector 34 | 35 | 36 | class EpsilonGreedyActionSelector(): 37 | 38 | def __init__(self, args): 39 | self.args = args 40 | 41 | # Was there so I used it 42 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, decay="linear") 43 | self.epsilon = self.schedule.eval(0) 44 | 45 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 46 | 47 | # Assuming agent_inputs is a batch of Q-Values for each agent bav 48 | self.epsilon = self.schedule.eval(t_env) 49 | 50 | if test_mode: 51 | # Greedy action selection only 52 | self.epsilon = 0.0 53 | 54 | # mask actions that are excluded from selection 55 | masked_q_values = agent_inputs.clone() 56 | masked_q_values[avail_actions == 0.0] = -float("inf") # should never be selected! 57 | 58 | random_numbers = th.rand_like(agent_inputs[:,:,0]) 59 | pick_random = (random_numbers < self.epsilon).long() 60 | random_actions = Categorical(avail_actions.float()).sample().long() 61 | 62 | picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1] 63 | return picked_actions 64 | 65 | REGISTRY["epsilon_greedy"] = EpsilonGreedyActionSelector 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # REFIL 2 | Code for [*Randomized Entity-wise Factorization for Multi-Agent Reinforcement Learning*](https://arxiv.org/abs/2006.04222) (Iqbal et al., ICML 2021) 3 | 4 | This codebase is built on top of the [PyMARL](https://github.com/oxwhirl/pymarl) framework for multi-agent reinforcement learning algorithms. 5 | 6 | # Dependencies 7 | - Docker 8 | - NVIDIA-Docker (if you want to use GPUs) 9 | 10 | ## Setup instructions 11 | 12 | Build the Dockerfile using 13 | ```shell 14 | cd docker 15 | ./build.sh 16 | ``` 17 | 18 | Set up StarCraft II. 19 | 20 | ```shell 21 | ./install_sc2.sh 22 | ``` 23 | 24 | ## Run an experiment 25 | 26 | Run an `ALGORITHM` from the folder `src/config/algs` 27 | in an `ENVIRONMENT` from the folder `src/config/envs` 28 | on a specific `GPU` using some `PARAMETERS`: 29 | ```shell 30 | ./run.sh src/main.py --env-config= --config= with 31 | ``` 32 | 33 | Possible environments are: 34 | - `group_matching`: Group Matching environment from the paper 35 | - `sc2custom`: StarCraft environment from the paper 36 | 37 | For StarCraft you need to specify the set of tasks to train on by including the parameter `scenario=`. 38 | Here are the possible scenario sets: 39 | 40 | - Included in the paper: 41 | - `3-8sz_symmetric` 42 | - `3-8MMM_symmetric` 43 | - `3-8csz_symmetric` 44 | - Debugging/Additional: 45 | - `3-8m_symmetric` 46 | - `6-11m_mandown` 47 | 48 | Possible algorithms are: 49 | - `refil`: REFIL (our method) 50 | - `refil_group_matching`: REFIL w/ hyperparameters for Group Matching game 51 | - `qmix_atten`: QMIX (Attention) 52 | - `qmix_atten_group_matching`: QMIX (Attention) w/ hyperparameters for Group Matching game 53 | - `refil_vdn`: REFIL (VDN) 54 | - `vdn_atten`: VDN (Attention) 55 | 56 | For group matching oracle methods, include the following parameters while selecting `refil_group_matching` as the algorithm: 57 | - REFIL (Fixed Oracle): `train_gt_factors=True` 58 | - REFIL (Randomized Oracle): `train_rand_gt_factors=True` 59 | 60 | ## Citing our work 61 | 62 | If you use this repo in your work, please consider citing the corresponding paper: 63 | 64 | ```bibtex 65 | @InProceedings{iqbal2021refil, 66 | title={Randomized Entity-wise Factorization for Multi-Agent Reinforcement Learning}, 67 | author={Iqbal, Shariq and de Witt, Christian A Schroeder and Peng, Bei and B{\"o}hmer, Wendelin and Whiteson, Shimon and Sha, Fei}, 68 | booktitle = {Proceedings of the 38th International Conference on Machine Learning}, 69 | year = {2021}, 70 | series = {Proceedings of Machine Learning Research}, 71 | publisher = {PMLR}, 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import logging 3 | import numpy as np 4 | 5 | class Logger: 6 | def __init__(self, console_logger): 7 | self.console_logger = console_logger 8 | 9 | self.use_tb = False 10 | self.use_sacred = False 11 | self.use_hdf = False 12 | 13 | self.stats = defaultdict(lambda: []) 14 | 15 | def setup_tb(self, directory_name): 16 | # Import here so it doesn't have to be installed if you don't use it 17 | from tensorboard_logger import configure, log_value 18 | configure(directory_name) 19 | self.tb_logger = log_value 20 | self.use_tb = True 21 | 22 | def setup_sacred(self, sacred_run_dict): 23 | self.sacred_info = sacred_run_dict.info 24 | self.use_sacred = True 25 | 26 | # TODO: Setup hdf logger 27 | 28 | def log_stat(self, key, value, t, to_sacred=True): 29 | self.stats[key].append((t, value)) 30 | 31 | if self.use_tb: 32 | self.tb_logger(key, value, t) 33 | 34 | if self.use_sacred and to_sacred: 35 | if key in self.sacred_info: 36 | self.sacred_info["{}_T".format(key)].append(t) 37 | self.sacred_info[key].append(value) 38 | else: 39 | self.sacred_info["{}_T".format(key)] = [t] 40 | self.sacred_info[key] = [value] 41 | 42 | def print_recent_stats(self): 43 | log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(*self.stats["episode"][-1]) 44 | i = 0 45 | for (k, v) in sorted(self.stats.items()): 46 | if k == "episode": 47 | continue 48 | i += 1 49 | window = 5 if k != "epsilon" else 1 50 | item = "{:.4f}".format(np.mean([x[1] for x in self.stats[k][-window:]])) 51 | log_str += "{:<25}{:>8}".format(k + ":", item) 52 | log_str += "\n" if i % 4 == 0 else "\t" 53 | self.console_logger.info(log_str) 54 | 55 | def print_stats_summary(self): 56 | log_str = "Summary Stats" 57 | i = 0 58 | for (k, v) in sorted(self.stats.items()): 59 | if k == "episode": 60 | continue 61 | i += 1 62 | mean_value = np.mean([x[1] for x in self.stats[k]], axis=0) 63 | if len(mean_value.shape) == 0: 64 | item = "{:.4f}".format(mean_value) 65 | else: 66 | item = mean_value.__repr__() 67 | log_str += "{:<25}{:>8}".format(k + ":", item) 68 | log_str += "\n" if i % 4 == 0 else "\t" 69 | self.console_logger.info(log_str) 70 | 71 | # set up a custom logger 72 | def get_logger(): 73 | logger = logging.getLogger() 74 | logger.handlers = [] 75 | ch = logging.StreamHandler() 76 | formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S') 77 | ch.setFormatter(formatter) 78 | logger.addHandler(ch) 79 | logger.setLevel('DEBUG') 80 | 81 | return logger 82 | 83 | -------------------------------------------------------------------------------- /src/config/default.yaml: -------------------------------------------------------------------------------- 1 | # --- Defaults --- 2 | 3 | # --- pymarl options --- 4 | runner: "episode" # Runs 1 env for an episode 5 | mac: "entity_mac" # Basic controller 6 | env: "sc2custom" # Environment name 7 | env_args: {} # Arguments for the environment 8 | batch_size_run: 1 # Number of environments to run in parallel 9 | test_nepisode: 20 # Number of episodes to test for 10 | test_interval: 2000 # Test after {} timesteps have passed 11 | test_greedy: True # Use greedy evaluation (if False, will set epsilon floor to 0 12 | log_interval: 2000 # Log summary of stats after every {} timesteps 13 | runner_log_interval: 2000 # Log runner stats (not test stats) every {} timesteps 14 | learner_log_interval: 2000 # Log training stats every {} timesteps 15 | t_max: 10000 # Stop running after this many timesteps 16 | use_cuda: True # Use gpu by default unless it isn't available 17 | buffer_cpu_only: False # If true we won't keep all of the replay buffer in vram 18 | 19 | # --- Logging options --- 20 | use_tensorboard: False # Log results to tensorboard 21 | save_model: False # Save the models to disk 22 | save_model_interval: 5000 # Save models after this many timesteps 23 | checkpoint_path: "" # Load a checkpoint from this path 24 | evaluate: False # Evaluate model for test_nepisode episodes and quit (no training) 25 | load_step: 0 # Load model trained on this many timesteps (0 if choose max possible) 26 | save_replay: False # Saving the replay of the model loaded from checkpoint_path 27 | video_path: # if path provided, save a video for evaluation runs 28 | fps: 2 # video frames per second 29 | local_results_path: "results" # Path for local results 30 | tb_dirname: "tb_logs" 31 | eval_all_scen: False # if True, evaluate on each separate scenario and report performance individually, otherwise randomly sample and report average performance 32 | eval_path: # if path provided, save evaluation results here in json form 33 | 34 | # --- RL hyperparameters --- 35 | gamma: 0.99 36 | batch_size: 32 # Number of episodes to train on 37 | buffer_size: 32 # Size of the replay buffer 38 | lr: 0.0005 # Learning rate for agents 39 | optim_alpha: 0.99 # RMSProp alpha 40 | optim_eps: 0.00001 # RMSProp epsilon 41 | grad_norm_clip: 10 # Reduce magnitude of gradients above this L2 norm 42 | weight_decay: 0 # L2 penalty weight decay on agent parameters 43 | pooling_type: # 'max' or 'mean' pooling used instead of attention if provided 44 | 45 | # --- Agent parameters --- 46 | agent: "rnn" # Default rnn agent 47 | rnn_hidden_dim: 64 # Size of hidden state for default rnn agent 48 | obs_agent_id: True # Include the agent's one_hot id in the observation 49 | obs_last_action: True # Include the agent's last action (one_hot) in the observation 50 | # This section is for the group matching game where we know the ground truth relevant entities to each agent 51 | gt_obs_mask: False # Use ground-truth observation mask 52 | train_gt_factors: False # Train w/ imagine groups automatically set to be ground-truth 53 | train_rand_gt_factors: False # Train w/ randomized ground-truth factors 54 | test_gt_factors: False # Test w/ imagine groups automatically set to be ground-truth and measure proportion of in-group weights w/ linear mixing network 55 | # --- Mixing/Hypernet parameters --- 56 | softmax_mixing_weights: False 57 | 58 | training_iters: 1 59 | 60 | # --- Experiment running params --- 61 | repeat_id: 1 62 | label: "default_label" 63 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import collections 4 | from os.path import dirname, abspath 5 | from sacred import Experiment, SETTINGS 6 | from sacred.observers import FileStorageObserver 7 | from sacred.utils import apply_backspaces_and_linefeeds 8 | import sys 9 | import torch as th 10 | from utils.logging import get_logger 11 | import yaml 12 | import time 13 | 14 | from run import run 15 | 16 | SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console 17 | logger = get_logger() 18 | 19 | ex = Experiment("pymarl") 20 | ex.logger = logger 21 | ex.captured_out_filter = apply_backspaces_and_linefeeds 22 | 23 | results_path = os.path.join(dirname(dirname(abspath(__file__))), "results") 24 | 25 | 26 | @ex.main 27 | def my_main(_run, _config, _log, env_args): 28 | # Setting the random seed throughout the modules 29 | np.random.seed(_config["seed"]) 30 | th.manual_seed(_config["seed"]) 31 | env_args['seed'] = _config["seed"] 32 | 33 | # run the framework 34 | run(_run, _config, _log) 35 | 36 | # force exit 37 | os._exit(0) 38 | 39 | 40 | def _get_config(params, arg_name, subfolder): 41 | config_name = None 42 | for _i, _v in enumerate(params): 43 | if _v.split("=")[0] == arg_name: 44 | config_name = _v.split("=")[1] 45 | del params[_i] 46 | break 47 | 48 | if config_name is not None: 49 | with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f: 50 | try: 51 | config_dict = yaml.load(f) 52 | except yaml.YAMLError as exc: 53 | assert False, "{}.yaml error: {}".format(config_name, exc) 54 | return config_dict 55 | 56 | 57 | def recursive_dict_update(d, u): 58 | for k, v in u.items(): 59 | if isinstance(v, collections.Mapping): 60 | d[k] = recursive_dict_update(d.get(k, {}), v) 61 | else: 62 | d[k] = v 63 | return d 64 | 65 | 66 | if __name__ == '__main__': 67 | import os 68 | 69 | from copy import deepcopy 70 | params = deepcopy(sys.argv) 71 | 72 | # Get the defaults from default.yaml 73 | with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f: 74 | try: 75 | config_dict = yaml.load(f) 76 | except yaml.YAMLError as exc: 77 | assert False, "default.yaml error: {}".format(exc) 78 | 79 | # Load algorithm and env base configs 80 | env_config = _get_config(params, "--env-config", "envs") 81 | alg_config = _get_config(params, "--config", "algs") 82 | # config_dict = {**config_dict, **env_config, **alg_config} 83 | config_dict = recursive_dict_update(config_dict, env_config) 84 | config_dict = recursive_dict_update(config_dict, alg_config) 85 | 86 | # now add all the config to sacred 87 | ex.add_config(config_dict) 88 | 89 | if not config_dict['evaluate']: # only log if training 90 | # Save to disk by default for sacred 91 | logger.info("Saving to FileStorageObserver in results/sacred.") 92 | file_obs_path = os.path.join(results_path, "sacred") 93 | while True: 94 | try: 95 | ex.observers.append(FileStorageObserver.create(file_obs_path)) 96 | break 97 | except FileExistsError: 98 | # sometimes we see race condition 99 | logger.info("Creating FileStorageObserver failed. Trying again...") 100 | time.sleep(1) 101 | 102 | ex.run_commandline(params) 103 | 104 | -------------------------------------------------------------------------------- /src/modules/mixers/qmix.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class QMixer(nn.Module): 8 | def __init__(self, args): 9 | super(QMixer, self).__init__() 10 | 11 | self.args = args 12 | 13 | self.n_agents = args.n_agents 14 | self.state_dim = int(np.prod(args.state_shape)) 15 | 16 | self.embed_dim = args.mixing_embed_dim 17 | 18 | self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) 19 | self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) 20 | 21 | if getattr(self.args, "hypernet_layers", 1) > 1: 22 | assert self.args.hypernet_layers == 2, "Only 1 or 2 hypernet_layers is supported atm!" 23 | hypernet_embed = self.args.hypernet_embed 24 | self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 25 | nn.ReLU(), 26 | nn.Linear(hypernet_embed, self.embed_dim * self.n_agents)) 27 | self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 28 | nn.ReLU(), 29 | nn.Linear(hypernet_embed, self.embed_dim)) 30 | 31 | # State dependent bias for hidden layer 32 | self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 33 | 34 | # V(s) instead of a bias for the last layers 35 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 36 | nn.ReLU(), 37 | nn.Linear(self.embed_dim, 1)) 38 | 39 | self.non_lin = F.elu 40 | if getattr(self.args, "mixer_non_lin", "elu") == "tanh": 41 | self.non_lin = F.tanh 42 | 43 | if hasattr(self.args, 'state_masks'): 44 | self.register_buffer('state_masks', th.tensor(self.args.state_masks)) 45 | 46 | def forward(self, agent_qs, states, imagine_groups=None): 47 | bs, max_t, sd = states.shape 48 | 49 | states = states.reshape(-1, self.state_dim) 50 | if imagine_groups is not None: 51 | ne = self.state_masks.shape[0] 52 | agent_qs = agent_qs.view(-1, 1, self.n_agents * 2) 53 | groupA, groupB = imagine_groups 54 | groupA = groupA.reshape(bs * max_t, ne, 1) 55 | groupB = groupB.reshape(bs * max_t, ne, 1) 56 | groupA_mask = (groupA * self.state_masks.reshape(1, ne, sd)).sum(dim=1).clamp_max(1) 57 | groupB_mask = (groupB * self.state_masks.reshape(1, ne, sd)).sum(dim=1).clamp_max(1) 58 | groupA_states = states * groupA_mask 59 | groupB_states = states * groupB_mask 60 | 61 | w1_A = self.hyper_w_1(groupA_states) 62 | w1_B = self.hyper_w_1(groupB_states) 63 | w1 = th.cat([w1_A, w1_B], dim=1) 64 | else: 65 | agent_qs = agent_qs.view(-1, 1, self.n_agents) 66 | w1 = self.hyper_w_1(states) 67 | # First layer 68 | b1 = self.hyper_b_1(states) 69 | w1 = w1.view(bs * max_t, -1, self.embed_dim) 70 | b1 = b1.view(-1, 1, self.embed_dim) 71 | if self.args.softmax_mixing_weights: 72 | w1 = F.softmax(w1, dim=-1) 73 | else: 74 | w1 = th.abs(w1) 75 | 76 | hidden = self.non_lin(th.bmm(agent_qs, w1) + b1) 77 | # Second layer 78 | if self.args.softmax_mixing_weights: 79 | w_final = F.softmax(self.hyper_w_final(states), dim=-1) 80 | else: 81 | w_final = th.abs(self.hyper_w_final(states)) 82 | w_final = w_final.view(-1, self.embed_dim, 1) 83 | # State-dependent bias 84 | v = self.V(states).view(-1, 1, 1) 85 | 86 | # Compute final output 87 | y = th.bmm(hidden, w_final) + v 88 | # Reshape and return 89 | q_tot = y.view(bs, -1, 1) 90 | return q_tot 91 | -------------------------------------------------------------------------------- /src/controllers/basic_controller.py: -------------------------------------------------------------------------------- 1 | from modules.agents import REGISTRY as agent_REGISTRY 2 | from components.action_selectors import REGISTRY as action_REGISTRY 3 | import torch as th 4 | 5 | 6 | # This multi-agent controller shares parameters between agents 7 | class BasicMAC: 8 | def __init__(self, scheme, groups, args): 9 | self.n_agents = args.n_agents 10 | self.args = args 11 | input_shape = self._get_input_shape(scheme) 12 | self._build_agents(input_shape) 13 | self.agent_output_type = args.agent_output_type 14 | 15 | self.action_selector = action_REGISTRY[args.action_selector](args) 16 | 17 | self.hidden_states = None 18 | 19 | def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False, ret_agent_outs=False): 20 | # Only select actions for the selected batch elements in bs 21 | avail_actions = ep_batch["avail_actions"][:, t_ep] 22 | agent_outputs = self.forward(ep_batch, t_ep, test_mode=test_mode) 23 | chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode) 24 | if ret_agent_outs: 25 | return chosen_actions, agent_outputs[bs] 26 | return chosen_actions 27 | 28 | def forward(self, ep_batch, t, test_mode=False, **kwargs): 29 | if t is None: 30 | t = slice(0, ep_batch["avail_actions"].shape[1]) 31 | int_t = False 32 | elif type(t) is int: 33 | t = slice(t, t + 1) 34 | int_t = True 35 | 36 | agent_inputs = self._build_inputs(ep_batch, t) 37 | avail_actions = ep_batch["avail_actions"][:, t] 38 | if kwargs.get('imagine', False): 39 | agent_outs, self.hidden_states, groups = self.agent(agent_inputs, self.hidden_states, **kwargs) 40 | else: 41 | agent_outs, self.hidden_states = self.agent(agent_inputs, self.hidden_states) 42 | 43 | if self.agent_output_type == "pi_logits": 44 | 45 | if getattr(self.args, "mask_before_softmax", True): 46 | # Make the logits for unavailable actions very negative to minimise their affect on the softmax 47 | agent_outs[avail_actions == 0] = -1e10 48 | 49 | agent_outs = th.nn.functional.softmax(agent_outs, dim=-1) 50 | if not test_mode: 51 | # Epsilon floor 52 | epsilon_action_num = agent_outs.size(-1) 53 | if getattr(self.args, "mask_before_softmax", True): 54 | # With probability epsilon, we will pick an available action uniformly 55 | epsilon_action_num = avail_actions.sum(dim=-1, keepdim=True).float() 56 | 57 | agent_outs = ((1 - self.action_selector.epsilon) * agent_outs 58 | + th.ones_like(agent_outs) * self.action_selector.epsilon/epsilon_action_num) 59 | 60 | if getattr(self.args, "mask_before_softmax", True): 61 | # Zero out the unavailable actions 62 | agent_outs[avail_actions == 0] = 0.0 63 | if int_t: 64 | return agent_outs.squeeze(1) 65 | if kwargs.get('imagine', False): 66 | return agent_outs, groups 67 | return agent_outs 68 | 69 | def init_hidden(self, batch_size): 70 | self.hidden_states = self.agent.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1) # bav 71 | 72 | def parameters(self): 73 | return self.agent.parameters() 74 | 75 | def load_state(self, other_mac): 76 | self.agent.load_state_dict(other_mac.agent.state_dict()) 77 | 78 | def cuda(self): 79 | self.agent.cuda() 80 | 81 | def eval(self): 82 | self.agent.eval() 83 | 84 | def train(self): 85 | self.agent.train() 86 | 87 | def save_models(self, path): 88 | th.save(self.agent.state_dict(), "{}/agent.th".format(path)) 89 | 90 | def load_models(self, path): 91 | self.agent.load_state_dict(th.load("{}/agent.th".format(path), map_location=lambda storage, loc: storage)) 92 | 93 | def _build_agents(self, input_shape): 94 | self.agent = agent_REGISTRY[self.args.agent](input_shape, self.args) 95 | 96 | def _build_inputs(self, batch, t): 97 | # Assumes homogenous agents with flat observations. 98 | # Other MACs might want to e.g. delegate building inputs to each agent 99 | bs, ts, na, os = batch["obs"].shape 100 | inputs = [] 101 | inputs.append(batch["obs"][:, t]) # btav 102 | if self.args.obs_last_action: 103 | if t.start == 0: 104 | acs = th.zeros_like(batch["actions_onehot"][:, t]) 105 | acs[:, 1:] = batch["actions_onehot"][:, slice(0, t.stop - 1)] 106 | else: 107 | acs = batch["actions_onehot"][:, slice(t.start - 1, t.stop - 1)] 108 | inputs.append(acs) 109 | if self.args.obs_agent_id: 110 | inputs.append(th.eye(self.n_agents, device=batch.device).view(1, 1, self.n_agents, self.n_agents).expand(bs, t.stop - t.start, -1, -1)) 111 | inputs = th.cat(inputs, dim=3) 112 | return inputs 113 | 114 | def _get_input_shape(self, scheme): 115 | input_shape = scheme["obs"]["vshape"] 116 | if self.args.obs_last_action: 117 | input_shape += scheme["actions_onehot"]["vshape"][0] 118 | if self.args.obs_agent_id: 119 | input_shape += self.n_agents 120 | 121 | return input_shape 122 | -------------------------------------------------------------------------------- /src/envs/group_matching/group_matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from ..multiagentenv import MultiAgentEnv 3 | 4 | 5 | class GroupMatching(MultiAgentEnv): 6 | def __init__(self, entity_scheme=True, n_agents=4, n_states=10, n_groups=2, 7 | rand_trans=0.1, episode_limit=50, fixed_scen=False, seed=None): 8 | super(GroupMatching, self).__init__() 9 | assert entity_scheme, "This environment only supports entity scheme" 10 | self.n_agents = n_agents 11 | self.n_states = n_states 12 | self.n_groups = n_groups 13 | self.rand_trans = rand_trans 14 | self.episode_limit = episode_limit 15 | self.fixed_scen = fixed_scen 16 | self.n_actions = 3 # left, stay, right 17 | self.seed(seed) 18 | 19 | def step(self, actions): 20 | """ Returns reward, terminated, info """ 21 | actions = [int(a) for a in actions[:self.n_agents]] 22 | for ia, ac in enumerate(actions): 23 | if self.random.uniform() < self.rand_trans: 24 | ac = self.random.randint(0, self.n_actions) 25 | if ac != 1: # if not stay action 26 | curr_loc = np.where(self.agent_locs[ia])[0].item() 27 | self.agent_locs[ia, curr_loc] = 0 28 | if ac == 0: # left 29 | # negative indices will automatically circle to end 30 | self.agent_locs[ia, curr_loc - 1] = 1 31 | elif ac == 2: # right 32 | next_loc = curr_loc + 1 33 | if next_loc >= self.n_states: 34 | next_loc -= self.n_states 35 | self.agent_locs[ia, next_loc] = 1 36 | 37 | curr_matches = self._calc_group_piles() 38 | rew = -0.1 # time penalty 39 | rew += 2.5 * (curr_matches - self.prev_matches) 40 | self.prev_matches = curr_matches 41 | 42 | info = {'solved': False} 43 | done = False 44 | if curr_matches == self.n_groups: 45 | done = True 46 | info['solved'] = True 47 | 48 | self.t += 1 49 | if self.t == self.episode_limit: 50 | done = True 51 | info['episode_limit'] = True 52 | 53 | return rew, done, info 54 | 55 | def get_masks(self): 56 | obs_mask = np.zeros((self.n_agents, self.n_agents), dtype=np.uint8) 57 | entity_mask = np.zeros(self.n_agents, dtype=np.uint8) 58 | gt_mask = np.ones((self.n_agents, self.n_agents), dtype=np.uint8) 59 | for ia in range(self.n_agents): 60 | for grp in self.agent_groups: 61 | if ia in grp: 62 | gt_mask[ia, grp] = 0 63 | break 64 | return obs_mask, entity_mask, gt_mask 65 | 66 | def get_entities(self): 67 | locs = self.agent_locs.copy() 68 | groups = np.zeros((self.n_agents, self.n_groups), dtype=np.float32) 69 | for ig, grp in enumerate(self.agent_groups): 70 | groups[grp, ig] = 1 71 | agent_ids = np.eye(self.n_agents, dtype=np.float32) 72 | entities = np.concatenate((locs, groups, agent_ids), axis=1) 73 | return [entities[i] for i in range(self.n_agents)] 74 | 75 | def get_entity_size(self): 76 | return self.n_states + self.n_groups + self.n_agents 77 | 78 | def get_avail_actions(self): 79 | return [[1 for _ in range(self.n_actions)] for _ in range(self.n_agents)] 80 | 81 | def get_total_actions(self): 82 | """ Returns the total number of actions an agent could ever take """ 83 | return self.n_actions 84 | 85 | def get_stats(self): 86 | return {} 87 | 88 | def get_agg_stats(self, stats): 89 | return {} 90 | 91 | def reset(self, **kwargs): 92 | agents = list(range(self.n_agents)) 93 | if not self.fixed_scen: 94 | self.random.shuffle(agents) 95 | partitions = [0] + self.random.randint(0, self.n_agents, size=(self.n_groups - 1,)).tolist() + [self.n_agents] 96 | else: 97 | partitions = np.linspace(0, self.n_agents, self.n_groups + 1).round().astype(np.int).tolist() 98 | self.agent_groups = [agents[s:e] for s, e in zip(partitions[:-1], partitions[1:])] 99 | 100 | self.agent_locs = np.zeros((self.n_agents, self.n_states), dtype=np.float32) 101 | self.agent_locs[range(self.n_agents), self.random.randint(0, self.n_states, size=self.n_agents)] = 1 102 | 103 | self.prev_matches = self._calc_group_piles() 104 | 105 | self.t = 0 106 | return self.get_entities(), self.get_masks() 107 | 108 | def _calc_group_piles(self): 109 | return sum(self.agent_locs[g].sum(0).max() == len(g) for g in self.agent_groups) 110 | 111 | def close(self): 112 | return 113 | 114 | def seed(self, seed): 115 | if seed is None: 116 | self.random = np.random.RandomState() 117 | else: 118 | self.random = np.random.RandomState(seed) 119 | 120 | def get_env_info(self, args): 121 | env_info = {"entity_shape": self.get_entity_size(), 122 | "n_actions": self.get_total_actions(), 123 | "n_agents": self.n_agents, 124 | "n_entities": self.n_agents, 125 | "gt_mask_avail": True, 126 | "episode_limit": self.episode_limit} 127 | return env_info 128 | 129 | 130 | if __name__ == '__main__': 131 | env = GroupMatching(entity_scheme=True, n_agents=4, n_states=10, n_groups=2, 132 | rand_trans=0.1, episode_limit=50, seed=None) 133 | env.reset() 134 | done = False 135 | -------------------------------------------------------------------------------- /src/modules/agents/entity_rnn_agent.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules.layers import EntityAttentionLayer, EntityPoolingLayer 5 | 6 | 7 | class EntityAttentionRNNAgent(nn.Module): 8 | def __init__(self, input_shape, args): 9 | super(EntityAttentionRNNAgent, self).__init__() 10 | self.args = args 11 | 12 | self.fc1 = nn.Linear(input_shape, args.attn_embed_dim) 13 | if args.pooling_type is None: 14 | self.attn = EntityAttentionLayer(args.attn_embed_dim, 15 | args.attn_embed_dim, 16 | args.attn_embed_dim, args) 17 | else: 18 | self.attn = EntityPoolingLayer(args.attn_embed_dim, 19 | args.attn_embed_dim, 20 | args.attn_embed_dim, 21 | args.pooling_type, 22 | args) 23 | self.fc2 = nn.Linear(args.attn_embed_dim, args.rnn_hidden_dim) 24 | self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) 25 | self.fc3 = nn.Linear(args.rnn_hidden_dim, args.n_actions) 26 | 27 | def init_hidden(self): 28 | # make hidden states on same device as model 29 | return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() 30 | 31 | def forward(self, inputs, hidden_state, ret_attn_logits=None): 32 | entities, obs_mask, entity_mask = inputs 33 | bs, ts, ne, ed = entities.shape 34 | entities = entities.reshape(bs * ts, ne, ed) 35 | obs_mask = obs_mask.reshape(bs * ts, ne, ne) 36 | entity_mask = entity_mask.reshape(bs * ts, ne) 37 | agent_mask = entity_mask[:, :self.args.n_agents] 38 | x1 = F.relu(self.fc1(entities)) 39 | attn_outs = self.attn(x1, pre_mask=obs_mask, 40 | post_mask=agent_mask, 41 | ret_attn_logits=ret_attn_logits) 42 | if ret_attn_logits is not None: 43 | x2, attn_logits = attn_outs 44 | else: 45 | x2 = attn_outs 46 | x3 = F.relu(self.fc2(x2)) 47 | x3 = x3.reshape(bs, ts, self.args.n_agents, -1) 48 | 49 | h = hidden_state.reshape(-1, self.args.rnn_hidden_dim) 50 | hs = [] 51 | for t in range(ts): 52 | curr_x3 = x3[:, t].reshape(-1, self.args.rnn_hidden_dim) 53 | h = self.rnn(curr_x3, h) 54 | hs.append(h.reshape(bs, self.args.n_agents, self.args.rnn_hidden_dim)) 55 | hs = th.stack(hs, dim=1) # Concat over time 56 | 57 | q = self.fc3(hs) 58 | # zero out output for inactive agents 59 | q = q.reshape(bs, ts, self.args.n_agents, -1) 60 | q = q.masked_fill(agent_mask.reshape(bs, ts, self.args.n_agents, 1), 0) 61 | # q = q.reshape(bs * self.args.n_agents, -1) 62 | if ret_attn_logits is not None: 63 | return q, h, attn_logits.reshape(bs, ts, self.args.n_agents, ne) 64 | return q, hs 65 | 66 | 67 | class ImagineEntityAttentionRNNAgent(EntityAttentionRNNAgent): 68 | def __init__(self, *args, **kwargs): 69 | super(ImagineEntityAttentionRNNAgent, self).__init__(*args, **kwargs) 70 | 71 | def logical_not(self, inp): 72 | return 1 - inp 73 | 74 | def logical_or(self, inp1, inp2): 75 | out = inp1 + inp2 76 | out[out > 1] = 1 77 | return out 78 | 79 | def entitymask2attnmask(self, entity_mask): 80 | bs, ts, ne = entity_mask.shape 81 | # agent_mask = entity_mask[:, :, :self.args.n_agents] 82 | in1 = (1 - entity_mask.to(th.float)).reshape(bs * ts, ne, 1) 83 | in2 = (1 - entity_mask.to(th.float)).reshape(bs * ts, 1, ne) 84 | attn_mask = 1 - th.bmm(in1, in2) 85 | return attn_mask.reshape(bs, ts, ne, ne).to(th.uint8) 86 | 87 | def forward(self, inputs, hidden_state, imagine=False, **kwargs): 88 | if not imagine: 89 | return super(ImagineEntityAttentionRNNAgent, self).forward(inputs, hidden_state) 90 | entities, obs_mask, entity_mask = inputs 91 | bs, ts, ne, ed = entities.shape 92 | 93 | # create random split of entities (once per episode) 94 | groupA_probs = th.rand(bs, 1, 1, device=entities.device).repeat(1, 1, ne) 95 | 96 | groupA = th.bernoulli(groupA_probs).to(th.uint8) 97 | groupB = self.logical_not(groupA) 98 | # mask out entities not present in env 99 | groupA = self.logical_or(groupA, entity_mask[:, [0]]) 100 | groupB = self.logical_or(groupB, entity_mask[:, [0]]) 101 | 102 | # convert entity mask to attention mask 103 | groupAattnmask = self.entitymask2attnmask(groupA) 104 | groupBattnmask = self.entitymask2attnmask(groupB) 105 | # create attention mask for interactions between groups 106 | interactattnmask = self.logical_or(self.logical_not(groupAattnmask), 107 | self.logical_not(groupBattnmask)) 108 | # get within group attention mask 109 | withinattnmask = self.logical_not(interactattnmask) 110 | 111 | activeattnmask = self.entitymask2attnmask(entity_mask[:, [0]]) 112 | # get masks to use for mixer (no obs_mask but mask out unused entities) 113 | Wattnmask_noobs = self.logical_or(withinattnmask, activeattnmask) 114 | Iattnmask_noobs = self.logical_or(interactattnmask, activeattnmask) 115 | # mask out agents that aren't observable (also expands time dim due to shape of obs_mask) 116 | withinattnmask = self.logical_or(withinattnmask, obs_mask) 117 | interactattnmask = self.logical_or(interactattnmask, obs_mask) 118 | 119 | entities = entities.repeat(3, 1, 1, 1) 120 | obs_mask = th.cat([obs_mask, withinattnmask, interactattnmask], dim=0) 121 | entity_mask = entity_mask.repeat(3, 1, 1) 122 | 123 | inputs = (entities, obs_mask, entity_mask) 124 | hidden_state = hidden_state.repeat(3, 1, 1) 125 | q, h = super(ImagineEntityAttentionRNNAgent, self).forward(inputs, hidden_state) 126 | return q, h, (Wattnmask_noobs.repeat(1, ts, 1, 1), Iattnmask_noobs.repeat(1, ts, 1, 1)) 127 | -------------------------------------------------------------------------------- /src/modules/layers/attention.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class EntityAttentionLayer(nn.Module): 7 | def __init__(self, in_dim, embed_dim, out_dim, args): 8 | super(EntityAttentionLayer, self).__init__() 9 | self.in_dim = in_dim 10 | self.embed_dim = embed_dim 11 | self.out_dim = out_dim 12 | self.n_heads = args.attn_n_heads 13 | self.n_agents = args.n_agents 14 | self.args = args 15 | 16 | assert self.embed_dim % self.n_heads == 0, "Embed dim must be divisible by n_heads" 17 | self.head_dim = self.embed_dim // self.n_heads 18 | self.register_buffer('scale_factor', 19 | th.scalar_tensor(self.head_dim).sqrt()) 20 | 21 | self.in_trans = nn.Linear(self.in_dim, self.embed_dim * 3, bias=False) 22 | self.out_trans = nn.Linear(self.embed_dim, self.out_dim) 23 | 24 | def forward(self, entities, pre_mask=None, post_mask=None, ret_attn_logits=None): 25 | """ 26 | entities: Entity representations 27 | shape: batch size, # of entities, embedding dimension 28 | pre_mask: Which agent-entity pairs are not available (observability and/or padding). 29 | Mask out before attention. 30 | shape: batch_size, # of agents, # of entities 31 | post_mask: Which agents/entities are not available. Zero out their outputs to 32 | prevent gradients from flowing back. Shape of 2nd dim determines 33 | whether to compute queries for all entities or just agents. 34 | shape: batch size, # of agents (or entities) 35 | ret_attn_logits: whether to return attention logits 36 | None: do not return 37 | "max": take max over heads 38 | "mean": take mean over heads 39 | 40 | Return shape: batch size, # of agents, embedding dimension 41 | """ 42 | entities_t = entities.transpose(0, 1) 43 | n_queries = post_mask.shape[1] 44 | pre_mask = pre_mask[:, :n_queries] 45 | ne, bs, ed = entities_t.shape 46 | query, key, value = self.in_trans(entities_t).chunk(3, dim=2) 47 | 48 | query = query[:n_queries] 49 | 50 | query_spl = query.reshape(n_queries, bs * self.n_heads, self.head_dim).transpose(0, 1) 51 | key_spl = key.reshape(ne, bs * self.n_heads, self.head_dim).permute(1, 2, 0) 52 | value_spl = value.reshape(ne, bs * self.n_heads, self.head_dim).transpose(0, 1) 53 | 54 | attn_logits = th.bmm(query_spl, key_spl) / self.scale_factor 55 | if pre_mask is not None: 56 | pre_mask_rep = pre_mask.repeat_interleave(self.n_heads, dim=0) 57 | masked_attn_logits = attn_logits.masked_fill(pre_mask_rep[:, :, :ne], -float('Inf')) 58 | attn_weights = F.softmax(masked_attn_logits, dim=2) 59 | # some weights might be NaN (if agent is inactive and all entities were masked) 60 | attn_weights = attn_weights.masked_fill(attn_weights != attn_weights, 0) 61 | attn_outs = th.bmm(attn_weights, value_spl) 62 | attn_outs = attn_outs.transpose( 63 | 0, 1).reshape(n_queries, bs, self.embed_dim) 64 | attn_outs = attn_outs.transpose(0, 1) 65 | attn_outs = self.out_trans(attn_outs) 66 | if post_mask is not None: 67 | attn_outs = attn_outs.masked_fill(post_mask.unsqueeze(2), 0) 68 | if ret_attn_logits is not None: 69 | # bs * n_heads, nq, ne 70 | attn_logits = attn_logits.reshape(bs, self.n_heads, 71 | n_queries, ne) 72 | if ret_attn_logits == 'max': 73 | attn_logits = attn_logits.max(dim=1)[0] 74 | elif ret_attn_logits == 'mean': 75 | attn_logits = attn_logits.mean(dim=1) 76 | elif ret_attn_logits == 'norm': 77 | attn_logits = attn_logits.mean(dim=1) 78 | return attn_outs, attn_logits 79 | return attn_outs 80 | 81 | 82 | class EntityPoolingLayer(nn.Module): 83 | def __init__(self, in_dim, embed_dim, out_dim, pooling_type, args): 84 | super(EntityPoolingLayer, self).__init__() 85 | self.in_dim = in_dim 86 | self.embed_dim = embed_dim 87 | self.out_dim = out_dim 88 | self.pooling_type = pooling_type 89 | self.n_agents = args.n_agents 90 | self.args = args 91 | 92 | self.in_trans = nn.Linear(self.in_dim, self.embed_dim) 93 | self.out_trans = nn.Linear(self.embed_dim, self.out_dim) 94 | 95 | def forward(self, entities, pre_mask=None, post_mask=None, ret_attn_logits=None): 96 | """ 97 | entities: Entity representations 98 | shape: batch size, # of entities, embedding dimension 99 | pre_mask: Which agent-entity pairs are not available (observability and/or padding). 100 | Mask out before pooling. 101 | shape: batch_size, # of agents, # of entities 102 | post_mask: Which agents are not available. Zero out their outputs to 103 | prevent gradients from flowing back. 104 | shape: batch size, # of agents 105 | ret_attn_logits: not used, here to match attention layer args 106 | 107 | Return shape: batch size, # of agents, embedding dimension 108 | """ 109 | bs, ne, ed = entities.shape 110 | 111 | ents_trans = self.in_trans(entities) 112 | n_queries = post_mask.shape[1] 113 | pre_mask = pre_mask[:, :n_queries] 114 | # duplicate all entities per agent so we can mask separately 115 | ents_trans_rep = ents_trans.reshape(bs, 1, ne, ed).repeat(1, self.n_agents, 1, 1) 116 | 117 | if pre_mask is not None: 118 | ents_trans_rep = ents_trans_rep.masked_fill(pre_mask.unsqueeze(3), 0) 119 | 120 | if self.pooling_type == 'max': 121 | pool_outs = ents_trans_rep.max(dim=2)[0] 122 | elif self.pooling_type == 'mean': 123 | pool_outs = ents_trans_rep.mean(dim=2) 124 | 125 | pool_outs = self.out_trans(pool_outs) 126 | 127 | if post_mask is not None: 128 | pool_outs = pool_outs.masked_fill(post_mask.unsqueeze(2), 0) 129 | 130 | if ret_attn_logits is not None: 131 | return pool_outs, None 132 | return pool_outs 133 | -------------------------------------------------------------------------------- /src/runners/episode_runner.py: -------------------------------------------------------------------------------- 1 | from envs import REGISTRY as env_REGISTRY 2 | from functools import partial 3 | from components.episode_buffer import EpisodeBatch 4 | import numpy as np 5 | 6 | 7 | class EpisodeRunner: 8 | 9 | def __init__(self, args, logger): 10 | self.args = args 11 | self.logger = logger 12 | self.batch_size = self.args.batch_size_run 13 | assert self.batch_size == 1 14 | 15 | if ('sc2' in self.args.env) or ('group_matching' in self.args.env): 16 | self.env = env_REGISTRY[self.args.env](**self.args.env_args) 17 | else: 18 | self.env = env_REGISTRY[self.args.env](env_args=self.args.env_args, args=args) 19 | 20 | self.episode_limit = self.env.episode_limit 21 | self.t = 0 22 | 23 | self.t_env = 0 24 | 25 | self.train_returns = [] 26 | self.test_returns = [] 27 | self.train_stats = {} 28 | self.test_stats = {} 29 | 30 | # Log the first run 31 | self.log_train_stats_t = -1000000 32 | 33 | def setup(self, scheme, groups, preprocess, mac): 34 | self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, 35 | preprocess=preprocess, device=self.args.device) 36 | self.mac = mac 37 | 38 | def get_env_info(self): 39 | return self.env.get_env_info(self.args) 40 | 41 | def save_replay(self): 42 | self.env.save_replay() 43 | 44 | def close_env(self): 45 | self.env.close() 46 | 47 | def reset(self, test=False, index=None): 48 | self.batch = self.new_batch() 49 | self.env.reset(test=test, index=index) 50 | self.t = 0 51 | 52 | def _get_pre_transition_data(self): 53 | if self.args.entity_scheme: 54 | masks = self.env.get_masks() 55 | if len(masks) == 2: 56 | obs_mask, entity_mask = masks 57 | gt_mask = None 58 | else: 59 | obs_mask, entity_mask, gt_mask = masks 60 | pre_transition_data = { 61 | "entities": [self.env.get_entities()], 62 | "obs_mask": [obs_mask], 63 | "entity_mask": [entity_mask], 64 | "avail_actions": [self.env.get_avail_actions()] 65 | } 66 | if gt_mask is not None: 67 | pre_transition_data["gt_mask"] = gt_mask 68 | else: 69 | pre_transition_data = { 70 | "state": [self.env.get_state()], 71 | "avail_actions": [self.env.get_avail_actions()], 72 | "obs": [self.env.get_obs()] 73 | } 74 | return pre_transition_data 75 | 76 | def run(self, test_mode=False, test_scen=None, index=None, vid_writer=None): 77 | """ 78 | test_mode: whether to use greedy action selection or sample actions 79 | test_scen: whether to run on test scenarios. defaults to matching test_mode. 80 | vid_writer: imageio video writer object 81 | """ 82 | if test_scen is None: 83 | test_scen = test_mode 84 | self.reset(test=test_scen, index=index) 85 | if vid_writer is not None: 86 | vid_writer.append_data(self.env.render()) 87 | terminated = False 88 | episode_return = 0 89 | self.mac.init_hidden(batch_size=self.batch_size) 90 | # make sure things like dropout are disabled 91 | self.mac.eval() 92 | 93 | while not terminated: 94 | pre_transition_data = self._get_pre_transition_data() 95 | 96 | self.batch.update(pre_transition_data, ts=self.t) 97 | 98 | # Pass the entire batch of experiences up till now to the agents 99 | # Receive the actions for each agent at this timestep in a batch of size 1 100 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode) 101 | 102 | reward, terminated, env_info = self.env.step(actions[0].cpu()) 103 | if vid_writer is not None: 104 | vid_writer.append_data(self.env.render()) 105 | episode_return += reward 106 | 107 | post_transition_data = { 108 | "actions": actions, 109 | "reward": [(reward,)], 110 | "terminated": [(terminated != env_info.get("episode_limit", False),)], 111 | } 112 | 113 | self.batch.update(post_transition_data, ts=self.t) 114 | 115 | self.t += 1 116 | 117 | last_data = self._get_pre_transition_data() 118 | self.batch.update(last_data, ts=self.t) 119 | 120 | # Select actions in the last stored state 121 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode) 122 | self.batch.update({"actions": actions}, ts=self.t) 123 | 124 | cur_stats = self.test_stats if test_mode else self.train_stats 125 | cur_returns = self.test_returns if test_mode else self.train_returns 126 | log_prefix = "test_" if test_mode else "" 127 | cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) | set(env_info)}) 128 | cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0) 129 | cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0) 130 | 131 | if not test_mode: 132 | self.t_env += self.t 133 | 134 | cur_returns.append(episode_return) 135 | 136 | if test_mode and (len(self.test_returns) == self.args.test_nepisode): 137 | self._log(cur_returns, cur_stats, log_prefix) 138 | elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval: 139 | self._log(cur_returns, cur_stats, log_prefix) 140 | if hasattr(self.mac.action_selector, "epsilon"): 141 | self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env) 142 | self.log_train_stats_t = self.t_env 143 | 144 | return self.batch 145 | 146 | def _log(self, returns, stats, prefix): 147 | self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env) 148 | self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env) 149 | returns.clear() 150 | 151 | for k, v in stats.items(): 152 | if k != "n_episodes": 153 | self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env) 154 | stats.clear() 155 | -------------------------------------------------------------------------------- /src/modules/agents/entity_ff_agent.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ..layers import EntityAttentionLayer, EntityPoolingLayer 5 | 6 | 7 | class EntityAttentionFFAgent(nn.Module): 8 | def __init__(self, input_shape, args): 9 | super(EntityAttentionFFAgent, self).__init__() 10 | self.args = args 11 | 12 | self.fc1 = nn.Linear(input_shape, args.attn_embed_dim) 13 | if args.pooling_type is None: 14 | self.attn = EntityAttentionLayer(args.attn_embed_dim, 15 | args.attn_embed_dim, 16 | args.attn_embed_dim, args) 17 | else: 18 | self.attn = EntityPoolingLayer(args.attn_embed_dim, 19 | args.attn_embed_dim, 20 | args.attn_embed_dim, 21 | args.pooling_type, 22 | args) 23 | self.fc2 = nn.Linear(args.attn_embed_dim, args.n_actions) 24 | 25 | def init_hidden(self): 26 | # make hidden states on same device as model 27 | return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() 28 | 29 | def forward(self, inputs, hidden_state, ret_attn_logits=None): 30 | if len(inputs) == 3: 31 | entities, obs_mask, entity_mask = inputs 32 | else: 33 | entities, obs_mask, entity_mask, gt_mask = inputs 34 | if self.args.gt_obs_mask: 35 | obs_mask = gt_mask 36 | bs, ts, ne, ed = entities.shape 37 | entities = entities.reshape(bs * ts, ne, ed) 38 | obs_mask = obs_mask.reshape(bs * ts, ne, ne) 39 | entity_mask = entity_mask.reshape(bs * ts, ne) 40 | agent_mask = entity_mask[:, :self.args.n_agents] 41 | x1 = F.relu(self.fc1(entities)) 42 | attn_outs = self.attn(x1, pre_mask=obs_mask, 43 | post_mask=agent_mask, 44 | ret_attn_logits=ret_attn_logits) 45 | attn_outs = F.relu(attn_outs) 46 | if ret_attn_logits is not None: 47 | x2, attn_logits = attn_outs 48 | else: 49 | x2 = attn_outs 50 | q = self.fc2(x2) 51 | # zero out output for inactive agents 52 | q = q.reshape(bs, ts, self.args.n_agents, -1) 53 | q = q.masked_fill(agent_mask.reshape(bs, ts, self.args.n_agents, 1), 0) 54 | # q = q.reshape(bs * self.args.n_agents, -1) 55 | if ret_attn_logits is not None: 56 | return q, attn_outs, attn_logits.reshape(bs, ts, self.args.n_agents, ne) 57 | return q, attn_outs 58 | 59 | 60 | class ImagineEntityAttentionFFAgent(EntityAttentionFFAgent): 61 | def __init__(self, *args, **kwargs): 62 | super(ImagineEntityAttentionFFAgent, self).__init__(*args, **kwargs) 63 | 64 | def logical_not(self, inp): 65 | return 1 - inp 66 | 67 | def logical_or(self, inp1, inp2): 68 | out = inp1 + inp2 69 | out[out > 1] = 1 70 | return out 71 | 72 | def entitymask2attnmask(self, entity_mask): 73 | bs, ts, ne = entity_mask.shape 74 | agent_mask = entity_mask[:, :, :self.args.n_agents] 75 | in1 = (1 - agent_mask.to(th.float)).reshape(bs * ts, self.args.n_agents, 1) 76 | in2 = (1 - entity_mask.to(th.float)).reshape(bs * ts, 1, ne) 77 | attn_mask = 1 - th.bmm(in1, in2) 78 | return attn_mask.reshape(bs, ts, self.args.n_agents, ne).to(th.uint8) 79 | 80 | def forward(self, inputs, hidden_state, imagine=False, use_gt_factors=False, use_rand_gt_factors=False): 81 | if not imagine: 82 | return super(ImagineEntityAttentionFFAgent, self).forward(inputs, hidden_state) 83 | if len(inputs) == 3: 84 | entities, obs_mask, entity_mask = inputs 85 | else: 86 | entities, obs_mask, entity_mask, gt_mask = inputs 87 | bs, ts, ne, ed = entities.shape 88 | 89 | # create random split of entities (once per episode) 90 | groupA_probs = th.rand(bs, 1, 1, device=entities.device).repeat(1, 1, ne) 91 | 92 | activeattnmask = self.entitymask2attnmask(entity_mask[:, [0]]) 93 | if use_gt_factors: 94 | withinattnmask = gt_mask 95 | interactattnmask = self.logical_not(withinattnmask) 96 | else: 97 | groupA = th.bernoulli(groupA_probs).to(th.uint8) 98 | groupB = self.logical_not(groupA) 99 | # mask out entities not present in env 100 | groupA = self.logical_or(groupA, entity_mask[:, [0]]) 101 | groupB = self.logical_or(groupB, entity_mask[:, [0]]) 102 | 103 | # convert entity mask to attention mask 104 | groupAattnmask = self.entitymask2attnmask(groupA) 105 | groupBattnmask = self.entitymask2attnmask(groupB) 106 | # create attention mask for interactions between groups 107 | interactattnmask = self.logical_or(self.logical_not(groupAattnmask), 108 | self.logical_not(groupBattnmask)) 109 | # get within group attention mask 110 | withinattnmask = self.logical_not(interactattnmask) 111 | if use_rand_gt_factors: 112 | assert not use_gt_factors, "Can only select one of use_rand_gt_factors and use_gt_factors" 113 | withinattnmask = self.logical_or(withinattnmask, gt_mask) 114 | interactattnmask = self.logical_not(withinattnmask) 115 | 116 | 117 | # get masks to use for mixer (no obs_mask but mask out unused entities) 118 | Wattnmask_noobs = self.logical_or(withinattnmask, activeattnmask) 119 | Iattnmask_noobs = self.logical_or(interactattnmask, activeattnmask) 120 | # mask out agents that aren't observable (also expands time dim due to shape of obs_mask) 121 | withinattnmask = self.logical_or(withinattnmask, obs_mask) 122 | interactattnmask = self.logical_or(interactattnmask, obs_mask) 123 | 124 | entities = entities.repeat(3, 1, 1, 1) 125 | obs_mask = th.cat([obs_mask, withinattnmask, interactattnmask], dim=0) 126 | entity_mask = entity_mask.repeat(3, 1, 1) 127 | 128 | inputs = (entities, obs_mask, entity_mask) 129 | hidden_state = hidden_state.repeat(3, 1, 1) 130 | q, h = super(ImagineEntityAttentionFFAgent, self).forward(inputs, hidden_state) 131 | if use_gt_factors or use_rand_gt_factors: 132 | rep_t = 1 133 | else: 134 | rep_t = ts 135 | return q, h, (Wattnmask_noobs.repeat(1, rep_t, 1, 1), Iattnmask_noobs.repeat(1, rep_t, 1, 1)) 136 | -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/smac_maps.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pysc2.maps import lib 6 | 7 | 8 | class SMACMap(lib.Map): 9 | directory = "SMAC_Maps" 10 | download = "https://github.com/oxwhirl/smac#smac-maps" 11 | players = 2 12 | step_mul = 8 13 | game_steps_per_episode = 0 14 | 15 | 16 | map_param_registry = { 17 | "3m": { 18 | "n_agents": 3, 19 | "n_enemies": 3, 20 | "limit": 60, 21 | "a_race": "T", 22 | "b_race": "T", 23 | "unit_type_bits": 0, 24 | "map_type": "marines", 25 | }, 26 | "8m": { 27 | "n_agents": 8, 28 | "n_enemies": 8, 29 | "limit": 120, 30 | "a_race": "T", 31 | "b_race": "T", 32 | "unit_type_bits": 0, 33 | "map_type": "marines", 34 | }, 35 | "25m": { 36 | "n_agents": 25, 37 | "n_enemies": 25, 38 | "limit": 150, 39 | "a_race": "T", 40 | "b_race": "T", 41 | "unit_type_bits": 0, 42 | "map_type": "marines", 43 | }, 44 | "5m_vs_6m": { 45 | "n_agents": 5, 46 | "n_enemies": 6, 47 | "limit": 70, 48 | "a_race": "T", 49 | "b_race": "T", 50 | "unit_type_bits": 0, 51 | "map_type": "marines", 52 | }, 53 | "8m_vs_9m": { 54 | "n_agents": 8, 55 | "n_enemies": 9, 56 | "limit": 120, 57 | "a_race": "T", 58 | "b_race": "T", 59 | "unit_type_bits": 0, 60 | "map_type": "marines", 61 | }, 62 | "10m_vs_11m": { 63 | "n_agents": 10, 64 | "n_enemies": 11, 65 | "limit": 150, 66 | "a_race": "T", 67 | "b_race": "T", 68 | "unit_type_bits": 0, 69 | "map_type": "marines", 70 | }, 71 | "27m_vs_30m": { 72 | "n_agents": 27, 73 | "n_enemies": 30, 74 | "limit": 180, 75 | "a_race": "T", 76 | "b_race": "T", 77 | "unit_type_bits": 0, 78 | "map_type": "marines", 79 | }, 80 | "MMM": { 81 | "n_agents": 10, 82 | "n_enemies": 10, 83 | "limit": 150, 84 | "a_race": "T", 85 | "b_race": "T", 86 | "unit_type_bits": 3, 87 | "map_type": "MMM", 88 | }, 89 | "MMM2": { 90 | "n_agents": 10, 91 | "n_enemies": 12, 92 | "limit": 180, 93 | "a_race": "T", 94 | "b_race": "T", 95 | "unit_type_bits": 3, 96 | "map_type": "MMM", 97 | }, 98 | "2s3z": { 99 | "n_agents": 5, 100 | "n_enemies": 5, 101 | "limit": 120, 102 | "a_race": "P", 103 | "b_race": "P", 104 | "unit_type_bits": 2, 105 | "map_type": "stalkers_and_zealots", 106 | }, 107 | "3s5z": { 108 | "n_agents": 8, 109 | "n_enemies": 8, 110 | "limit": 150, 111 | "a_race": "P", 112 | "b_race": "P", 113 | "unit_type_bits": 2, 114 | "map_type": "stalkers_and_zealots", 115 | }, 116 | "5s10z": { 117 | "n_agents": 15, 118 | "n_enemies": 15, 119 | "limit": 150, 120 | "a_race": "P", 121 | "b_race": "P", 122 | "unit_type_bits": 2, 123 | "map_type": "stalkers_and_zealots", 124 | }, 125 | "3s5z_vs_3s6z": { 126 | "n_agents": 8, 127 | "n_enemies": 9, 128 | "limit": 170, 129 | "a_race": "P", 130 | "b_race": "P", 131 | "unit_type_bits": 2, 132 | "map_type": "stalkers_and_zealots", 133 | }, 134 | "3s_vs_3z": { 135 | "n_agents": 3, 136 | "n_enemies": 3, 137 | "limit": 150, 138 | "a_race": "P", 139 | "b_race": "P", 140 | "unit_type_bits": 0, 141 | "map_type": "stalkers", 142 | }, 143 | "3s_vs_4z": { 144 | "n_agents": 3, 145 | "n_enemies": 4, 146 | "limit": 200, 147 | "a_race": "P", 148 | "b_race": "P", 149 | "unit_type_bits": 0, 150 | "map_type": "stalkers", 151 | }, 152 | "3s_vs_5z": { 153 | "n_agents": 3, 154 | "n_enemies": 5, 155 | "limit": 250, 156 | "a_race": "P", 157 | "b_race": "P", 158 | "unit_type_bits": 0, 159 | "map_type": "stalkers", 160 | }, 161 | "1c3s5z": { 162 | "n_agents": 9, 163 | "n_enemies": 9, 164 | "limit": 180, 165 | "a_race": "P", 166 | "b_race": "P", 167 | "unit_type_bits": 3, 168 | "map_type": "colossi_stalkers_zealots", 169 | }, 170 | "2m_vs_1z": { 171 | "n_agents": 2, 172 | "n_enemies": 1, 173 | "limit": 150, 174 | "a_race": "T", 175 | "b_race": "P", 176 | "unit_type_bits": 0, 177 | "map_type": "marines", 178 | }, 179 | "corridor": { 180 | "n_agents": 6, 181 | "n_enemies": 24, 182 | "limit": 400, 183 | "a_race": "P", 184 | "b_race": "Z", 185 | "unit_type_bits": 0, 186 | "map_type": "zealots", 187 | }, 188 | "6h_vs_8z": { 189 | "n_agents": 6, 190 | "n_enemies": 8, 191 | "limit": 150, 192 | "a_race": "Z", 193 | "b_race": "P", 194 | "unit_type_bits": 0, 195 | "map_type": "hydralisks", 196 | }, 197 | "2s_vs_1sc": { 198 | "n_agents": 2, 199 | "n_enemies": 1, 200 | "limit": 300, 201 | "a_race": "P", 202 | "b_race": "Z", 203 | "unit_type_bits": 0, 204 | "map_type": "stalkers", 205 | }, 206 | "so_many_baneling": { 207 | "n_agents": 7, 208 | "n_enemies": 32, 209 | "limit": 100, 210 | "a_race": "P", 211 | "b_race": "Z", 212 | "unit_type_bits": 0, 213 | "map_type": "zealots", 214 | }, 215 | "bane_vs_bane": { 216 | "n_agents": 24, 217 | "n_enemies": 24, 218 | "limit": 200, 219 | "a_race": "Z", 220 | "b_race": "Z", 221 | "unit_type_bits": 2, 222 | "map_type": "bane", 223 | }, 224 | "2c_vs_64zg": { 225 | "n_agents": 2, 226 | "n_enemies": 64, 227 | "limit": 400, 228 | "a_race": "P", 229 | "b_race": "Z", 230 | "unit_type_bits": 0, 231 | "map_type": "colossus", 232 | }, 233 | } 234 | 235 | 236 | def get_smac_map_registry(): 237 | return map_param_registry 238 | 239 | custom_maps = ["empty_passive", "empty_aggressive", "terran_vs_terran", "5m_vs_6m_alt", "5m_vs_6m_alt1xtra"] 240 | 241 | for name in list(map_param_registry.keys()) + custom_maps: 242 | globals()[name] = type(name, (SMACMap,), dict(filename=name)) 243 | -------------------------------------------------------------------------------- /src/envs/starcraft2/custom_scenarios.py: -------------------------------------------------------------------------------- 1 | from numpy.random import RandomState 2 | from os.path import dirname, join 3 | from functools import partial 4 | from itertools import combinations_with_replacement, product 5 | 6 | 7 | def get_all_unique_teams(all_types, min_len, max_len): 8 | all_uniq = [] 9 | for i in range(min_len, max_len + 1): 10 | all_uniq += list(combinations_with_replacement(all_types, i)) 11 | all_uniq_counts = [] 12 | for scen in all_uniq: 13 | curr_uniq = list(set(scen)) 14 | uniq_counts = list(zip([scen.count(u) for u in curr_uniq], curr_uniq)) 15 | all_uniq_counts.append(uniq_counts) 16 | return all_uniq_counts 17 | 18 | 19 | def fixed_armies(ally_army, enemy_army, ally_centered=False, rotate=False, 20 | separation=10, jitter=0, episode_limit=100, 21 | map_name="empty_passive", rs=None): 22 | scenario_dict = {'scenarios': [(ally_army, enemy_army)], 23 | 'max_types_and_units_scenario': (ally_army, enemy_army), 24 | 'ally_centered': ally_centered, 25 | 'rotate': rotate, 26 | 'separation': separation, 27 | 'jitter': jitter, 28 | 'episode_limit': episode_limit, 29 | 'map_name': map_name} 30 | return scenario_dict 31 | 32 | 33 | def symmetric_armies(army_spec, ally_centered=False, 34 | rotate=False, separation=10, 35 | jitter=0, episode_limit=100, map_name="empty_passive", 36 | n_extra_tags=0, 37 | rs=None): 38 | if rs is None: 39 | rs = RandomState() 40 | 41 | unique_sub_teams = [] 42 | for unit_types, n_unit_range in army_spec: 43 | unique_sub_teams.append(get_all_unique_teams(unit_types, n_unit_range[0], 44 | n_unit_range[1])) 45 | unique_teams = [sum(prod, []) for prod in product(*unique_sub_teams)] 46 | 47 | scenarios = list(zip(unique_teams, unique_teams)) 48 | # sort by number of types and total number of units 49 | max_types_and_units_team = sorted(unique_teams, key=lambda x: (len(x), sum(num for num, unit in x)), reverse=True)[0] 50 | max_types_and_units_scenario = (max_types_and_units_team, 51 | max_types_and_units_team) 52 | 53 | scenario_dict = {'scenarios': scenarios, 54 | 'max_types_and_units_scenario': max_types_and_units_scenario, 55 | 'ally_centered': ally_centered, 56 | 'rotate': rotate, 57 | 'separation': separation, 58 | 'jitter': jitter, 59 | 'episode_limit': episode_limit, 60 | 'n_extra_tags': n_extra_tags, 61 | 'map_name': map_name} 62 | return scenario_dict 63 | 64 | 65 | def asymm_armies(army_spec, spec_delta, ally_centered=False, 66 | rotate=False, separation=10, 67 | jitter=0, episode_limit=100, map_name="empty_passive", 68 | n_extra_tags=0, 69 | rs=None): 70 | if rs is None: 71 | rs = RandomState() 72 | 73 | unique_sub_teams = [] 74 | for unit_types, n_unit_range in army_spec: 75 | unique_sub_teams.append(get_all_unique_teams(unit_types, n_unit_range[0], 76 | n_unit_range[1])) 77 | enemy_teams = [sum(prod, []) for prod in product(*unique_sub_teams)] 78 | agent_teams = [[(max(num + spec_delta.get(typ, 0), 0), typ) for num, typ in team] for team in enemy_teams] 79 | 80 | scenarios = list(zip(agent_teams, enemy_teams)) 81 | # sort by number of types and total number of units 82 | max_types_and_units_ag_team = sorted(agent_teams, key=lambda x: (len(x), sum(num for num, unit in x)), reverse=True)[0] 83 | max_types_and_units_en_team = sorted(enemy_teams, key=lambda x: (len(x), sum(num for num, unit in x)), reverse=True)[0] 84 | max_types_and_units_scenario = (max_types_and_units_ag_team, 85 | max_types_and_units_en_team) 86 | 87 | scenario_dict = {'scenarios': scenarios, 88 | 'max_types_and_units_scenario': max_types_and_units_scenario, 89 | 'ally_centered': ally_centered, 90 | 'rotate': rotate, 91 | 'separation': separation, 92 | 'jitter': jitter, 93 | 'episode_limit': episode_limit, 94 | 'n_extra_tags': n_extra_tags, 95 | 'map_name': map_name} 96 | return scenario_dict 97 | 98 | 99 | """ 100 | The function in the registry needs to return a tuple of two lists, one for the 101 | ally army and one for the enemy. 102 | Each is of the form [(number, unit_type, pos), ....], where pos is the starting 103 | positiong (relative to center of map) for the corresponding units. 104 | The function will be called on each episode start. 105 | Currently, we only support the same number of agents and enemies each episode. 106 | """ 107 | 108 | custom_scenario_registry = { 109 | "3-8m_symmetric": partial(symmetric_armies, 110 | [(('Marine',), (3, 8))], 111 | rotate=True, 112 | ally_centered=False, 113 | separation=14, 114 | jitter=1, episode_limit=100, map_name="empty_passive"), 115 | "6-11m_mandown": partial(asymm_armies, 116 | [(('Marine',), (6, 11))], 117 | {'Marine': -1}, 118 | rotate=True, 119 | ally_centered=False, 120 | separation=14, 121 | jitter=1, episode_limit=100, map_name="empty_passive"), 122 | "3-8sz_symmetric": partial(symmetric_armies, 123 | [(('Stalker', 'Zealot'), (3, 8))], 124 | rotate=True, 125 | ally_centered=False, 126 | separation=14, 127 | jitter=1, episode_limit=150, map_name="empty_passive"), 128 | "3-8MMM_symmetric": partial(symmetric_armies, 129 | [(('Marine', 'Marauder'), (3, 6)), 130 | (('Medivac',), (0, 2))], 131 | rotate=True, 132 | ally_centered=False, 133 | separation=14, 134 | jitter=1, episode_limit=150, map_name="empty_passive"), 135 | "3-8csz_symmetric": partial(symmetric_armies, 136 | [(('Stalker', 'Zealot'), (3, 6)), 137 | (('Colossus',), (0, 2))], 138 | rotate=True, 139 | ally_centered=False, 140 | separation=14, 141 | jitter=1, episode_limit=150, map_name="empty_passive"), 142 | } 143 | -------------------------------------------------------------------------------- /src/modules/mixers/flex_qmix.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules.layers import EntityAttentionLayer, EntityPoolingLayer 5 | 6 | 7 | class AttentionHyperNet(nn.Module): 8 | """ 9 | mode='matrix' gets you a sized matrix 10 | mode='vector' gets you a sized vector by averaging over agents 11 | mode='alt_vector' gets you a sized vector by averaging over embedding dim 12 | mode='scalar' gets you a scalar by averaging over agents and embed dim 13 | ...per set of entities 14 | """ 15 | def __init__(self, args, extra_dims=0, mode='matrix'): 16 | super(AttentionHyperNet, self).__init__() 17 | self.args = args 18 | self.mode = mode 19 | self.extra_dims = extra_dims 20 | self.entity_dim = args.entity_shape 21 | if self.args.entity_last_action: 22 | self.entity_dim += args.n_actions 23 | if extra_dims > 0: 24 | self.entity_dim += extra_dims 25 | 26 | hypernet_embed = args.hypernet_embed 27 | self.fc1 = nn.Linear(self.entity_dim, hypernet_embed) 28 | if args.pooling_type is None: 29 | self.attn = EntityAttentionLayer(hypernet_embed, 30 | hypernet_embed, 31 | hypernet_embed, args) 32 | else: 33 | self.attn = EntityPoolingLayer(hypernet_embed, 34 | hypernet_embed, 35 | hypernet_embed, 36 | args.pooling_type, 37 | args) 38 | self.fc2 = nn.Linear(hypernet_embed, args.mixing_embed_dim) 39 | 40 | def forward(self, entities, entity_mask, attn_mask=None): 41 | x1 = F.relu(self.fc1(entities)) 42 | agent_mask = entity_mask[:, :self.args.n_agents] 43 | if attn_mask is None: 44 | # create attn_mask from entity mask 45 | attn_mask = 1 - th.bmm((1 - agent_mask.to(th.float)).unsqueeze(2), 46 | (1 - entity_mask.to(th.float)).unsqueeze(1)) 47 | x2 = self.attn(x1, pre_mask=attn_mask.to(th.uint8), 48 | post_mask=agent_mask) 49 | x3 = self.fc2(x2) 50 | x3 = x3.masked_fill(agent_mask.unsqueeze(2), 0) 51 | if self.mode == 'vector': 52 | return x3.mean(dim=1) 53 | elif self.mode == 'alt_vector': 54 | return x3.mean(dim=2) 55 | elif self.mode == 'scalar': 56 | return x3.mean(dim=(1, 2)) 57 | return x3 58 | 59 | 60 | class FlexQMixer(nn.Module): 61 | def __init__(self, args): 62 | super(FlexQMixer, self).__init__() 63 | self.args = args 64 | 65 | self.n_agents = args.n_agents 66 | 67 | self.embed_dim = args.mixing_embed_dim 68 | 69 | self.hyper_w_1 = AttentionHyperNet(args, mode='matrix') 70 | self.hyper_w_final = AttentionHyperNet(args, mode='vector') 71 | self.hyper_b_1 = AttentionHyperNet(args, mode='vector') 72 | # V(s) instead of a bias for the last layers 73 | self.V = AttentionHyperNet(args, mode='scalar') 74 | 75 | self.non_lin = F.elu 76 | if getattr(self.args, "mixer_non_lin", "elu") == "tanh": 77 | self.non_lin = F.tanh 78 | 79 | def forward(self, agent_qs, inputs, imagine_groups=None): 80 | entities, entity_mask = inputs 81 | bs, max_t, ne, ed = entities.shape 82 | 83 | entities = entities.reshape(bs * max_t, ne, ed) 84 | entity_mask = entity_mask.reshape(bs * max_t, ne) 85 | if imagine_groups is not None: 86 | agent_qs = agent_qs.view(-1, 1, self.n_agents * 2) 87 | Wmask, Imask = imagine_groups 88 | w1_W = self.hyper_w_1(entities, entity_mask, 89 | attn_mask=Wmask.reshape(bs * max_t, 90 | ne, ne)) 91 | w1_I = self.hyper_w_1(entities, entity_mask, 92 | attn_mask=Imask.reshape(bs * max_t, 93 | ne, ne)) 94 | w1 = th.cat([w1_W, w1_I], dim=1) 95 | else: 96 | agent_qs = agent_qs.view(-1, 1, self.n_agents) 97 | # First layer 98 | w1 = self.hyper_w_1(entities, entity_mask) 99 | b1 = self.hyper_b_1(entities, entity_mask) 100 | w1 = w1.view(bs * max_t, -1, self.embed_dim) 101 | b1 = b1.view(-1, 1, self.embed_dim) 102 | if self.args.softmax_mixing_weights: 103 | w1 = F.softmax(w1, dim=-1) 104 | else: 105 | w1 = th.abs(w1) 106 | 107 | hidden = self.non_lin(th.bmm(agent_qs, w1) + b1) 108 | # Second layer 109 | if self.args.softmax_mixing_weights: 110 | w_final = F.softmax(self.hyper_w_final(entities, entity_mask), dim=-1) 111 | else: 112 | w_final = th.abs(self.hyper_w_final(entities, entity_mask)) 113 | w_final = w_final.view(-1, self.embed_dim, 1) 114 | # State-dependent bias 115 | v = self.V(entities, entity_mask).view(-1, 1, 1) 116 | 117 | # Compute final output 118 | y = th.bmm(hidden, w_final) + v 119 | # Reshape and return 120 | q_tot = y.view(bs, -1, 1) 121 | return q_tot 122 | 123 | 124 | class LinearFlexQMixer(nn.Module): 125 | def __init__(self, args): 126 | super(LinearFlexQMixer, self).__init__() 127 | self.args = args 128 | 129 | self.n_agents = args.n_agents 130 | 131 | self.embed_dim = args.mixing_embed_dim 132 | 133 | self.hyper_w_1 = AttentionHyperNet(args, mode='alt_vector') 134 | self.V = AttentionHyperNet(args, mode='scalar') 135 | 136 | def forward(self, agent_qs, inputs, imagine_groups=None, ret_ingroup_prop=False): 137 | entities, entity_mask = inputs 138 | bs, max_t, ne, ed = entities.shape 139 | 140 | entities = entities.reshape(bs * max_t, ne, ed) 141 | entity_mask = entity_mask.reshape(bs * max_t, ne) 142 | if imagine_groups is not None: 143 | agent_qs = agent_qs.view(-1, self.n_agents * 2) 144 | Wmask, Imask = imagine_groups 145 | w1_W = self.hyper_w_1(entities, entity_mask, 146 | attn_mask=Wmask.reshape(bs * max_t, 147 | self.n_agents, ne)) 148 | w1_I = self.hyper_w_1(entities, entity_mask, 149 | attn_mask=Imask.reshape(bs * max_t, 150 | self.n_agents, ne)) 151 | w1 = th.cat([w1_W, w1_I], dim=1) 152 | else: 153 | agent_qs = agent_qs.view(-1, self.n_agents) 154 | # First layer 155 | w1 = self.hyper_w_1(entities, entity_mask) 156 | w1 = w1.view(bs * max_t, -1) 157 | if self.args.softmax_mixing_weights: 158 | w1 = F.softmax(w1, dim=1) 159 | else: 160 | w1 = th.abs(w1) 161 | v = self.V(entities, entity_mask) 162 | 163 | q_cont = agent_qs * w1 164 | q_tot = q_cont.sum(dim=1) + v 165 | # Reshape and return 166 | q_tot = q_tot.view(bs, -1, 1) 167 | if ret_ingroup_prop: 168 | ingroup_w = w1.clone() 169 | ingroup_w[:, self.n_agents:] = 0 # zero-out out of group weights 170 | ingroup_prop = (ingroup_w.sum(dim=1)).mean() 171 | return q_tot, ingroup_prop 172 | return q_tot 173 | -------------------------------------------------------------------------------- /src/learners/q_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.vdn import VDNMixer 4 | from modules.mixers.qmix import QMixer 5 | from modules.mixers.flex_qmix import FlexQMixer, LinearFlexQMixer 6 | import torch as th 7 | from torch.optim import RMSprop 8 | 9 | 10 | class QLearner: 11 | def __init__(self, mac, scheme, logger, args): 12 | self.args = args 13 | self.mac = mac 14 | self.logger = logger 15 | 16 | self.params = list(mac.parameters()) 17 | 18 | self.last_target_update_episode = 0 19 | 20 | self.mixer = None 21 | if args.mixer is not None: 22 | if args.mixer == "vdn": 23 | self.mixer = VDNMixer() 24 | elif args.mixer == "qmix": 25 | self.mixer = QMixer(args) 26 | elif args.mixer == "flex_qmix": 27 | assert args.entity_scheme, "FlexQMixer only available with entity scheme" 28 | self.mixer = FlexQMixer(args) 29 | elif args.mixer == "lin_flex_qmix": 30 | assert args.entity_scheme, "FlexQMixer only available with entity scheme" 31 | self.mixer = LinearFlexQMixer(args) 32 | else: 33 | raise ValueError("Mixer {} not recognised.".format(args.mixer)) 34 | self.params += list(self.mixer.parameters()) 35 | self.target_mixer = copy.deepcopy(self.mixer) 36 | 37 | self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps, 38 | weight_decay=args.weight_decay) 39 | 40 | # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC 41 | self.target_mac = copy.deepcopy(mac) 42 | 43 | self.log_stats_t = -self.args.learner_log_interval - 1 44 | 45 | def _get_mixer_ins(self, batch, repeat_batch=1): 46 | if not self.args.entity_scheme: 47 | return (batch["state"][:, :-1].repeat(repeat_batch, 1, 1), 48 | batch["state"][:, 1:]) 49 | else: 50 | entities = [] 51 | bs, max_t, ne, ed = batch["entities"].shape 52 | entities.append(batch["entities"]) 53 | if self.args.entity_last_action: 54 | last_actions = th.zeros(bs, max_t, ne, self.args.n_actions, 55 | device=batch.device, 56 | dtype=batch["entities"].dtype) 57 | last_actions[:, 1:, :self.args.n_agents] = batch["actions_onehot"][:, :-1] 58 | entities.append(last_actions) 59 | 60 | entities = th.cat(entities, dim=3) 61 | return ((entities[:, :-1].repeat(repeat_batch, 1, 1, 1), 62 | batch["entity_mask"][:, :-1].repeat(repeat_batch, 1, 1)), 63 | (entities[:, 1:], 64 | batch["entity_mask"][:, 1:])) 65 | 66 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 67 | # Get the relevant quantities 68 | rewards = batch["reward"][:, :-1] 69 | actions = batch["actions"][:, :-1] 70 | terminated = batch["terminated"][:, :-1].float() 71 | mask = batch["filled"][:, :-1].float() 72 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 73 | avail_actions = batch["avail_actions"] 74 | 75 | will_log = (t_env - self.log_stats_t >= self.args.learner_log_interval) 76 | 77 | # # Calculate estimated Q-Values 78 | # mac_out = [] 79 | self.mac.init_hidden(batch.batch_size) 80 | # enable things like dropout on mac and mixer, but not target_mac and target_mixer 81 | self.mac.train() 82 | self.mixer.train() 83 | self.target_mac.eval() 84 | self.target_mixer.eval() 85 | 86 | if 'imagine' in self.args.agent: 87 | all_mac_out, groups = self.mac.forward(batch, t=None, imagine=True, 88 | use_gt_factors=self.args.train_gt_factors, 89 | use_rand_gt_factors=self.args.train_rand_gt_factors) 90 | # Pick the Q-Values for the actions taken by each agent 91 | rep_actions = actions.repeat(3, 1, 1, 1) 92 | all_chosen_action_qvals = th.gather(all_mac_out[:, :-1], dim=3, index=rep_actions).squeeze(3) # Remove the last dim 93 | 94 | mac_out, moW, moI = all_mac_out.chunk(3, dim=0) 95 | chosen_action_qvals, caqW, caqI = all_chosen_action_qvals.chunk(3, dim=0) 96 | caq_imagine = th.cat([caqW, caqI], dim=2) 97 | 98 | if will_log and self.args.test_gt_factors: 99 | gt_all_mac_out, gt_groups = self.mac.forward(batch, t=None, imagine=True, use_gt_factors=True) 100 | # Pick the Q-Values for the actions taken by each agent 101 | gt_all_chosen_action_qvals = th.gather(gt_all_mac_out[:, :-1], dim=3, index=rep_actions).squeeze(3) # Remove the last dim 102 | 103 | gt_mac_out, gt_moW, gt_moI = gt_all_mac_out.chunk(3, dim=0) 104 | gt_chosen_action_qvals, gt_caqW, gt_caqI = gt_all_chosen_action_qvals.chunk(3, dim=0) 105 | gt_caq_imagine = th.cat([gt_caqW, gt_caqI], dim=2) 106 | else: 107 | mac_out = self.mac.forward(batch, t=None) 108 | # Pick the Q-Values for the actions taken by each agent 109 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 110 | 111 | self.target_mac.init_hidden(batch.batch_size) 112 | 113 | target_mac_out = self.target_mac.forward(batch, t=None) 114 | avail_actions_targ = avail_actions 115 | target_mac_out = target_mac_out[:, 1:] 116 | 117 | # Mask out unavailable actions 118 | target_mac_out[avail_actions_targ[:, 1:] == 0] = -9999999 # From OG deepmarl 119 | 120 | # Max over target Q-Values 121 | if self.args.double_q: 122 | # Get actions that maximise live Q (for double q-learning) 123 | mac_out_detach = mac_out.clone().detach() 124 | mac_out_detach[avail_actions_targ == 0] = -9999999 125 | cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] 126 | target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) 127 | else: 128 | target_max_qvals = target_mac_out.max(dim=3)[0] 129 | 130 | # Mix 131 | if self.mixer is not None: 132 | if 'imagine' in self.args.agent: 133 | mix_ins, targ_mix_ins = self._get_mixer_ins(batch) 134 | chosen_action_qvals = self.mixer(chosen_action_qvals, 135 | mix_ins) 136 | # don't need last timestep 137 | groups = [gr[:, :-1] for gr in groups] 138 | if will_log and self.args.test_gt_factors: 139 | caq_imagine, ingroup_prop = self.mixer( 140 | caq_imagine, mix_ins, 141 | imagine_groups=groups, 142 | ret_ingroup_prop=True) 143 | gt_groups = [gr[:, :-1] for gr in gt_groups] 144 | gt_caq_imagine, gt_ingroup_prop = self.mixer( 145 | gt_caq_imagine, mix_ins, 146 | imagine_groups=gt_groups, 147 | ret_ingroup_prop=True) 148 | else: 149 | caq_imagine = self.mixer(caq_imagine, mix_ins, 150 | imagine_groups=groups) 151 | else: 152 | mix_ins, targ_mix_ins = self._get_mixer_ins(batch) 153 | chosen_action_qvals = self.mixer(chosen_action_qvals, mix_ins) 154 | target_max_qvals = self.target_mixer(target_max_qvals, targ_mix_ins) 155 | 156 | # Calculate 1-step Q-Learning targets 157 | targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals 158 | 159 | # Td-error 160 | td_error = (chosen_action_qvals - targets.detach()) 161 | mask = mask.expand_as(td_error) 162 | # 0-out the targets that came from padded data 163 | masked_td_error = td_error * mask 164 | # Normal L2 loss, take mean over actual data 165 | loss = (masked_td_error ** 2).sum() / mask.sum() 166 | 167 | if 'imagine' in self.args.agent: 168 | im_prop = self.args.lmbda 169 | im_td_error = (caq_imagine - targets.detach()) 170 | im_masked_td_error = im_td_error * mask 171 | im_loss = (im_masked_td_error ** 2).sum() / mask.sum() 172 | loss = (1 - im_prop) * loss + im_prop * im_loss 173 | 174 | # Optimise 175 | self.optimiser.zero_grad() 176 | loss.backward() 177 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 178 | self.optimiser.step() 179 | 180 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 181 | self._update_targets() 182 | self.last_target_update_episode = episode_num 183 | 184 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 185 | self.logger.log_stat("loss", loss.item(), t_env) 186 | if 'imagine' in self.args.agent: 187 | self.logger.log_stat("im_loss", im_loss.item(), t_env) 188 | if self.args.test_gt_factors: 189 | self.logger.log_stat("ingroup_prop", ingroup_prop.item(), t_env) 190 | self.logger.log_stat("gt_ingroup_prop", gt_ingroup_prop.item(), t_env) 191 | self.logger.log_stat("grad_norm", grad_norm, t_env) 192 | mask_elems = mask.sum().item() 193 | self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) 194 | self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 195 | self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 196 | if batch.max_seq_length == 2: 197 | # We are in a 1-step env. Calculate the max Q-Value for logging 198 | max_agent_qvals = mac_out_detach[:,0].max(dim=2, keepdim=True)[0] 199 | max_qtots = self.mixer(max_agent_qvals, batch["state"][:,0]) 200 | self.logger.log_stat("max_qtot", max_qtots.mean().item(), t_env) 201 | self.log_stats_t = t_env 202 | 203 | def _update_targets(self): 204 | self.target_mac.load_state(self.mac) 205 | if self.mixer is not None: 206 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 207 | self.logger.console_logger.info("Updated target network") 208 | 209 | def cuda(self): 210 | self.mac.cuda() 211 | self.target_mac.cuda() 212 | if self.mixer is not None: 213 | self.mixer.cuda() 214 | self.target_mixer.cuda() 215 | 216 | def save_models(self, path): 217 | self.mac.save_models(path) 218 | if self.mixer is not None: 219 | th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) 220 | th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) 221 | 222 | def load_models(self, path, evaluate=False): 223 | self.mac.load_models(path) 224 | # Not quite right but I don't want to save target networks 225 | self.target_mac.load_models(path) 226 | if not evaluate: 227 | if self.mixer is not None: 228 | self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) 229 | self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) 230 | -------------------------------------------------------------------------------- /src/components/episode_buffer.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from types import SimpleNamespace as SN 4 | 5 | 6 | class EpisodeBatch: 7 | def __init__(self, 8 | scheme, 9 | groups, 10 | batch_size, 11 | max_seq_length, 12 | data=None, 13 | preprocess=None, 14 | device="cpu"): 15 | self.scheme = scheme.copy() 16 | self.groups = groups 17 | self.batch_size = batch_size 18 | self.max_seq_length = max_seq_length 19 | self.preprocess = {} if preprocess is None else preprocess 20 | self.device = device 21 | 22 | if data is not None: 23 | self.data = data 24 | else: 25 | self.data = SN() 26 | self.data.transition_data = {} 27 | self.data.episode_data = {} 28 | self._setup_data(self.scheme, self.groups, batch_size, max_seq_length, self.preprocess) 29 | 30 | def _setup_data(self, scheme, groups, batch_size, max_seq_length, preprocess): 31 | if preprocess is not None: 32 | for k in preprocess: 33 | assert k in scheme 34 | new_k = preprocess[k][0] 35 | transforms = preprocess[k][1] 36 | 37 | vshape = self.scheme[k]["vshape"] 38 | dtype = self.scheme[k]["dtype"] 39 | for transform in transforms: 40 | vshape, dtype = transform.infer_output_info(vshape, dtype) 41 | 42 | self.scheme[new_k] = { 43 | "vshape": vshape, 44 | "dtype": dtype 45 | } 46 | if "group" in self.scheme[k]: 47 | self.scheme[new_k]["group"] = self.scheme[k]["group"] 48 | if "episode_const" in self.scheme[k]: 49 | self.scheme[new_k]["episode_const"] = self.scheme[k]["episode_const"] 50 | 51 | assert "filled" not in scheme, '"filled" is a reserved key for masking.' 52 | scheme.update({ 53 | "filled": {"vshape": (1,), "dtype": th.long}, 54 | }) 55 | 56 | for field_key, field_info in scheme.items(): 57 | assert "vshape" in field_info, "Scheme must define vshape for {}".format(field_key) 58 | vshape = field_info["vshape"] 59 | episode_const = field_info.get("episode_const", False) 60 | group = field_info.get("group", None) 61 | dtype = field_info.get("dtype", th.float32) 62 | 63 | if isinstance(vshape, int): 64 | vshape = (vshape,) 65 | 66 | if group: 67 | assert group in groups, "Group {} must have its number of members defined in _groups_".format(group) 68 | shape = (groups[group], *vshape) 69 | else: 70 | shape = vshape 71 | 72 | if episode_const: 73 | self.data.episode_data[field_key] = th.zeros((batch_size, *shape), dtype=dtype, device=self.device) 74 | else: 75 | self.data.transition_data[field_key] = th.zeros((batch_size, max_seq_length, *shape), dtype=dtype, device=self.device) 76 | 77 | def extend(self, scheme, groups=None): 78 | self._setup_data(scheme, self.groups if groups is None else groups, self.batch_size, self.max_seq_length) 79 | 80 | def to(self, device): 81 | for k, v in self.data.transition_data.items(): 82 | self.data.transition_data[k] = v.to(device) 83 | for k, v in self.data.episode_data.items(): 84 | self.data.episode_data[k] = v.to(device) 85 | self.device = device 86 | 87 | def update(self, data, bs=slice(None), ts=slice(None), mark_filled=True): 88 | slices = self._parse_slices((bs, ts)) 89 | for k, v in data.items(): 90 | if k in self.data.transition_data: 91 | target = self.data.transition_data 92 | if mark_filled: 93 | target["filled"][slices] = 1 94 | mark_filled = False 95 | _slices = slices 96 | elif k in self.data.episode_data: 97 | target = self.data.episode_data 98 | _slices = slices[0] 99 | else: 100 | raise KeyError("{} not found in transition or episode data".format(k)) 101 | 102 | dtype = self.scheme[k].get("dtype", th.float32) 103 | v = th.tensor(v, dtype=dtype, device=self.device) 104 | self._check_safe_view(v, target[k][_slices]) 105 | target[k][_slices] = v.view_as(target[k][_slices]) 106 | 107 | if k in self.preprocess: 108 | new_k = self.preprocess[k][0] 109 | v = target[k][_slices] 110 | for transform in self.preprocess[k][1]: 111 | v = transform.transform(v) 112 | target[new_k][_slices] = v.view_as(target[new_k][_slices]) 113 | 114 | def _check_safe_view(self, v, dest): 115 | idx = len(v.shape) - 1 116 | for s in dest.shape[::-1]: 117 | if v.shape[idx] != s: 118 | if s != 1: 119 | raise ValueError("Unsafe reshape of {} to {}".format(v.shape, dest.shape)) 120 | else: 121 | idx -= 1 122 | 123 | def __getitem__(self, item): 124 | if isinstance(item, str): 125 | if item in self.data.episode_data: 126 | return self.data.episode_data[item] 127 | elif item in self.data.transition_data: 128 | return self.data.transition_data[item] 129 | else: 130 | raise ValueError 131 | elif isinstance(item, tuple) and all([isinstance(it, str) for it in item]): 132 | new_data = self._new_data_sn() 133 | for key in item: 134 | if key in self.data.transition_data: 135 | new_data.transition_data[key] = self.data.transition_data[key] 136 | elif key in self.data.episode_data: 137 | new_data.episode_data[key] = self.data.episode_data[key] 138 | else: 139 | raise KeyError("Unrecognised key {}".format(key)) 140 | 141 | # Update the scheme to only have the requested keys 142 | new_scheme = {key: self.scheme[key] for key in item} 143 | new_groups = {self.scheme[key]["group"]: self.groups[self.scheme[key]["group"]] 144 | for key in item if "group" in self.scheme[key]} 145 | ret = EpisodeBatch(new_scheme, new_groups, self.batch_size, self.max_seq_length, data=new_data, device=self.device) 146 | return ret 147 | else: 148 | item = self._parse_slices(item) 149 | new_data = self._new_data_sn() 150 | for k, v in self.data.transition_data.items(): 151 | new_data.transition_data[k] = v[item] 152 | for k, v in self.data.episode_data.items(): 153 | new_data.episode_data[k] = v[item[0]] 154 | 155 | ret_bs = self._get_num_items(item[0], self.batch_size) 156 | ret_max_t = self._get_num_items(item[1], self.max_seq_length) 157 | 158 | ret = EpisodeBatch(self.scheme, self.groups, ret_bs, ret_max_t, data=new_data, device=self.device) 159 | return ret 160 | 161 | def _get_num_items(self, indexing_item, max_size): 162 | if isinstance(indexing_item, list) or isinstance(indexing_item, np.ndarray): 163 | return len(indexing_item) 164 | elif isinstance(indexing_item, slice): 165 | _range = indexing_item.indices(max_size) 166 | return 1 + (_range[1] - _range[0] - 1)//_range[2] 167 | 168 | def _new_data_sn(self): 169 | new_data = SN() 170 | new_data.transition_data = {} 171 | new_data.episode_data = {} 172 | return new_data 173 | 174 | def _parse_slices(self, items): 175 | parsed = [] 176 | # Only batch slice given, add full time slice 177 | if (isinstance(items, slice) # slice a:b 178 | or isinstance(items, int) # int i 179 | or (isinstance(items, (list, np.ndarray, th.LongTensor, th.cuda.LongTensor))) # [a,b,c] 180 | ): 181 | items = (items, slice(None)) 182 | 183 | # Need the time indexing to be contiguous 184 | if isinstance(items[1], list): 185 | raise IndexError("Indexing across Time must be contiguous") 186 | 187 | for item in items: 188 | #TODO: stronger checks to ensure only supported options get through 189 | if isinstance(item, int): 190 | # Convert single indices to slices 191 | parsed.append(slice(item, item+1)) 192 | else: 193 | # Leave slices and lists as is 194 | parsed.append(item) 195 | return parsed 196 | 197 | def max_t_filled(self): 198 | return th.sum(self.data.transition_data["filled"], 1).max(0)[0] 199 | 200 | def __repr__(self): 201 | return "EpisodeBatch. Batch Size:{} Max_seq_len:{} Keys:{} Groups:{}".format(self.batch_size, 202 | self.max_seq_length, 203 | self.scheme.keys(), 204 | self.groups.keys()) 205 | 206 | class ReplayBuffer(EpisodeBatch): 207 | def __init__(self, scheme, groups, buffer_size, max_seq_length, preprocess=None, device="cpu"): 208 | super(ReplayBuffer, self).__init__(scheme, groups, buffer_size, max_seq_length, preprocess=preprocess, device=device) 209 | self.buffer_size = buffer_size # same as self.batch_size but more explicit 210 | self.buffer_index = 0 211 | self.episodes_in_buffer = 0 212 | 213 | def insert_episode_batch(self, ep_batch): 214 | if self.buffer_index + ep_batch.batch_size <= self.buffer_size: 215 | self.update(ep_batch.data.transition_data, 216 | slice(self.buffer_index, self.buffer_index + ep_batch.batch_size), 217 | slice(0, ep_batch.max_seq_length), 218 | mark_filled=False) 219 | self.update(ep_batch.data.episode_data, 220 | slice(self.buffer_index, self.buffer_index + ep_batch.batch_size)) 221 | self.buffer_index = (self.buffer_index + ep_batch.batch_size) 222 | self.episodes_in_buffer = max(self.episodes_in_buffer, self.buffer_index) 223 | self.buffer_index = self.buffer_index % self.buffer_size 224 | assert self.buffer_index < self.buffer_size 225 | else: 226 | buffer_left = self.buffer_size - self.buffer_index 227 | self.insert_episode_batch(ep_batch[0:buffer_left, :]) 228 | self.insert_episode_batch(ep_batch[buffer_left:, :]) 229 | 230 | def can_sample(self, batch_size): 231 | return self.episodes_in_buffer >= batch_size 232 | 233 | def sample(self, batch_size): 234 | assert self.can_sample(batch_size) 235 | if self.episodes_in_buffer == batch_size: 236 | return self[:batch_size] 237 | else: 238 | # Uniform sampling only atm 239 | ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False) 240 | return self[ep_ids] 241 | 242 | def __repr__(self): 243 | return "ReplayBuffer. {}/{} episodes. Keys:{} Groups:{}".format(self.episodes_in_buffer, 244 | self.buffer_size, 245 | self.scheme.keys(), 246 | self.groups.keys()) 247 | 248 | # def _check_slice(self, slice, max_size): 249 | # if slice.step is not None: 250 | # return slice.step > 0 # pytorch doesn't support negative steps so neither do we 251 | # if slice.start is None and slice.stop is None: 252 | # return True 253 | # elif slice.start is None: 254 | # return 0 < slice.stop <= max_size 255 | # elif slice.stop is None: 256 | # return 0 <= slice.start < max_size 257 | # else: 258 | # return (0 < slice.stop <= max_size) and (0 <= slice.start < max_size) 259 | 260 | if __name__ == "__main__": 261 | bs = 4 262 | n_agents = 2 263 | groups = {"agents": n_agents} 264 | 265 | # "input": {"vshape": (shape), "episode_const": bool, "group": (name), "dtype": dtype} 266 | scheme = { 267 | "actions": {"vshape": (1,), "group": "agents", "dtype": th.long}, 268 | "obs": {"vshape": (3,), "group": "agents"}, 269 | "state": {"vshape": (3,3)}, 270 | "epsilon": {"vshape": (1,), "episode_const": True} 271 | } 272 | from transforms import OneHot 273 | preprocess = { 274 | "actions": ("actions_onehot", [OneHot(out_dim=5)]) 275 | } 276 | 277 | ep_batch = EpisodeBatch(scheme, groups, bs, 3, preprocess=preprocess) 278 | 279 | env_data = { 280 | "actions": th.ones(n_agents, 1).long(), 281 | "obs": th.ones(2, 3), 282 | "state": th.eye(3) 283 | } 284 | batch_data = { 285 | "actions": th.ones(bs, n_agents, 1).long(), 286 | "obs": th.ones(2, 3).unsqueeze(0).repeat(bs,1,1), 287 | "state": th.eye(3).unsqueeze(0).repeat(bs,1,1), 288 | } 289 | # bs=4 x t=3 x v=3*3 290 | 291 | ep_batch.update(env_data, 0, 0) 292 | 293 | ep_batch.update({"epsilon": th.ones(bs)*.05}) 294 | 295 | ep_batch[:, 1].update(batch_data) 296 | ep_batch.update(batch_data, ts=1) 297 | 298 | ep_batch.update(env_data, 0, 1) 299 | 300 | env_data = { 301 | "obs": th.ones(2, 3), 302 | "state": th.eye(3)*2 303 | } 304 | ep_batch.update(env_data, 3, 0) 305 | 306 | b2 = ep_batch[0, 1] 307 | b2.update(env_data, 0, 0) 308 | 309 | replay_buffer = ReplayBuffer(scheme, groups, 5, 3, preprocess=preprocess) 310 | 311 | replay_buffer.insert_episode_batch(ep_batch) 312 | 313 | replay_buffer.insert_episode_batch(ep_batch) 314 | 315 | sampled = replay_buffer.sample(3) 316 | 317 | print(sampled["actions_onehot"]) -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from functools import partial 3 | from math import ceil 4 | import imageio 5 | import os 6 | import pprint 7 | import time 8 | import json 9 | import threading 10 | import torch as th 11 | from numpy.random import RandomState 12 | from types import SimpleNamespace as SN 13 | from utils.logging import Logger 14 | from utils.timehelper import time_left, time_str 15 | from os.path import dirname, abspath, basename, join, splitext 16 | 17 | from learners import REGISTRY as le_REGISTRY 18 | from runners import REGISTRY as r_REGISTRY 19 | from controllers import REGISTRY as mac_REGISTRY 20 | from envs import s_REGISTRY 21 | from components.episode_buffer import ReplayBuffer 22 | from components.transforms import OneHot 23 | 24 | 25 | def run(_run, _config, _log): 26 | # check args sanity 27 | _config = args_sanity_check(_config, _log) 28 | 29 | args = SN(**_config) 30 | args.device = "cuda" if args.use_cuda else "cpu" 31 | 32 | # setup loggers 33 | logger = Logger(_log) 34 | 35 | _log.info("Experiment Parameters:") 36 | experiment_params = pprint.pformat(_config, 37 | indent=4, 38 | width=1) 39 | _log.info("\n\n" + experiment_params + "\n") 40 | 41 | # configure tensorboard logger 42 | unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")) 43 | args.unique_token = unique_token 44 | if args.use_tensorboard: 45 | tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", args.tb_dirname) 46 | tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token) 47 | logger.setup_tb(tb_exp_direc) 48 | 49 | # sacred is on by default 50 | logger.setup_sacred(_run) 51 | 52 | # Run and train 53 | run_sequential(args=args, logger=logger) 54 | 55 | # Clean up after finishing 56 | print("Exiting Main") 57 | 58 | print("Stopping all threads") 59 | for t in threading.enumerate(): 60 | if t.name != "MainThread": 61 | print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon)) 62 | t.join(timeout=1) 63 | print("Thread joined") 64 | 65 | print("Exiting script") 66 | 67 | # Making sure framework really exits 68 | os._exit(os.EX_OK) 69 | 70 | 71 | def evaluate_sequential(args, runner, logger): 72 | vw = None 73 | if args.video_path is not None: 74 | os.makedirs(dirname(args.video_path), exist_ok=True) 75 | vid_basename_split = splitext(basename(args.video_path)) 76 | if vid_basename_split[1] == '.mp4': 77 | vid_basename = ''.join(vid_basename_split) 78 | else: 79 | vid_basename = ''.join(vid_basename_split) + '.mp4' 80 | vid_filename = join(dirname(args.video_path), vid_basename) 81 | vw = imageio.get_writer(vid_filename, format='FFMPEG', mode='I', 82 | fps=args.fps, codec='h264', quality=10) 83 | 84 | if args.eval_path is not None: 85 | os.makedirs(dirname(args.eval_path), exist_ok=True) 86 | eval_basename_split = splitext(basename(args.eval_path)) 87 | if eval_basename_split[1] == '.json': 88 | eval_basename = ''.join(eval_basename_split) 89 | else: 90 | eval_basename = ''.join(eval_basename_split) + '.json' 91 | eval_filename = join(dirname(args.eval_path), eval_basename) 92 | 93 | res_dict = {} 94 | 95 | if args.eval_all_scen: 96 | if 'sc2' in args.env: 97 | dict_key = 'scenarios' 98 | else: 99 | raise Exception("Environment (%s) does not incorporate multiple scenarios") 100 | n_scen = len(args.env_args['scenario_dict'][dict_key]) 101 | else: 102 | n_scen = 1 103 | n_test_batches = max(1, args.test_nepisode // runner.batch_size) 104 | 105 | for i in range(n_scen): 106 | run_args = {'test_mode': True, 'vid_writer': vw, 107 | 'test_scen': True} 108 | if args.eval_all_scen: 109 | run_args['index'] = i 110 | for _ in range(n_test_batches): 111 | runner.run(**run_args) 112 | curr_stats = dict((k, v[-1][1]) for k, v in logger.stats.items()) 113 | if args.eval_all_scen: 114 | curr_scen = args.env_args['scenario_dict'][dict_key][i] 115 | # assumes that unique set of agents is a unique scenario 116 | if 'sc2' in args.env: 117 | scen_str = "-".join("%i%s" % (count, name[:3]) for count, name in sorted(curr_scen[0], key=lambda x: x[1])) 118 | else: 119 | scen_str = "".join(curr_scen[0]) 120 | res_dict[scen_str] = curr_stats 121 | else: 122 | res_dict.update(curr_stats) 123 | 124 | if vw is not None: 125 | vw.close() 126 | 127 | if args.eval_path is not None: 128 | with open(eval_filename, 'w') as f: 129 | json.dump(res_dict, f) 130 | 131 | if args.save_replay: 132 | runner.save_replay() 133 | 134 | runner.close_env() 135 | logger.print_stats_summary() 136 | 137 | 138 | def run_sequential(args, logger): 139 | # Init runner so we can get env info 140 | if 'entity_scheme' in args.env_args: 141 | args.entity_scheme = args.env_args['entity_scheme'] 142 | else: 143 | args.entity_scheme = False 144 | 145 | if ('sc2custom' in args.env): 146 | rs = RandomState(0) 147 | args.env_args['scenario_dict'] = s_REGISTRY[args.scenario](rs=rs) 148 | runner = r_REGISTRY[args.runner](args=args, logger=logger) 149 | 150 | # Set up schemes and groups here 151 | env_info = runner.get_env_info() 152 | args.n_agents = env_info["n_agents"] 153 | args.n_actions = env_info["n_actions"] 154 | if not args.entity_scheme: 155 | args.state_shape = env_info["state_shape"] 156 | # Default/Base scheme 157 | scheme = { 158 | "state": {"vshape": env_info["state_shape"]}, 159 | "obs": {"vshape": env_info["obs_shape"], "group": "agents"}, 160 | "actions": {"vshape": (1,), "group": "agents", "dtype": th.long}, 161 | "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int}, 162 | "reward": {"vshape": (1,)}, 163 | "terminated": {"vshape": (1,), "dtype": th.uint8}, 164 | } 165 | groups = { 166 | "agents": args.n_agents 167 | } 168 | if 'masks' in env_info: 169 | # masks that identify what part of observation/state spaces correspond to each entity 170 | args.obs_masks, args.state_masks = env_info['masks'] 171 | if 'unit_dim' in env_info: 172 | args.unit_dim = env_info['unit_dim'] 173 | else: 174 | args.entity_shape = env_info["entity_shape"] 175 | args.n_entities = env_info["n_entities"] 176 | args.gt_mask_avail = env_info.get("gt_mask_avail", False) 177 | # Entity scheme 178 | scheme = { 179 | "entities": {"vshape": env_info["entity_shape"], "group": "entities"}, 180 | "obs_mask": {"vshape": env_info["n_entities"], "group": "entities", "dtype": th.uint8}, 181 | "entity_mask": {"vshape": env_info["n_entities"], "dtype": th.uint8}, 182 | "actions": {"vshape": (1,), "group": "agents", "dtype": th.long}, 183 | "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int}, 184 | "reward": {"vshape": (1,)}, 185 | "terminated": {"vshape": (1,), "dtype": th.uint8}, 186 | } 187 | if args.gt_mask_avail: 188 | scheme["gt_mask"] = {"vshape": env_info["n_entities"], "group": "agents", "dtype": th.uint8} 189 | groups = { 190 | "agents": args.n_agents, 191 | "entities": args.n_entities 192 | } 193 | 194 | preprocess = { 195 | "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)]) 196 | } 197 | 198 | buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1, 199 | preprocess=preprocess, 200 | device="cpu" if args.buffer_cpu_only else args.device) 201 | 202 | # Setup multiagent controller here 203 | mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args) 204 | 205 | # Give runner the scheme 206 | runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac) 207 | 208 | # Learner 209 | learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args) 210 | 211 | if args.use_cuda: 212 | learner.cuda() 213 | 214 | if args.checkpoint_path != "": 215 | 216 | timesteps = [] 217 | timestep_to_load = 0 218 | 219 | if not os.path.isdir(args.checkpoint_path): 220 | logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path)) 221 | return 222 | 223 | # Go through all files in args.checkpoint_path 224 | for name in os.listdir(args.checkpoint_path): 225 | full_name = os.path.join(args.checkpoint_path, name) 226 | # Check if they are dirs the names of which are numbers 227 | if os.path.isdir(full_name) and name.isdigit(): 228 | timesteps.append(int(name)) 229 | 230 | if args.load_step == 0: 231 | # choose the max timestep 232 | timestep_to_load = max(timesteps) 233 | else: 234 | # choose the timestep closest to load_step 235 | timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step)) 236 | 237 | model_path = os.path.join(args.checkpoint_path, str(timestep_to_load)) 238 | 239 | logger.console_logger.info("Loading model from {}".format(model_path)) 240 | learner.load_models(model_path, evaluate=args.evaluate) 241 | runner.t_env = timestep_to_load 242 | 243 | if args.evaluate or args.save_replay: 244 | evaluate_sequential(args, runner, logger) 245 | return 246 | 247 | # start training 248 | episode = 0 249 | last_test_T = -args.test_interval - 1 250 | last_log_T = 0 251 | model_save_time = 0 252 | 253 | start_time = time.time() 254 | last_time = start_time 255 | 256 | logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max)) 257 | 258 | while runner.t_env <= args.t_max: 259 | 260 | # Run for a whole episode at a time 261 | episode_batch = runner.run(test_mode=False) 262 | buffer.insert_episode_batch(episode_batch) 263 | 264 | if buffer.can_sample(args.batch_size): 265 | for _ in range(args.training_iters): 266 | episode_sample = buffer.sample(args.batch_size) 267 | 268 | # Truncate batch to only filled timesteps 269 | max_ep_t = episode_sample.max_t_filled() 270 | episode_sample = episode_sample[:, :max_ep_t] 271 | 272 | if episode_sample.device != args.device: 273 | episode_sample.to(args.device) 274 | 275 | learner.train(episode_sample, runner.t_env, episode) 276 | 277 | # Execute test runs once in a while 278 | n_test_runs = max(1, args.test_nepisode // runner.batch_size) 279 | if (runner.t_env - last_test_T) / args.test_interval >= 1.0: 280 | 281 | logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max)) 282 | logger.console_logger.info("Estimated time left: {}. Time passed: {}".format( 283 | time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time))) 284 | last_time = time.time() 285 | 286 | last_test_T = runner.t_env 287 | for _ in range(n_test_runs): 288 | runner.run(test_mode=True) 289 | 290 | if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or 291 | model_save_time == 0 or 292 | runner.t_env > args.t_max): 293 | model_save_time = runner.t_env 294 | save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env)) 295 | #"results/models/{}".format(unique_token) 296 | os.makedirs(save_path, exist_ok=True) 297 | logger.console_logger.info("Saving models to {}".format(save_path)) 298 | 299 | # learner should handle saving/loading -- delegate actor save/load to mac, 300 | # use appropriate filenames to do critics, optimizer states 301 | learner.save_models(save_path) 302 | 303 | episode += args.batch_size_run 304 | 305 | if (runner.t_env - last_log_T) >= args.log_interval: 306 | logger.log_stat("episode", episode, runner.t_env) 307 | logger.print_recent_stats() 308 | last_log_T = runner.t_env 309 | 310 | runner.close_env() 311 | logger.console_logger.info("Finished Training") 312 | 313 | 314 | # TODO: Clean this up 315 | def args_sanity_check(config, _log): 316 | 317 | # set CUDA flags 318 | # config["use_cuda"] = True # Use cuda whenever possible! 319 | if config["use_cuda"] and not th.cuda.is_available(): 320 | config["use_cuda"] = False 321 | _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!") 322 | 323 | if config["test_nepisode"] < config["batch_size_run"]: 324 | config["test_nepisode"] = config["batch_size_run"] 325 | else: 326 | config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"] 327 | 328 | # assert (config["run_mode"] in ["parallel_subproc"] and config["use_replay_buffer"]) or (not config["run_mode"] in ["parallel_subproc"]), \ 329 | # "need to use replay buffer if running in parallel mode!" 330 | 331 | # assert not (not config["use_replay_buffer"] and (config["batch_size_run"]!=config["batch_size"]) ) , "if not using replay buffer, require batch_size and batch_size_run to be the same." 332 | 333 | # if config["learner"] == "coma": 334 | # assert (config["run_mode"] in ["parallel_subproc"] and config["batch_size_run"]==config["batch_size"]) or \ 335 | # (not config["run_mode"] in ["parallel_subproc"] and not config["use_replay_buffer"]), \ 336 | # "cannot use replay buffer for coma, unless in parallel mode, when it needs to have exactly have size batch_size." 337 | 338 | return config 339 | -------------------------------------------------------------------------------- /src/runners/parallel_runner.py: -------------------------------------------------------------------------------- 1 | from envs import REGISTRY as env_REGISTRY 2 | from functools import partial 3 | from components.episode_buffer import EpisodeBatch 4 | from multiprocessing import Pipe, Process 5 | import numpy as np 6 | import torch as th 7 | 8 | 9 | # Based (very) heavily on SubprocVecEnv from OpenAI Baselines 10 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py 11 | class ParallelRunner: 12 | 13 | def __init__(self, args, logger): 14 | self.args = args 15 | self.logger = logger 16 | self.batch_size = self.args.batch_size_run 17 | 18 | # Make subprocesses for the envs 19 | # TODO: Add a delay when making sc2 envs 20 | self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)]) 21 | env_fn = env_REGISTRY[self.args.env] 22 | if ('sc2' in self.args.env) or ('group_matching' in self.args.env): 23 | base_seed = self.args.env_args.pop('seed') 24 | self.ps = [Process(target=env_worker, args=(worker_conn, self.args.entity_scheme, 25 | CloudpickleWrapper(partial(env_fn, seed=base_seed + rank, 26 | **self.args.env_args)))) 27 | for rank, worker_conn in enumerate(self.worker_conns)] 28 | else: 29 | self.ps = [Process(target=env_worker, args=(worker_conn, self.args.entity_scheme, 30 | CloudpickleWrapper(partial(env_fn, env_args=self.args.env_args, args=self.args)))) 31 | for worker_conn in self.worker_conns] 32 | 33 | for p in self.ps: 34 | p.daemon = True 35 | p.start() 36 | 37 | # TODO: Close stuff if appropriate 38 | 39 | self.parent_conns[0].send(("get_env_info", args)) 40 | self.env_info = self.parent_conns[0].recv() 41 | self.episode_limit = self.env_info["episode_limit"] 42 | 43 | # TODO: Will have to add stuff to episode batch for envs that terminate at different times to ensure filled is correct 44 | self.t = 0 45 | 46 | self.t_env = 0 47 | 48 | self.train_returns = [] 49 | self.test_returns = [] 50 | self.train_stats = {} 51 | self.test_stats = {} 52 | 53 | self.log_train_stats_t = -100000 54 | 55 | def setup(self, scheme, groups, preprocess, mac): 56 | self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, 57 | preprocess=preprocess, device=self.args.device) 58 | self.mac = mac 59 | # TODO: Remove these if the runner doesn't need them 60 | self.scheme = scheme 61 | self.groups = groups 62 | self.preprocess = preprocess 63 | 64 | def get_env_info(self): 65 | return self.env_info 66 | 67 | def save_replay(self): 68 | pass 69 | 70 | def close_env(self): 71 | for parent_conn in self.parent_conns: 72 | parent_conn.send(("close", None)) 73 | 74 | def reset(self, **kwargs): 75 | self.batch = self.new_batch() 76 | 77 | # Reset the envs 78 | for parent_conn in self.parent_conns: 79 | parent_conn.send(("reset", kwargs)) 80 | 81 | pre_transition_data = {} 82 | # Get the obs, state and avail_actions back 83 | for parent_conn in self.parent_conns: 84 | data = parent_conn.recv() 85 | for k, v in data.items(): 86 | if k in pre_transition_data: 87 | pre_transition_data[k].append(data[k]) 88 | else: 89 | pre_transition_data[k] = [data[k]] 90 | 91 | self.batch.update(pre_transition_data, ts=0) 92 | 93 | self.t = 0 94 | self.env_steps_this_run = 0 95 | 96 | def run(self, test_mode=False, test_scen=None, index=None, vid_writer=None): 97 | """ 98 | test_mode: whether to use greedy action selection or sample actions 99 | test_scen: whether to run on test scenarios. defaults to matching test_mode. 100 | vid_writer: imageio video writer object (not supported in parallel runner) 101 | """ 102 | if test_scen is None: 103 | test_scen = test_mode 104 | assert vid_writer is None, "Writing videos not supported for ParallelRunner" 105 | self.reset(test=test_scen, index=index) 106 | 107 | all_terminated = False 108 | episode_returns = [0 for _ in range(self.batch_size)] 109 | episode_lengths = [0 for _ in range(self.batch_size)] 110 | self.mac.init_hidden(batch_size=self.batch_size) 111 | # make sure things like dropout are disabled 112 | self.mac.eval() 113 | terminated = [False for _ in range(self.batch_size)] 114 | envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed] 115 | final_env_infos = [] # may store extra stats like battle won. this is filled in ORDER OF TERMINATION 116 | 117 | while True: 118 | 119 | # Pass the entire batch of experiences up till now to the agents 120 | # Receive the actions for each agent at this timestep in a batch for each un-terminated env 121 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, bs=envs_not_terminated, test_mode=test_mode) 122 | cpu_actions = actions.to("cpu").numpy() 123 | 124 | # Update the actions taken 125 | actions_chosen = { 126 | "actions": actions.unsqueeze(1) 127 | } 128 | self.batch.update(actions_chosen, bs=envs_not_terminated, ts=self.t, mark_filled=False) 129 | 130 | # Send actions to each env 131 | action_idx = 0 132 | for idx, parent_conn in enumerate(self.parent_conns): 133 | if idx in envs_not_terminated: # We produced actions for this env 134 | if not terminated[idx]: # Only send the actions to the env if it hasn't terminated 135 | parent_conn.send(("step", cpu_actions[action_idx])) 136 | action_idx += 1 # actions is not a list over every env 137 | 138 | # Post step data we will insert for the current timestep 139 | post_transition_data = { 140 | # "actions": actions.unsqueeze(1), 141 | "reward": [], 142 | "terminated": [] 143 | } 144 | # Data for the next step we will insert in order to select an action 145 | if self.args.entity_scheme: 146 | pre_transition_data = { 147 | "entities": [], 148 | "obs_mask": [], 149 | "entity_mask": [], 150 | "avail_actions": [] 151 | } 152 | else: 153 | pre_transition_data = { 154 | "state": [], 155 | "avail_actions": [], 156 | "obs": [] 157 | } 158 | 159 | # Update terminated envs after adding post_transition_data 160 | envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed] 161 | all_terminated = all(terminated) 162 | if all_terminated: 163 | break 164 | 165 | # Receive data back for each unterminated env 166 | for idx, parent_conn in enumerate(self.parent_conns): 167 | if not terminated[idx]: 168 | data = parent_conn.recv() 169 | # Remaining data for this current timestep 170 | post_transition_data["reward"].append((data["reward"],)) 171 | 172 | episode_returns[idx] += data["reward"] 173 | episode_lengths[idx] += 1 174 | if not test_mode: 175 | self.env_steps_this_run += 1 176 | 177 | env_terminated = False 178 | if data["terminated"]: 179 | final_env_infos.append(data["info"]) 180 | if data["terminated"] and not data["info"].get("episode_limit", False): 181 | env_terminated = True 182 | terminated[idx] = data["terminated"] 183 | post_transition_data["terminated"].append((env_terminated,)) 184 | 185 | # Data for the next timestep needed to select an action 186 | for k in pre_transition_data: 187 | pre_transition_data[k].append(data[k]) 188 | 189 | # Add post_transiton data into the batch 190 | self.batch.update(post_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=False) 191 | 192 | # Move onto the next timestep 193 | self.t += 1 194 | 195 | # Add the pre-transition data 196 | 197 | self.batch.update(pre_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=True) 198 | 199 | 200 | if not test_mode: 201 | self.t_env += self.env_steps_this_run 202 | 203 | # Get stats back for each env 204 | for parent_conn in self.parent_conns: 205 | parent_conn.send(("get_stats",None)) 206 | 207 | env_stats = [] 208 | for parent_conn in self.parent_conns: 209 | env_stat = parent_conn.recv() 210 | env_stats.append(env_stat) 211 | 212 | cur_stats = self.test_stats if test_mode else self.train_stats 213 | cur_returns = self.test_returns if test_mode else self.train_returns 214 | log_prefix = "test_" if test_mode else "" 215 | infos = [cur_stats] + final_env_infos 216 | cur_stats.update({k: sum(d.get(k, 0) for d in infos) for k in set.union(*[set(d) for d in infos])}) 217 | cur_stats["n_episodes"] = self.batch_size + cur_stats.get("n_episodes", 0) 218 | cur_stats["ep_length"] = sum(episode_lengths) + cur_stats.get("ep_length", 0) 219 | 220 | cur_returns.extend(episode_returns) 221 | 222 | n_test_runs = max(1, self.args.test_nepisode // self.batch_size) * self.batch_size 223 | if test_mode and (len(self.test_returns) == n_test_runs): 224 | self._log(cur_returns, cur_stats, log_prefix) 225 | elif not test_mode and self.t_env - self.log_train_stats_t >= self.args.runner_log_interval: 226 | self._log(cur_returns, cur_stats, log_prefix) 227 | if hasattr(self.mac.action_selector, "epsilon"): 228 | self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env) 229 | if 'sc2' in self.args.env: 230 | self.logger.log_stat("forced_restarts", 231 | sum(es['restarts'] for es in env_stats), 232 | self.t_env) 233 | self.log_train_stats_t = self.t_env 234 | 235 | return self.batch 236 | 237 | def _log(self, returns, stats, prefix): 238 | self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env) 239 | self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env) 240 | returns.clear() 241 | 242 | for k, v in stats.items(): 243 | if k != "n_episodes": 244 | self.logger.log_stat(prefix + k + "_mean", v / stats["n_episodes"], self.t_env) 245 | stats.clear() 246 | 247 | 248 | def env_worker(remote, entity_scheme, env_fn): 249 | # Make environment 250 | env = env_fn.x() 251 | while True: 252 | cmd, data = remote.recv() 253 | if cmd == "step": 254 | actions = data 255 | # Take a step in the environment 256 | reward, terminated, env_info = env.step(actions) 257 | send_dict = { 258 | "avail_actions": env.get_avail_actions(), 259 | # Rest of the data for the current timestep 260 | "reward": reward, 261 | "terminated": terminated, 262 | "info": env_info 263 | } 264 | if entity_scheme: 265 | masks = env.get_masks() 266 | if len(masks) == 2: 267 | obs_mask, entity_mask = masks 268 | gt_mask = None 269 | else: 270 | obs_mask, entity_mask, gt_mask = masks 271 | send_dict["obs_mask"] = obs_mask 272 | send_dict["entity_mask"] = entity_mask 273 | if gt_mask is not None: 274 | send_dict["gt_mask"] = gt_mask 275 | send_dict["entities"] = env.get_entities() 276 | else: 277 | # Data for the next timestep needed to pick an action 278 | send_dict["state"] = env.get_state() 279 | send_dict["obs"] = env.get_obs() 280 | remote.send(send_dict) 281 | elif cmd == "reset": 282 | env.reset(**data) 283 | if entity_scheme: 284 | masks = env.get_masks() 285 | if len(masks) == 2: 286 | obs_mask, entity_mask = masks 287 | gt_mask = None 288 | else: 289 | obs_mask, entity_mask, gt_mask = masks 290 | send_dict = { 291 | "entities": env.get_entities(), 292 | "avail_actions": env.get_avail_actions(), 293 | "obs_mask": obs_mask, 294 | "entity_mask": entity_mask 295 | } 296 | if gt_mask is not None: 297 | send_dict["gt_mask"] = gt_mask 298 | remote.send(send_dict) 299 | else: 300 | remote.send({ 301 | "state": env.get_state(), 302 | "avail_actions": env.get_avail_actions(), 303 | "obs": env.get_obs() 304 | }) 305 | elif cmd == "close": 306 | env.close() 307 | remote.close() 308 | break 309 | elif cmd == "get_env_info": 310 | remote.send(env.get_env_info(data)) 311 | elif cmd == "get_stats": 312 | remote.send(env.get_stats()) 313 | # TODO: unused now? 314 | # elif cmd == "agg_stats": 315 | # agg_stats = env.get_agg_stats(data) 316 | # remote.send(agg_stats) 317 | else: 318 | raise NotImplementedError 319 | 320 | 321 | class CloudpickleWrapper(): 322 | """ 323 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 324 | """ 325 | def __init__(self, x): 326 | self.x = x 327 | def __getstate__(self): 328 | import cloudpickle 329 | return cloudpickle.dumps(self.x) 330 | def __setstate__(self, ob): 331 | import pickle 332 | self.x = pickle.loads(ob) 333 | 334 | --------------------------------------------------------------------------------