├── src ├── __init__.py ├── components │ ├── __init__.py │ ├── transforms.py │ ├── epsilon_schedules.py │ ├── action_selectors.py │ └── episode_buffer.py ├── modules │ ├── __init__.py │ ├── critics │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── fop.cpython-37.pyc │ │ │ ├── coma.cpython-37.pyc │ │ │ └── __init__.cpython-37.pyc │ │ ├── fop.py │ │ └── coma.py │ ├── mixers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── fop.cpython-37.pyc │ │ │ ├── qmix.cpython-37.pyc │ │ │ ├── vdn.cpython-37.pyc │ │ │ ├── qtran.cpython-37.pyc │ │ │ └── __init__.cpython-37.pyc │ │ ├── vdn.py │ │ ├── qmix.py │ │ ├── fop.py │ │ └── qtran.py │ └── agents │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── rnn_agent.cpython-37.pyc │ │ └── rnn_agent.py ├── controllers │ ├── __init__.py │ └── basic_controller.py ├── utils │ ├── dist2namedtuple.py │ ├── rl_utils.py │ ├── timehelper.py │ └── logging.py ├── runners │ ├── __init__.py │ ├── episode_runner.py │ └── parallel_runner.py ├── learners │ ├── __init__.py │ ├── q_learner.py │ ├── coma_learner.py │ ├── qtran_learner.py │ └── fop_learner.py ├── envs │ ├── __init__.py │ └── multiagentenv.py ├── config │ ├── algs │ │ ├── vdn.yaml │ │ ├── iql.yaml │ │ ├── qmix.yaml │ │ ├── fop.yaml │ │ ├── qtran.yaml │ │ └── coma.yaml │ ├── envs │ │ ├── sc2_beta.yaml │ │ └── sc2.yaml │ └── default.yaml ├── main.py └── run.py ├── run.sh ├── run_interactive.sh ├── requirements.txt ├── install_sc2.sh ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /src/components/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /src/modules/critics/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /src/modules/mixers/__init__.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- /src/modules/agents/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .rnn_agent import RNNAgent 4 | 5 | REGISTRY["rnn"] = RNNAgent 6 | -------------------------------------------------------------------------------- /src/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .basic_controller import BasicMAC 4 | 5 | REGISTRY["basic_mac"] = BasicMAC 6 | -------------------------------------------------------------------------------- /src/modules/critics/__pycache__/fop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/critics/__pycache__/fop.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/fop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/mixers/__pycache__/fop.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/qmix.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/mixers/__pycache__/qmix.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/vdn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/mixers/__pycache__/vdn.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/critics/__pycache__/coma.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/critics/__pycache__/coma.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/qtran.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/mixers/__pycache__/qtran.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/agents/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/agents/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/agents/__pycache__/rnn_agent.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/agents/__pycache__/rnn_agent.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/critics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/critics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/liyheng/FOP/HEAD/src/modules/mixers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/utils/dist2namedtuple.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | def convert(dictionary): 5 | return namedtuple('GenericDict', dictionary.keys())(**dictionary) 6 | -------------------------------------------------------------------------------- /src/runners/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .episode_runner import EpisodeRunner 4 | REGISTRY["episode"] = EpisodeRunner 5 | 6 | from .parallel_runner import ParallelRunner 7 | REGISTRY["parallel"] = ParallelRunner 8 | -------------------------------------------------------------------------------- /src/modules/mixers/vdn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | 5 | class VDNMixer(nn.Module): 6 | def __init__(self): 7 | super(VDNMixer, self).__init__() 8 | 9 | def forward(self, agent_qs, batch): 10 | return th.sum(agent_qs, dim=2, keepdim=True) -------------------------------------------------------------------------------- /src/learners/__init__.py: -------------------------------------------------------------------------------- 1 | from .q_learner import QLearner 2 | from .coma_learner import COMALearner 3 | from .qtran_learner import QLearner as QTranLearner 4 | from .fop_learner import FOP_Learner 5 | REGISTRY = {} 6 | 7 | REGISTRY["q_learner"] = QLearner 8 | REGISTRY["coma_learner"] = COMALearner 9 | REGISTRY["qtran_learner"] = QTranLearner 10 | REGISTRY["fop_learner"] = FOP_Learner -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from smac.env import MultiAgentEnv, StarCraft2Env 3 | import sys 4 | import os 5 | 6 | def env_fn(env, **kwargs) -> MultiAgentEnv: 7 | return env(**kwargs) 8 | 9 | REGISTRY = { 10 | "sc2": partial(env_fn, env=StarCraft2Env) 11 | 12 | } 13 | 14 | 15 | if sys.platform == "linux": 16 | os.environ.setdefault("SC2PATH", 17 | os.path.join(os.getcwd(), "3rdparty", "StarCraftII")) 18 | -------------------------------------------------------------------------------- /src/config/algs/vdn.yaml: -------------------------------------------------------------------------------- 1 | # use epsilon greedy action selector 2 | action_selector: "epsilon_greedy" 3 | epsilon_start: 1.0 4 | epsilon_finish: 0.05 5 | epsilon_anneal_time: 50000 6 | 7 | runner: "episode" 8 | 9 | buffer_size: 5000 10 | 11 | # update the target network every {} episodes 12 | target_update_interval: 200 13 | 14 | # use the Q_Learner to train 15 | agent_output_type: "q" 16 | learner: "q_learner" 17 | double_q: True 18 | mixer: "vdn" 19 | 20 | name: "vdn" 21 | -------------------------------------------------------------------------------- /src/config/algs/iql.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "q_learner" 19 | double_q: True 20 | mixer: # Mixer becomes None 21 | 22 | name: "iql" 23 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | HASH=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 4 | head -n 1) 3 | GPU=$1 4 | name=${USER}_pymarl_GPU_${GPU}_${HASH} 5 | 6 | echo "Launching container named '${name}' on GPU '${GPU}'" 7 | # Launches a docker container using our image, and runs the provided command 8 | 9 | if hash nvidia-docker 2>/dev/null; then 10 | cmd=nvidia-docker 11 | else 12 | cmd=docker 13 | fi 14 | 15 | NV_GPU="$GPU" ${cmd} run \ 16 | --name $name \ 17 | --user $(id -u):$(id -g) \ 18 | -v `pwd`:/pymarl \ 19 | -t pymarl:1.0 \ 20 | ${@:2} 21 | -------------------------------------------------------------------------------- /src/config/algs/qmix.yaml: -------------------------------------------------------------------------------- 1 | # use epsilon greedy action selector 2 | action_selector: "epsilon_greedy" 3 | epsilon_start: 1.0 4 | epsilon_finish: 0.05 5 | epsilon_anneal_time: 50000 6 | 7 | runner: "episode" 8 | 9 | buffer_size: 5000 10 | 11 | # update the target network every {} episodes 12 | target_update_interval: 200 13 | 14 | # use the Q_Learner to train 15 | agent_output_type: "q" 16 | learner: "q_learner" 17 | double_q: True 18 | mixer: "qmix" 19 | mixing_embed_dim: 32 20 | hypernet_layers: 2 21 | hypernet_embed: 64 22 | 23 | name: "qmix" 24 | -------------------------------------------------------------------------------- /run_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | HASH=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 4 | head -n 1) 3 | GPU=$1 4 | name=${USER}_pymarl_GPU_${GPU}_${HASH} 5 | 6 | echo "Launching container named '${name}' on GPU '${GPU}'" 7 | # Launches a docker container using our image, and runs the provided command 8 | 9 | if hash nvidia-docker 2>/dev/null; then 10 | cmd=nvidia-docker 11 | else 12 | cmd=docker 13 | fi 14 | 15 | NV_GPU="$GPU" ${cmd} run -i \ 16 | --name $name \ 17 | --user $(id -u):$(id -g) \ 18 | -v `pwd`:/pymarl \ 19 | -t pymarl:1.0 \ 20 | ${@:2} 21 | -------------------------------------------------------------------------------- /src/config/algs/fop.yaml: -------------------------------------------------------------------------------- 1 | # multinomial action selector 2 | action_selector: "multinomial" 3 | epsilon_start: 1.0 4 | epsilon_finish: .05 5 | epsilon_anneal_time: 50000 6 | mask_before_softmax: False 7 | 8 | runner: "episode" 9 | 10 | # update the target network every {} training steps 11 | target_update_interval: 200 12 | 13 | lr: 0.0005 14 | c_lr: 0.0005 15 | 16 | agent_output_type: "pi_logits" 17 | td_lambda: 0.8 18 | learner: "fop_learner" 19 | 20 | name: "fop" 21 | buffer_size: 5000 22 | 23 | mixing_embed_dim: 32 24 | n_head: 4 25 | burn_in_period: 100 26 | -------------------------------------------------------------------------------- /src/config/algs/qtran.yaml: -------------------------------------------------------------------------------- 1 | # use epsilon greedy action selector 2 | action_selector: "epsilon_greedy" 3 | epsilon_start: 1.0 4 | epsilon_finish: 0.05 5 | epsilon_anneal_time: 50000 6 | 7 | runner: "episode" 8 | 9 | buffer_size: 5000 10 | 11 | # update the target network every {} episodes 12 | target_update_interval: 200 13 | 14 | # use the Q_Learner to train 15 | agent_output_type: "q" 16 | learner: "qtran_learner" 17 | double_q: True 18 | mixer: "qtran_base" 19 | mixing_embed_dim: 64 20 | qtran_arch: "qtran_paper" 21 | 22 | opt_loss: 1 23 | nopt_min_loss: 0.1 24 | 25 | network_size: small 26 | 27 | name: "qtran" 28 | -------------------------------------------------------------------------------- /src/components/transforms.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class Transform: 5 | def transform(self, tensor): 6 | raise NotImplementedError 7 | 8 | def infer_output_info(self, vshape_in, dtype_in): 9 | raise NotImplementedError 10 | 11 | 12 | class OneHot(Transform): 13 | def __init__(self, out_dim): 14 | self.out_dim = out_dim 15 | 16 | def transform(self, tensor): 17 | y_onehot = tensor.new(*tensor.shape[:-1], self.out_dim).zero_() 18 | y_onehot.scatter_(-1, tensor.long(), 1) 19 | return y_onehot.float() 20 | 21 | def infer_output_info(self, vshape_in, dtype_in): 22 | return (self.out_dim,), th.float32 -------------------------------------------------------------------------------- /src/utils/rl_utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | def build_td_lambda_targets(rewards, terminated, mask, target_qs, n_agents, gamma, td_lambda): 5 | # Assumes in B*T*A and , , in (at least) B*T-1*1 6 | # Initialise last lambda -return for not terminated episodes 7 | ret = target_qs.new_zeros(*target_qs.shape) 8 | ret[:, -1] = target_qs[:, -1] * (1 - th.sum(terminated, dim=1)) 9 | # Backwards recursive update of the "forward view" 10 | for t in range(ret.shape[1] - 2, -1, -1): 11 | ret[:, t] = td_lambda * gamma * ret[:, t + 1] + mask[:, t] \ 12 | * (rewards[:, t] + (1 - td_lambda) * gamma * target_qs[:, t + 1] * (1 - terminated[:, t])) 13 | # Returns lambda-return from t=0 to t=T-1, i.e. in B*T-1*A 14 | return ret[:, 0:-1] 15 | -------------------------------------------------------------------------------- /src/config/algs/coma.yaml: -------------------------------------------------------------------------------- 1 | # --- COMA specific parameters --- 2 | 3 | action_selector: "multinomial" 4 | epsilon_start: .5 5 | epsilon_finish: .01 6 | epsilon_anneal_time: 100000 7 | mask_before_softmax: False 8 | 9 | runner: "parallel" 10 | 11 | buffer_size: 8 12 | batch_size_run: 8 13 | batch_size: 8 14 | 15 | env_args: 16 | state_last_action: False # critic adds last action internally 17 | 18 | # update the target network every {} training steps 19 | target_update_interval: 200 20 | 21 | lr: 0.0005 22 | critic_lr: 0.0005 23 | td_lambda: 0.8 24 | 25 | # use COMA 26 | agent_output_type: "pi_logits" 27 | learner: "coma_learner" 28 | critic_q_fn: "coma" 29 | critic_baseline_fn: "coma" 30 | critic_train_mode: "seq" 31 | critic_train_reps: 1 32 | q_nstep: 0 # 0 corresponds to default Q, 1 is r + gamma*Q, etc 33 | 34 | name: "coma" 35 | -------------------------------------------------------------------------------- /src/modules/agents/rnn_agent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class RNNAgent(nn.Module): 6 | def __init__(self, input_shape, args): 7 | super(RNNAgent, self).__init__() 8 | self.args = args 9 | 10 | self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) 11 | self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) 12 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) 13 | 14 | def init_hidden(self): 15 | # make hidden states on same device as model 16 | return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() 17 | 18 | def forward(self, inputs, hidden_state): 19 | x = F.relu(self.fc1(inputs)) 20 | h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) 21 | h = self.rnn(x, h_in) 22 | q = self.fc2(h) 23 | return q, h 24 | -------------------------------------------------------------------------------- /src/components/epsilon_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DecayThenFlatSchedule(): 5 | 6 | def __init__(self, 7 | start, 8 | finish, 9 | time_length, 10 | decay="exp"): 11 | 12 | self.start = start 13 | self.finish = finish 14 | self.time_length = time_length 15 | self.delta = (self.start - self.finish) / self.time_length 16 | self.decay = decay 17 | 18 | if self.decay in ["exp"]: 19 | self.exp_scaling = (-1) * self.time_length / np.log(self.finish) if self.finish > 0 else 1 20 | 21 | def eval(self, T): 22 | if self.decay in ["linear"]: 23 | return max(self.finish, self.start - self.delta * T) 24 | elif self.decay in ["exp"]: 25 | return min(self.start, max(self.finish, np.exp(- T / self.exp_scaling))) 26 | pass 27 | -------------------------------------------------------------------------------- /src/config/envs/sc2_beta.yaml: -------------------------------------------------------------------------------- 1 | env: sc2 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "3m" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | state_last_action: True 27 | state_timestep_number: False 28 | step_mul: 8 29 | seed: null 30 | heuristic_ai: False 31 | debug: False 32 | 33 | learner_log_interval: 20000 34 | log_interval: 20000 35 | runner_log_interval: 20000 36 | t_max: 40050000 37 | test_interval: 20000 38 | test_nepisode: 24 39 | test_greedy: True 40 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | atomicwrites==1.2.1 3 | attrs==18.2.0 4 | certifi==2018.8.24 5 | chardet==3.0.4 6 | cycler==0.10.0 7 | docopt==0.6.2 8 | enum34==1.1.6 9 | future==0.16.0 10 | idna==2.7 11 | imageio==2.4.1 12 | jsonpickle==0.9.6 13 | kiwisolver==1.0.1 14 | matplotlib==3.0.0 15 | mock==2.0.0 16 | more-itertools==4.3.0 17 | mpyq==0.2.5 18 | munch==2.3.2 19 | numpy==1.15.2 20 | pathlib2==2.3.2 21 | pbr==4.3.0 22 | Pillow==6.2.0 23 | pluggy==0.7.1 24 | portpicker==1.2.0 25 | probscale==0.2.3 26 | protobuf==3.6.1 27 | py==1.6.0 28 | pygame==1.9.4 29 | pyparsing==2.2.2 30 | pysc2==3.0.0 31 | pytest==3.8.2 32 | python-dateutil==2.7.3 33 | PyYAML==3.13 34 | requests==2.20.0 35 | s2clientprotocol==4.10.1.75800.0 36 | sacred==0.7.2 37 | scipy==1.1.0 38 | six==1.11.0 39 | sk-video==1.1.10 40 | snakeviz==1.0.0 41 | tensorboard-logger==0.1.0 42 | torch==0.4.1 43 | torchvision==0.2.1 44 | tornado==5.1.1 45 | urllib3==1.24.2 46 | websocket-client==0.53.0 47 | whichcraft==0.5.2 48 | wrapt==1.10.11 49 | git+https://github.com/oxwhirl/smac.git 50 | -------------------------------------------------------------------------------- /src/config/envs/sc2.yaml: -------------------------------------------------------------------------------- 1 | env: sc2 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "3m" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | state_last_action: True 27 | state_timestep_number: False 28 | step_mul: 8 29 | seed: null 30 | heuristic_ai: False 31 | heuristic_rest: False 32 | debug: False 33 | 34 | test_greedy: True 35 | test_nepisode: 32 36 | test_interval: 10000 37 | log_interval: 10000 38 | runner_log_interval: 10000 39 | learner_log_interval: 10000 40 | t_max: 2050000 41 | -------------------------------------------------------------------------------- /install_sc2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Install SC2 and add the custom maps 3 | 4 | if [ -z "$EXP_DIR" ] 5 | then 6 | EXP_DIR=~ 7 | fi 8 | 9 | echo "EXP_DIR: $EXP_DIR" 10 | cd $EXP_DIR/pymarl 11 | 12 | mkdir 3rdparty 13 | cd 3rdparty 14 | 15 | export SC2PATH=`pwd`'/StarCraftII' 16 | echo 'SC2PATH is set to '$SC2PATH 17 | 18 | if [ ! -d $SC2PATH ]; then 19 | echo 'StarCraftII is not installed. Installing now ...'; 20 | wget http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip 21 | unzip -P iagreetotheeula SC2.4.10.zip 22 | rm -rf SC2.4.10.zip 23 | else 24 | echo 'StarCraftII is already installed.' 25 | fi 26 | 27 | echo 'Adding SMAC maps.' 28 | MAP_DIR="$SC2PATH/Maps/" 29 | echo 'MAP_DIR is set to '$MAP_DIR 30 | 31 | if [ ! -d $MAP_DIR ]; then 32 | mkdir -p $MAP_DIR 33 | fi 34 | 35 | cd .. 36 | wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip 37 | unzip SMAC_Maps.zip 38 | mv SMAC_Maps $MAP_DIR 39 | rm -rf SMAC_Maps.zip 40 | 41 | echo 'StarCraft II and SMAC are installed.' 42 | 43 | -------------------------------------------------------------------------------- /src/utils/timehelper.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | def print_time(start_time, T, t_max, episode, episode_rewards): 6 | time_elapsed = time.time() - start_time 7 | T = max(1, T) 8 | time_left = time_elapsed * (t_max - T) / T 9 | # Just in case its over 100 days 10 | time_left = min(time_left, 60 * 60 * 24 * 100) 11 | last_reward = "N\A" 12 | if len(episode_rewards) > 5: 13 | last_reward = "{:.2f}".format(np.mean(episode_rewards[-50:])) 14 | print("\033[F\033[F\x1b[KEp: {:,}, T: {:,}/{:,}, Reward: {}, \n\x1b[KElapsed: {}, Left: {}\n".format(episode, T, t_max, last_reward, time_str(time_elapsed), time_str(time_left)), " " * 10, end="\r") 15 | 16 | 17 | def time_left(start_time, t_start, t_current, t_max): 18 | if t_current >= t_max: 19 | return "-" 20 | time_elapsed = time.time() - start_time 21 | t_current = max(1, t_current) 22 | time_left = time_elapsed * (t_max - t_current) / (t_current - t_start) 23 | # Just in case its over 100 days 24 | time_left = min(time_left, 60 * 60 * 24 * 100) 25 | return time_str(time_left) 26 | 27 | 28 | def time_str(s): 29 | """ 30 | Convert seconds to a nicer string showing days, hours, minutes and seconds 31 | """ 32 | days, remainder = divmod(s, 60 * 60 * 24) 33 | hours, remainder = divmod(remainder, 60 * 60) 34 | minutes, seconds = divmod(remainder, 60) 35 | string = "" 36 | if days > 0: 37 | string += "{:d} days, ".format(int(days)) 38 | if hours > 0: 39 | string += "{:d} hours, ".format(int(hours)) 40 | if minutes > 0: 41 | string += "{:d} minutes, ".format(int(minutes)) 42 | string += "{:d} seconds".format(int(seconds)) 43 | return string 44 | -------------------------------------------------------------------------------- /src/envs/multiagentenv.py: -------------------------------------------------------------------------------- 1 | class MultiAgentEnv(object): 2 | 3 | def step(self, actions): 4 | """ Returns reward, terminated, info """ 5 | raise NotImplementedError 6 | 7 | def get_obs(self): 8 | """ Returns all agent observations in a list """ 9 | raise NotImplementedError 10 | 11 | def get_obs_agent(self, agent_id): 12 | """ Returns observation for agent_id """ 13 | raise NotImplementedError 14 | 15 | def get_obs_size(self): 16 | """ Returns the shape of the observation """ 17 | raise NotImplementedError 18 | 19 | def get_state(self): 20 | raise NotImplementedError 21 | 22 | def get_state_size(self): 23 | """ Returns the shape of the state""" 24 | raise NotImplementedError 25 | 26 | def get_avail_actions(self): 27 | raise NotImplementedError 28 | 29 | def get_avail_agent_actions(self, agent_id): 30 | """ Returns the available actions for agent_id """ 31 | raise NotImplementedError 32 | 33 | def get_total_actions(self): 34 | """ Returns the total number of actions an agent could ever take """ 35 | # TODO: This is only suitable for a discrete 1 dimensional action space for each agent 36 | raise NotImplementedError 37 | 38 | def reset(self): 39 | """ Returns initial observations and states""" 40 | raise NotImplementedError 41 | 42 | def render(self): 43 | raise NotImplementedError 44 | 45 | def close(self): 46 | raise NotImplementedError 47 | 48 | def seed(self): 49 | raise NotImplementedError 50 | 51 | def save_replay(self): 52 | raise NotImplementedError 53 | 54 | def get_env_info(self): 55 | env_info = {"state_shape": self.get_state_size(), 56 | "obs_shape": self.get_obs_size(), 57 | "n_actions": self.get_total_actions(), 58 | "n_agents": self.n_agents, 59 | "episode_limit": self.episode_limit} 60 | return env_info 61 | -------------------------------------------------------------------------------- /src/modules/critics/fop.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FOPCritic(nn.Module): 7 | def __init__(self, scheme, args): 8 | super(FOPCritic, self).__init__() 9 | 10 | self.args = args 11 | self.n_actions = args.n_actions 12 | self.n_agents = args.n_agents 13 | 14 | input_shape = self._get_input_shape(scheme) 15 | self.output_type = "q" 16 | 17 | # Set up network layers 18 | self.fc1 = nn.Linear(input_shape, 64) 19 | self.fc2 = nn.Linear(64, 64) 20 | self.fc3 = nn.Linear(64, self.n_actions) 21 | 22 | def forward(self, inputs): 23 | x = F.relu(self.fc1(inputs)) 24 | x = F.relu(self.fc2(x)) 25 | q = self.fc3(x) 26 | return q 27 | 28 | def _build_inputs(self, batch, bs, max_t): 29 | inputs = [] 30 | # state, obs, action 31 | #inputs.append(batch["state"][:].unsqueeze(2).repeat(1, 1, self.n_agents, 1)) 32 | inputs.append(batch["obs"][:]) 33 | # last actions 34 | #if self.args.obs_last_action: 35 | # last_action = [] 36 | # last_action.append(actions[:, 0:1].squeeze(2)) 37 | # last_action.append(actions[:, :-1].squeeze(2)) 38 | # last_action = th.cat([x for x in last_action], dim = 1) 39 | # inputs.append(last_action) 40 | #agent id 41 | inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)) 42 | inputs = th.cat([x.reshape(bs, max_t, self.n_agents, -1) for x in inputs], dim=-1) 43 | return inputs 44 | 45 | def _get_input_shape(self, scheme): 46 | # state 47 | #input_shape = scheme["state"]["vshape"] 48 | # observation 49 | input_shape = scheme["obs"]["vshape"] 50 | # actions and last actions 51 | #if self.args.obs_last_action: 52 | # input_shape += scheme["actions_onehot"]["vshape"][0] * self.n_agents 53 | # agent id 54 | input_shape += self.n_agents 55 | return input_shape 56 | -------------------------------------------------------------------------------- /src/config/default.yaml: -------------------------------------------------------------------------------- 1 | # --- Defaults --- 2 | 3 | # --- pymarl options --- 4 | runner: "episode" # Runs 1 env for an episode 5 | mac: "basic_mac" # Basic controller 6 | env: "sc2" # Environment name 7 | env_args: {} # Arguments for the environment 8 | batch_size_run: 1 # Number of environments to run in parallel 9 | test_nepisode: 20 # Number of episodes to test for 10 | test_interval: 2000 # Test after {} timesteps have passed 11 | test_greedy: True # Use greedy evaluation (if False, will set epsilon floor to 0 12 | log_interval: 2000 # Log summary of stats after every {} timesteps 13 | runner_log_interval: 2000 # Log runner stats (not test stats) every {} timesteps 14 | learner_log_interval: 2000 # Log training stats every {} timesteps 15 | t_max: 10000 # Stop running after this many timesteps 16 | use_cuda: True # Use gpu by default unless it isn't available 17 | buffer_cpu_only: True # If true we won't keep all of the replay buffer in vram 18 | 19 | # --- Logging options --- 20 | use_tensorboard: False # Log results to tensorboard 21 | save_model: False # Save the models to disk 22 | save_model_interval: 2000000 # Save models after this many timesteps 23 | checkpoint_path: "" # Load a checkpoint from this path 24 | evaluate: False # Evaluate model for test_nepisode episodes and quit (no training) 25 | load_step: 0 # Load model trained on this many timesteps (0 if choose max possible) 26 | save_replay: False # Saving the replay of the model loaded from checkpoint_path 27 | local_results_path: "results" # Path for local results 28 | 29 | # --- RL hyperparameters --- 30 | gamma: 0.99 31 | batch_size: 32 # Number of episodes to train on 32 | buffer_size: 32 # Size of the replay buffer 33 | lr: 0.0005 # Learning rate for agents 34 | critic_lr: 0.0005 # Learning rate for critics 35 | optim_alpha: 0.99 # RMSProp alpha 36 | optim_eps: 0.00001 # RMSProp epsilon 37 | grad_norm_clip: 10 # Reduce magnitude of gradients above this L2 norm 38 | 39 | # --- Agent parameters --- 40 | agent: "rnn" # Default rnn agent 41 | rnn_hidden_dim: 64 # Size of hidden state for default rnn agent 42 | obs_agent_id: True # Include the agent's one_hot id in the observation 43 | obs_last_action: True # Include the agent's last action (one_hot) in the observation 44 | 45 | # --- Experiment running params --- 46 | repeat_id: 1 47 | label: "default_label" 48 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import logging 3 | import numpy as np 4 | 5 | class Logger: 6 | def __init__(self, console_logger): 7 | self.console_logger = console_logger 8 | 9 | self.use_tb = False 10 | self.use_sacred = False 11 | self.use_hdf = False 12 | 13 | self.stats = defaultdict(lambda: []) 14 | 15 | def setup_tb(self, directory_name): 16 | # Import here so it doesn't have to be installed if you don't use it 17 | from tensorboard_logger import configure, log_value 18 | configure(directory_name) 19 | self.tb_logger = log_value 20 | self.use_tb = True 21 | 22 | def setup_sacred(self, sacred_run_dict): 23 | self.sacred_info = sacred_run_dict.info 24 | self.use_sacred = True 25 | 26 | def log_stat(self, key, value, t, to_sacred=True): 27 | self.stats[key].append((t, value)) 28 | 29 | if self.use_tb: 30 | self.tb_logger(key, value, t) 31 | 32 | if self.use_sacred and to_sacred: 33 | if key in self.sacred_info: 34 | self.sacred_info["{}_T".format(key)].append(t) 35 | self.sacred_info[key].append(value) 36 | else: 37 | self.sacred_info["{}_T".format(key)] = [t] 38 | self.sacred_info[key] = [value] 39 | 40 | def print_recent_stats(self): 41 | log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(*self.stats["episode"][-1]) 42 | i = 0 43 | for (k, v) in sorted(self.stats.items()): 44 | if k == "episode": 45 | continue 46 | i += 1 47 | window = 5 if k != "epsilon" else 1 48 | item = "{:.4f}".format(np.mean([x[1] for x in self.stats[k][-window:]])) 49 | log_str += "{:<25}{:>8}".format(k + ":", item) 50 | log_str += "\n" if i % 4 == 0 else "\t" 51 | self.console_logger.info(log_str) 52 | 53 | 54 | # set up a custom logger 55 | def get_logger(): 56 | logger = logging.getLogger() 57 | logger.handlers = [] 58 | ch = logging.StreamHandler() 59 | formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S') 60 | ch.setFormatter(formatter) 61 | logger.addHandler(ch) 62 | logger.setLevel('DEBUG') 63 | 64 | return logger 65 | 66 | -------------------------------------------------------------------------------- /src/modules/mixers/qmix.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class QMixer(nn.Module): 8 | def __init__(self, args): 9 | super(QMixer, self).__init__() 10 | 11 | self.args = args 12 | self.n_agents = args.n_agents 13 | self.state_dim = int(np.prod(args.state_shape)) 14 | 15 | self.embed_dim = args.mixing_embed_dim 16 | 17 | if getattr(args, "hypernet_layers", 1) == 1: 18 | self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) 19 | self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) 20 | elif getattr(args, "hypernet_layers", 1) == 2: 21 | hypernet_embed = self.args.hypernet_embed 22 | self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 23 | nn.ReLU(), 24 | nn.Linear(hypernet_embed, self.embed_dim * self.n_agents)) 25 | self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 26 | nn.ReLU(), 27 | nn.Linear(hypernet_embed, self.embed_dim)) 28 | elif getattr(args, "hypernet_layers", 1) > 2: 29 | raise Exception("Sorry >2 hypernet layers is not implemented!") 30 | else: 31 | raise Exception("Error setting number of hypernet layers.") 32 | 33 | # State dependent bias for hidden layer 34 | self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 35 | 36 | # V(s) instead of a bias for the last layers 37 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 38 | nn.ReLU(), 39 | nn.Linear(self.embed_dim, 1)) 40 | 41 | def forward(self, agent_qs, states): 42 | bs = agent_qs.size(0) 43 | states = states.reshape(-1, self.state_dim) 44 | agent_qs = agent_qs.view(-1, 1, self.n_agents) 45 | # First layer 46 | w1 = th.abs(self.hyper_w_1(states)) 47 | b1 = self.hyper_b_1(states) 48 | w1 = w1.view(-1, self.n_agents, self.embed_dim) 49 | b1 = b1.view(-1, 1, self.embed_dim) 50 | hidden = F.elu(th.bmm(agent_qs, w1) + b1) 51 | # Second layer 52 | w_final = th.abs(self.hyper_w_final(states)) 53 | w_final = w_final.view(-1, self.embed_dim, 1) 54 | # State-dependent bias 55 | v = self.V(states).view(-1, 1, 1) 56 | # Compute final output 57 | y = th.bmm(hidden, w_final) + v 58 | # Reshape and return 59 | q_tot = y.view(bs, -1, 1) 60 | return q_tot 61 | -------------------------------------------------------------------------------- /src/modules/critics/coma.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class COMACritic(nn.Module): 7 | def __init__(self, scheme, args): 8 | super(COMACritic, self).__init__() 9 | 10 | self.args = args 11 | self.n_actions = args.n_actions 12 | self.n_agents = args.n_agents 13 | 14 | input_shape = self._get_input_shape(scheme) 15 | self.output_type = "q" 16 | 17 | # Set up network layers 18 | self.fc1 = nn.Linear(input_shape, 128) 19 | self.fc2 = nn.Linear(128, 128) 20 | self.fc3 = nn.Linear(128, self.n_actions) 21 | 22 | def forward(self, batch, t=None): 23 | inputs = self._build_inputs(batch, t=t) 24 | x = F.relu(self.fc1(inputs)) 25 | x = F.relu(self.fc2(x)) 26 | q = self.fc3(x) 27 | return q 28 | 29 | def _build_inputs(self, batch, t=None): 30 | bs = batch.batch_size 31 | max_t = batch.max_seq_length if t is None else 1 32 | ts = slice(None) if t is None else slice(t, t+1) 33 | inputs = [] 34 | # state 35 | inputs.append(batch["state"][:, ts].unsqueeze(2).repeat(1, 1, self.n_agents, 1)) 36 | 37 | # observation 38 | inputs.append(batch["obs"][:, ts]) 39 | 40 | # actions (masked out by agent) 41 | actions = batch["actions_onehot"][:, ts].view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1) 42 | agent_mask = (1 - th.eye(self.n_agents, device=batch.device)) 43 | agent_mask = agent_mask.view(-1, 1).repeat(1, self.n_actions).view(self.n_agents, -1) 44 | inputs.append(actions * agent_mask.unsqueeze(0).unsqueeze(0)) 45 | 46 | # last actions 47 | if t == 0: 48 | inputs.append(th.zeros_like(batch["actions_onehot"][:, 0:1]).view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1)) 49 | elif isinstance(t, int): 50 | inputs.append(batch["actions_onehot"][:, slice(t-1, t)].view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1)) 51 | else: 52 | last_actions = th.cat([th.zeros_like(batch["actions_onehot"][:, 0:1]), batch["actions_onehot"][:, :-1]], dim=1) 53 | last_actions = last_actions.view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1) 54 | inputs.append(last_actions) 55 | 56 | inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)) 57 | 58 | inputs = th.cat([x.reshape(bs, max_t, self.n_agents, -1) for x in inputs], dim=-1) 59 | return inputs 60 | 61 | def _get_input_shape(self, scheme): 62 | # state 63 | input_shape = scheme["state"]["vshape"] 64 | # observation 65 | input_shape += scheme["obs"]["vshape"] 66 | # actions and last actions 67 | input_shape += scheme["actions_onehot"]["vshape"][0] * self.n_agents * 2 68 | # agent id 69 | input_shape += self.n_agents 70 | return input_shape -------------------------------------------------------------------------------- /src/modules/mixers/fop.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class FOPMixer(nn.Module): 8 | def __init__(self, args): 9 | super(FOPMixer, self).__init__() 10 | self.args = args 11 | self.n_agents = args.n_agents 12 | self.n_actions = args.n_actions 13 | self.state_dim = int(np.prod(args.state_shape)) 14 | self.action_dim = args.n_agents * self.n_actions 15 | self.state_action_dim = self.state_dim + self.action_dim 16 | self.n_head = args.n_head 17 | self.embed_dim = args.mixing_embed_dim 18 | 19 | self.key_extractors = nn.ModuleList() 20 | self.agents_extractors = nn.ModuleList() 21 | self.action_extractors = nn.ModuleList() 22 | 23 | for i in range(self.n_head): # multi-head attention 24 | self.key_extractors.append(nn.Linear(self.state_dim, 1)) 25 | self.agents_extractors.append(nn.Linear(self.state_dim, self.n_agents)) 26 | self.action_extractors.append(nn.Linear(self.state_action_dim, self.n_agents)) 27 | 28 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 29 | nn.ReLU(), 30 | nn.Linear(self.embed_dim, 1)) 31 | 32 | def forward(self, agent_qs, states, actions=None, vs=None): 33 | bs = agent_qs.size(0) 34 | 35 | v = self.V(states).reshape(-1, 1).repeat(1, self.n_agents) / self.n_agents 36 | 37 | agent_qs = agent_qs.reshape(-1, self.n_agents) 38 | vs = vs.reshape(-1, self.n_agents) 39 | 40 | adv_q = (agent_qs - vs).detach() 41 | lambda_weight = self.lambda_weight(states, actions)-1 42 | 43 | adv_tot = th.sum(adv_q * lambda_weight, dim=1).reshape(bs, -1, 1) 44 | v_tot = th.sum(agent_qs + v, dim=-1).reshape(bs, -1, 1) 45 | 46 | return adv_tot + v_tot 47 | 48 | def lambda_weight(self, states, actions): 49 | states = states.reshape(-1, self.state_dim) 50 | actions = actions.reshape(-1, self.action_dim) 51 | state_actions = th.cat([states, actions], dim=1) 52 | 53 | head_keys = [k_ext(states) for k_ext in self.key_extractors] 54 | head_agents = [k_ext(states) for k_ext in self.agents_extractors] 55 | head_actions = [sel_ext(state_actions) for sel_ext in self.action_extractors] 56 | 57 | lambda_weights = [] 58 | 59 | for head_key, head_agents, head_action in zip(head_keys, head_agents, head_actions): 60 | key = th.abs(head_key).repeat(1, self.n_agents) + 1e-10 61 | agents = F.sigmoid(head_agents) 62 | action = F.sigmoid(head_action) 63 | weights = key * agents * action 64 | lambda_weights.append(weights) 65 | 66 | lambdas = th.stack(lambda_weights, dim=1) 67 | lambdas = lambdas.reshape(-1, self.n_head, self.n_agents).sum(dim=1) 68 | 69 | return lambdas.reshape(-1, self.n_agents) 70 | 71 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import collections 4 | from os.path import dirname, abspath 5 | from copy import deepcopy 6 | from sacred import Experiment, SETTINGS 7 | from sacred.observers import FileStorageObserver 8 | from sacred.utils import apply_backspaces_and_linefeeds 9 | import sys 10 | import torch as th 11 | from utils.logging import get_logger 12 | import yaml 13 | 14 | from run import run 15 | 16 | SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console 17 | logger = get_logger() 18 | 19 | ex = Experiment("pymarl") 20 | ex.logger = logger 21 | ex.captured_out_filter = apply_backspaces_and_linefeeds 22 | 23 | results_path = os.path.join(dirname(dirname(abspath(__file__))), "results") 24 | 25 | 26 | @ex.main 27 | def my_main(_run, _config, _log): 28 | # Setting the random seed throughout the modules 29 | config = config_copy(_config) 30 | np.random.seed(config["seed"]) 31 | th.manual_seed(config["seed"]) 32 | config['env_args']['seed'] = config["seed"] 33 | 34 | # run the framework 35 | run(_run, config, _log) 36 | 37 | 38 | def _get_config(params, arg_name, subfolder): 39 | config_name = None 40 | for _i, _v in enumerate(params): 41 | if _v.split("=")[0] == arg_name: 42 | config_name = _v.split("=")[1] 43 | del params[_i] 44 | break 45 | 46 | if config_name is not None: 47 | with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f: 48 | try: 49 | config_dict = yaml.load(f) 50 | except yaml.YAMLError as exc: 51 | assert False, "{}.yaml error: {}".format(config_name, exc) 52 | return config_dict 53 | 54 | 55 | def recursive_dict_update(d, u): 56 | for k, v in u.items(): 57 | if isinstance(v, collections.Mapping): 58 | d[k] = recursive_dict_update(d.get(k, {}), v) 59 | else: 60 | d[k] = v 61 | return d 62 | 63 | 64 | def config_copy(config): 65 | if isinstance(config, dict): 66 | return {k: config_copy(v) for k, v in config.items()} 67 | elif isinstance(config, list): 68 | return [config_copy(v) for v in config] 69 | else: 70 | return deepcopy(config) 71 | 72 | 73 | if __name__ == '__main__': 74 | params = deepcopy(sys.argv) 75 | 76 | # Get the defaults from default.yaml 77 | with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f: 78 | try: 79 | config_dict = yaml.load(f) 80 | except yaml.YAMLError as exc: 81 | assert False, "default.yaml error: {}".format(exc) 82 | 83 | # Load algorithm and env base configs 84 | env_config = _get_config(params, "--env-config", "envs") 85 | alg_config = _get_config(params, "--config", "algs") 86 | # config_dict = {**config_dict, **env_config, **alg_config} 87 | config_dict = recursive_dict_update(config_dict, env_config) 88 | config_dict = recursive_dict_update(config_dict, alg_config) 89 | 90 | # now add all the config to sacred 91 | ex.add_config(config_dict) 92 | 93 | # Save to disk by default for sacred 94 | logger.info("Saving to FileStorageObserver in results/sacred.") 95 | file_obs_path = os.path.join(results_path, "sacred") 96 | ex.observers.append(FileStorageObserver.create(file_obs_path)) 97 | 98 | ex.run_commandline(params) 99 | 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FOP: Factorizing Optimal Joint Policy of Maximum-Entropy Multi-AgentReinforcement Learning 2 | 3 | ## Note 4 | 5 | This codebase accompanies paper "FOP: Factorizing Optimal Joint Policy of Maximum-Entropy Multi-AgentReinforcement Learning"([link](http://proceedings.mlr.press/v139/zhang21m.html)). The implementation is based on [PyMARL](https://github.com/oxwhirl/pymarl) and [SMAC](https://github.com/oxwhirl/smac) codebases which are open-sourced. 6 | 7 | The implementation of the following methods can also be found in this codebase, which are finished by the authors of [PyMARL](https://github.com/oxwhirl/pymarl): 8 | - [**QMIX**: QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1803.11485) 9 | - [**COMA**: Counterfactual Multi-Agent Policy Gradients](https://arxiv.org/abs/1705.08926) 10 | - [**VDN**: Value-Decomposition Networks For Cooperative Multi-Agent Learning](https://arxiv.org/abs/1706.05296) 11 | - [**IQL**: Independent Q-Learning](https://arxiv.org/abs/1511.08779) 12 | - [**QTRAN**: QTRAN: Learning to Factorize with Transformation for Cooperative Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1905.05408) 13 | 14 | ## Installation instructions 15 | 16 | Build the Dockerfile using 17 | ```shell 18 | cd docker 19 | bash build.sh 20 | ``` 21 | 22 | Set up StarCraft II and SMAC: 23 | ```shell 24 | bash install_sc2.sh 25 | ``` 26 | 27 | This will download SC2 into the 3rdparty folder and copy the maps necessary to run over. 28 | 29 | The `requirements.txt` file can be used to install the necessary packages into a virtual environment (not recomended). 30 | 31 | ## Run an experiment 32 | 33 | ```shell 34 | python3 src/main.py --config=fop --env-config=sc2 with env_args.map_name=2c_vs_64zg 35 | ``` 36 | 37 | The config files act as defaults for an algorithm or environment. 38 | 39 | They are all located in `src/config`. 40 | `--config` refers to the config files in `src/config/algs` 41 | `--env-config` refers to the config files in `src/config/envs` 42 | 43 | To run experiments using the Docker container: 44 | ```shell 45 | bash run.sh $GPU python3 src/main.py --config=fop --env-config=sc2 with env_args.map_name=2c_vs_64zg 46 | ``` 47 | 48 | All results will be stored in the `Results` folder. 49 | 50 | The previous config files used for the SMAC Beta have the suffix `_beta`. 51 | 52 | ## Saving and loading learnt models 53 | 54 | ### Saving models 55 | 56 | You can save the learnt models to disk by setting `save_model = True`, which is set to `False` by default. The frequency of saving models can be adjusted using `save_model_interval` configuration. Models will be saved in the result directory, under the folder called *models*. The directory corresponding each run will contain models saved throughout the experiment, each within a folder corresponding to the number of timesteps passed since starting the learning process. 57 | 58 | ### Loading models 59 | 60 | Learnt models can be loaded using the `checkpoint_path` parameter, after which the learning will proceed from the corresponding timestep. 61 | 62 | ## Watching StarCraft II replays 63 | 64 | `save_replay` option allows saving replays of models which are loaded using `checkpoint_path`. Once the model is successfully loaded, `test_nepisode` number of episodes are run on the test mode and a .SC2Replay file is saved in the Replay directory of StarCraft II. Please make sure to use the episode runner if you wish to save a replay, i.e., `runner=episode`. The name of the saved replay file starts with the given `env_args.save_replay_prefix` (map_name if empty), followed by the current timestamp. 65 | 66 | The saved replays can be watched by double-clicking on them or using the following command: 67 | 68 | ```shell 69 | python -m pysc2.bin.play --norender --rgb_minimap_size 0 --replay NAME.SC2Replay 70 | ``` 71 | 72 | **Note:** Replays cannot be watched using the Linux version of StarCraft II. Please use either the Mac or Windows version of the StarCraft II client. 73 | 74 | -------------------------------------------------------------------------------- /src/components/action_selectors.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch.distributions import Categorical 3 | from .epsilon_schedules import DecayThenFlatSchedule 4 | import torch.nn.functional as F 5 | 6 | REGISTRY = {} 7 | 8 | ''' 9 | class GumbelSoftmax(): 10 | def __init__(self, args): 11 | self.args = args 12 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, 13 | decay="linear") 14 | self.epsilon = self.schedule.eval(0) 15 | self.eps = 1e-10 16 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 17 | masked_policies = agent_inputs.clone() 18 | masked_policies[avail_actions == 0.0] = 0.0 19 | self.epsilon = self.schedule.eval(t_env) 20 | 21 | if test_mode: 22 | picked_actions = masked_policies.max(dim=2)[1] 23 | else: 24 | U = th.rand(masked_policies.size()).cuda() 25 | y = masked_policies - th.log(-th.log(U + self.eps) + self.eps) 26 | y = F.softmax(y / 1, dim=-1) 27 | y[avail_actions == 0.0] = 0.0 28 | picked_actions = y.max(dim=2)[1] 29 | 30 | return picked_actions 31 | 32 | REGISTRY["gumbel"] = GumbelSoftmax 33 | ''' 34 | 35 | class MultinomialActionSelector(): 36 | 37 | def __init__(self, args): 38 | self.args = args 39 | 40 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, 41 | decay="linear") 42 | self.epsilon = self.schedule.eval(0) 43 | self.test_greedy = getattr(args, "test_greedy", True) 44 | 45 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 46 | masked_policies = agent_inputs.clone() 47 | masked_policies[avail_actions == 0.0] = 0.0 48 | 49 | self.epsilon = self.schedule.eval(t_env) 50 | 51 | if test_mode and self.test_greedy: 52 | picked_actions = masked_policies.max(dim=2)[1] 53 | else: 54 | picked_actions = Categorical(masked_policies).sample().long() 55 | 56 | random_numbers = th.rand_like(agent_inputs[:, :, 0]) 57 | pick_random = (random_numbers < self.epsilon).long() 58 | random_actions = Categorical(avail_actions.float()).sample().long() 59 | picked_actions = pick_random * random_actions + (1 - pick_random) * picked_actions 60 | 61 | if not (th.gather(avail_actions, dim=2, index=picked_actions.unsqueeze(2)) > 0.99).all(): 62 | return self.select_action(agent_inputs, avail_actions, t_env, test_mode) 63 | 64 | return picked_actions 65 | 66 | 67 | REGISTRY["multinomial"] = MultinomialActionSelector 68 | 69 | 70 | class EpsilonGreedyActionSelector(): 71 | 72 | def __init__(self, args): 73 | self.args = args 74 | 75 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, 76 | decay="linear") 77 | self.epsilon = self.schedule.eval(0) 78 | 79 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 80 | 81 | # Assuming agent_inputs is a batch of Q-Values for each agent bav 82 | self.epsilon = self.schedule.eval(t_env) 83 | 84 | if test_mode: 85 | # Greedy action selection only 86 | self.epsilon = 0.0 87 | 88 | # mask actions that are excluded from selection 89 | masked_q_values = agent_inputs.clone() 90 | masked_q_values[avail_actions == 0.0] = -float("inf") # should never be selected! 91 | 92 | random_numbers = th.rand_like(agent_inputs[:, :, 0]) 93 | pick_random = (random_numbers < self.epsilon).long() 94 | random_actions = Categorical(avail_actions.float()).sample().long() 95 | 96 | picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1] 97 | return picked_actions 98 | 99 | 100 | REGISTRY["epsilon_greedy"] = EpsilonGreedyActionSelector 101 | -------------------------------------------------------------------------------- /src/controllers/basic_controller.py: -------------------------------------------------------------------------------- 1 | from modules.agents import REGISTRY as agent_REGISTRY 2 | from components.action_selectors import REGISTRY as action_REGISTRY 3 | import torch as th 4 | 5 | 6 | # This multi-agent controller shares parameters between agents 7 | class BasicMAC: 8 | def __init__(self, scheme, groups, args): 9 | self.n_agents = args.n_agents 10 | self.args = args 11 | input_shape = self._get_input_shape(scheme) 12 | self._build_agents(input_shape) 13 | self.agent_output_type = args.agent_output_type 14 | 15 | self.action_selector = action_REGISTRY[args.action_selector](args) 16 | 17 | self.hidden_states = None 18 | 19 | def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False): 20 | # Only select actions for the selected batch elements in bs 21 | avail_actions = ep_batch["avail_actions"][:, t_ep] 22 | agent_outputs = self.forward(ep_batch, t_ep, test_mode=test_mode) 23 | chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode) 24 | return chosen_actions 25 | 26 | def forward(self, ep_batch, t, test_mode=False): 27 | agent_inputs = self._build_inputs(ep_batch, t) 28 | avail_actions = ep_batch["avail_actions"][:, t] 29 | agent_outs, self.hidden_states = self.agent(agent_inputs, self.hidden_states) 30 | 31 | # Softmax the agent outputs if they're policy logits 32 | if self.agent_output_type == "pi_logits": 33 | 34 | if getattr(self.args, "mask_before_softmax", True): 35 | # Make the logits for unavailable actions very negative to minimise their affect on the softmax 36 | reshaped_avail_actions = avail_actions.reshape(ep_batch.batch_size * self.n_agents, -1) 37 | agent_outs[reshaped_avail_actions == 0] = -1e11 38 | 39 | agent_outs = th.nn.functional.softmax(agent_outs, dim=-1) 40 | if not test_mode: 41 | # Epsilon floor 42 | epsilon_action_num = agent_outs.size(-1) 43 | if getattr(self.args, "mask_before_softmax", True): 44 | # With probability epsilon, we will pick an available action uniformly 45 | epsilon_action_num = reshaped_avail_actions.sum(dim=1, keepdim=True).float() 46 | 47 | agent_outs = ((1 - self.action_selector.epsilon) * agent_outs 48 | + th.ones_like(agent_outs) * self.action_selector.epsilon/epsilon_action_num) 49 | 50 | if getattr(self.args, "mask_before_softmax", True): 51 | # Zero out the unavailable actions 52 | agent_outs[reshaped_avail_actions == 0] = 0.0 53 | 54 | return agent_outs.view(ep_batch.batch_size, self.n_agents, -1) 55 | 56 | def init_hidden(self, batch_size): 57 | self.hidden_states = self.agent.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1) # bav 58 | 59 | def parameters(self): 60 | return self.agent.parameters() 61 | 62 | def load_state(self, other_mac): 63 | self.agent.load_state_dict(other_mac.agent.state_dict()) 64 | 65 | def cuda(self): 66 | self.agent.cuda() 67 | 68 | def save_models(self, path): 69 | th.save(self.agent.state_dict(), "{}/agent.th".format(path)) 70 | 71 | def load_models(self, path): 72 | self.agent.load_state_dict(th.load("{}/agent.th".format(path), map_location=lambda storage, loc: storage)) 73 | 74 | def _build_agents(self, input_shape): 75 | self.agent = agent_REGISTRY[self.args.agent](input_shape, self.args) 76 | 77 | def _build_inputs(self, batch, t): 78 | # Assumes homogenous agents with flat observations. 79 | # Other MACs might want to e.g. delegate building inputs to each agent 80 | bs = batch.batch_size 81 | inputs = [] 82 | inputs.append(batch["obs"][:, t]) # b1av 83 | if self.args.obs_last_action: 84 | if t == 0: 85 | inputs.append(th.zeros_like(batch["actions_onehot"][:, t])) 86 | else: 87 | inputs.append(batch["actions_onehot"][:, t-1]) 88 | if self.args.obs_agent_id: 89 | inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1)) 90 | 91 | inputs = th.cat([x.reshape(bs*self.n_agents, -1) for x in inputs], dim=1) 92 | return inputs 93 | 94 | def _get_input_shape(self, scheme): 95 | input_shape = scheme["obs"]["vshape"] 96 | if self.args.obs_last_action: 97 | input_shape += scheme["actions_onehot"]["vshape"][0] 98 | if self.args.obs_agent_id: 99 | input_shape += self.n_agents 100 | 101 | return input_shape 102 | -------------------------------------------------------------------------------- /src/runners/episode_runner.py: -------------------------------------------------------------------------------- 1 | from envs import REGISTRY as env_REGISTRY 2 | from functools import partial 3 | from components.episode_buffer import EpisodeBatch 4 | import numpy as np 5 | 6 | 7 | class EpisodeRunner: 8 | 9 | def __init__(self, args, logger): 10 | self.args = args 11 | self.logger = logger 12 | self.batch_size = self.args.batch_size_run 13 | assert self.batch_size == 1 14 | 15 | self.env = env_REGISTRY[self.args.env](**self.args.env_args) 16 | self.episode_limit = self.env.episode_limit 17 | self.t = 0 18 | 19 | self.t_env = 0 20 | 21 | self.train_returns = [] 22 | self.test_returns = [] 23 | self.train_stats = {} 24 | self.test_stats = {} 25 | 26 | # Log the first run 27 | self.log_train_stats_t = -1000000 28 | 29 | def setup(self, scheme, groups, preprocess, mac): 30 | self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, 31 | preprocess=preprocess, device=self.args.device) 32 | self.mac = mac 33 | 34 | def get_env_info(self): 35 | return self.env.get_env_info() 36 | 37 | def save_replay(self): 38 | self.env.save_replay() 39 | 40 | def close_env(self): 41 | self.env.close() 42 | 43 | def reset(self): 44 | self.batch = self.new_batch() 45 | self.env.reset() 46 | self.t = 0 47 | 48 | def run(self, test_mode=False): 49 | self.reset() 50 | 51 | terminated = False 52 | episode_return = 0 53 | self.mac.init_hidden(batch_size=self.batch_size) 54 | 55 | while not terminated: 56 | 57 | pre_transition_data = { 58 | "state": [self.env.get_state()], 59 | "avail_actions": [self.env.get_avail_actions()], 60 | "obs": [self.env.get_obs()] 61 | } 62 | 63 | self.batch.update(pre_transition_data, ts=self.t) 64 | 65 | # Pass the entire batch of experiences up till now to the agents 66 | # Receive the actions for each agent at this timestep in a batch of size 1 67 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode) 68 | 69 | reward, terminated, env_info = self.env.step(actions[0]) 70 | episode_return += reward 71 | 72 | post_transition_data = { 73 | "actions": actions, 74 | "reward": [(reward,)], 75 | "terminated": [(terminated != env_info.get("episode_limit", False),)], 76 | } 77 | 78 | self.batch.update(post_transition_data, ts=self.t) 79 | 80 | self.t += 1 81 | 82 | last_data = { 83 | "state": [self.env.get_state()], 84 | "avail_actions": [self.env.get_avail_actions()], 85 | "obs": [self.env.get_obs()] 86 | } 87 | self.batch.update(last_data, ts=self.t) 88 | 89 | # Select actions in the last stored state 90 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode) 91 | self.batch.update({"actions": actions}, ts=self.t) 92 | 93 | cur_stats = self.test_stats if test_mode else self.train_stats 94 | cur_returns = self.test_returns if test_mode else self.train_returns 95 | log_prefix = "test_" if test_mode else "" 96 | cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) | set(env_info)}) 97 | cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0) 98 | cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0) 99 | 100 | if not test_mode: 101 | self.t_env += self.t 102 | 103 | cur_returns.append(episode_return) 104 | 105 | if test_mode and (len(self.test_returns) == self.args.test_nepisode): 106 | self._log(cur_returns, cur_stats, log_prefix) 107 | elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval: 108 | self._log(cur_returns, cur_stats, log_prefix) 109 | if hasattr(self.mac.action_selector, "epsilon"): 110 | self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env) 111 | self.log_train_stats_t = self.t_env 112 | 113 | return self.batch 114 | 115 | def _log(self, returns, stats, prefix): 116 | self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env) 117 | self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env) 118 | returns.clear() 119 | 120 | for k, v in stats.items(): 121 | if k != "n_episodes": 122 | self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env) 123 | stats.clear() 124 | -------------------------------------------------------------------------------- /src/modules/mixers/qtran.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class QTranBase(nn.Module): 8 | def __init__(self, args): 9 | super(QTranBase, self).__init__() 10 | 11 | self.args = args 12 | 13 | self.n_agents = args.n_agents 14 | self.n_actions = args.n_actions 15 | self.state_dim = int(np.prod(args.state_shape)) 16 | self.arch = self.args.qtran_arch # QTran architecture 17 | 18 | self.embed_dim = args.mixing_embed_dim 19 | 20 | # Q(s,u) 21 | if self.arch == "coma_critic": 22 | # Q takes [state, u] as input 23 | q_input_size = self.state_dim + (self.n_agents * self.n_actions) 24 | elif self.arch == "qtran_paper": 25 | # Q takes [state, agent_action_observation_encodings] 26 | q_input_size = self.state_dim + self.args.rnn_hidden_dim + self.n_actions 27 | else: 28 | raise Exception("{} is not a valid QTran architecture".format(self.arch)) 29 | 30 | if self.args.network_size == "small": 31 | self.Q = nn.Sequential(nn.Linear(q_input_size, self.embed_dim), 32 | nn.ReLU(), 33 | nn.Linear(self.embed_dim, self.embed_dim), 34 | nn.ReLU(), 35 | nn.Linear(self.embed_dim, 1)) 36 | 37 | # V(s) 38 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 39 | nn.ReLU(), 40 | nn.Linear(self.embed_dim, self.embed_dim), 41 | nn.ReLU(), 42 | nn.Linear(self.embed_dim, 1)) 43 | ae_input = self.args.rnn_hidden_dim + self.n_actions 44 | self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), 45 | nn.ReLU(), 46 | nn.Linear(ae_input, ae_input)) 47 | elif self.args.network_size == "big": 48 | self.Q = nn.Sequential(nn.Linear(q_input_size, self.embed_dim), 49 | nn.ReLU(), 50 | nn.Linear(self.embed_dim, self.embed_dim), 51 | nn.ReLU(), 52 | nn.Linear(self.embed_dim, self.embed_dim), 53 | nn.ReLU(), 54 | nn.Linear(self.embed_dim, 1)) 55 | # V(s) 56 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 57 | nn.ReLU(), 58 | nn.Linear(self.embed_dim, self.embed_dim), 59 | nn.ReLU(), 60 | nn.Linear(self.embed_dim, self.embed_dim), 61 | nn.ReLU(), 62 | nn.Linear(self.embed_dim, 1)) 63 | ae_input = self.args.rnn_hidden_dim + self.n_actions 64 | self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), 65 | nn.ReLU(), 66 | nn.Linear(ae_input, ae_input)) 67 | else: 68 | assert False 69 | 70 | def forward(self, batch, hidden_states, actions=None): 71 | bs = batch.batch_size 72 | ts = batch.max_seq_length 73 | 74 | states = batch["state"].reshape(bs * ts, self.state_dim) 75 | 76 | if self.arch == "coma_critic": 77 | if actions is None: 78 | # Use the actions taken by the agents 79 | actions = batch["actions_onehot"].reshape(bs * ts, self.n_agents * self.n_actions) 80 | else: 81 | # It will arrive as (bs, ts, agents, actions), we need to reshape it 82 | actions = actions.reshape(bs * ts, self.n_agents * self.n_actions) 83 | inputs = th.cat([states, actions], dim=1) 84 | elif self.arch == "qtran_paper": 85 | if actions is None: 86 | # Use the actions taken by the agents 87 | actions = batch["actions_onehot"].reshape(bs * ts, self.n_agents, self.n_actions) 88 | else: 89 | # It will arrive as (bs, ts, agents, actions), we need to reshape it 90 | actions = actions.reshape(bs * ts, self.n_agents, self.n_actions) 91 | 92 | hidden_states = hidden_states.reshape(bs * ts, self.n_agents, -1) 93 | agent_state_action_input = th.cat([hidden_states, actions], dim=2) 94 | agent_state_action_encoding = self.action_encoding(agent_state_action_input.reshape(bs * ts * self.n_agents, -1)).reshape(bs * ts, self.n_agents, -1) 95 | agent_state_action_encoding = agent_state_action_encoding.sum(dim=1) # Sum across agents 96 | 97 | inputs = th.cat([states, agent_state_action_encoding], dim=1) 98 | 99 | q_outputs = self.Q(inputs) 100 | 101 | states = batch["state"].reshape(bs * ts, self.state_dim) 102 | v_outputs = self.V(states) 103 | 104 | return q_outputs, v_outputs 105 | 106 | -------------------------------------------------------------------------------- /src/learners/q_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.vdn import VDNMixer 4 | from modules.mixers.qmix import QMixer 5 | import torch as th 6 | from torch.optim import RMSprop 7 | 8 | 9 | class QLearner: 10 | def __init__(self, mac, scheme, logger, args): 11 | self.args = args 12 | self.mac = mac 13 | self.logger = logger 14 | 15 | self.params = list(mac.parameters()) 16 | 17 | self.last_target_update_episode = 0 18 | 19 | self.mixer = None 20 | if args.mixer is not None: 21 | if args.mixer == "vdn": 22 | self.mixer = VDNMixer() 23 | elif args.mixer == "qmix": 24 | self.mixer = QMixer(args) 25 | else: 26 | raise ValueError("Mixer {} not recognised.".format(args.mixer)) 27 | self.params += list(self.mixer.parameters()) 28 | self.target_mixer = copy.deepcopy(self.mixer) 29 | 30 | self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 31 | 32 | # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC 33 | self.target_mac = copy.deepcopy(mac) 34 | 35 | self.log_stats_t = -self.args.learner_log_interval - 1 36 | 37 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, show_demo=False, save_data=None): 38 | # Get the relevant quantities 39 | rewards = batch["reward"][:, :-1] 40 | actions = batch["actions"][:, :-1] 41 | terminated = batch["terminated"][:, :-1].float() 42 | mask = batch["filled"][:, :-1].float() 43 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 44 | avail_actions = batch["avail_actions"] 45 | 46 | # Calculate estimated Q-Values 47 | mac_out = [] 48 | self.mac.init_hidden(batch.batch_size) 49 | for t in range(batch.max_seq_length): 50 | agent_outs = self.mac.forward(batch, t=t) 51 | mac_out.append(agent_outs) 52 | mac_out = th.stack(mac_out, dim=1) # Concat over time 53 | 54 | # Pick the Q-Values for the actions taken by each agent 55 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 56 | 57 | # Calculate the Q-Values necessary for the target 58 | target_mac_out = [] 59 | self.target_mac.init_hidden(batch.batch_size) 60 | for t in range(batch.max_seq_length): 61 | target_agent_outs = self.target_mac.forward(batch, t=t) 62 | target_mac_out.append(target_agent_outs) 63 | 64 | # We don't need the first timesteps Q-Value estimate for calculating targets 65 | target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time 66 | 67 | # Mask out unavailable actions 68 | target_mac_out[avail_actions[:, 1:] == 0] = -9999999 69 | 70 | # Max over target Q-Values 71 | if self.args.double_q: 72 | # Get actions that maximise live Q (for double q-learning) 73 | mac_out_detach = mac_out.clone().detach() 74 | mac_out_detach[avail_actions == 0] = -9999999 75 | cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] 76 | target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) 77 | else: 78 | target_max_qvals = target_mac_out.max(dim=3)[0] 79 | 80 | # Mix 81 | if self.mixer is not None: 82 | chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) 83 | target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) 84 | 85 | # Calculate 1-step Q-Learning targets 86 | targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals 87 | 88 | # Td-error 89 | td_error = (chosen_action_qvals - targets.detach()) 90 | 91 | mask = mask.expand_as(td_error) 92 | 93 | # 0-out the targets that came from padded data 94 | masked_td_error = td_error * mask 95 | 96 | # Normal L2 loss, take mean over actual data 97 | loss = (masked_td_error ** 2).sum() / mask.sum() 98 | 99 | # Optimise 100 | self.optimiser.zero_grad() 101 | loss.backward() 102 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 103 | self.optimiser.step() 104 | 105 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 106 | self._update_targets() 107 | self.last_target_update_episode = episode_num 108 | 109 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 110 | self.logger.log_stat("loss", loss.item(), t_env) 111 | self.logger.log_stat("grad_norm", grad_norm, t_env) 112 | mask_elems = mask.sum().item() 113 | self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) 114 | self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 115 | self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 116 | self.log_stats_t = t_env 117 | 118 | def _update_targets(self): 119 | self.target_mac.load_state(self.mac) 120 | if self.mixer is not None: 121 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 122 | self.logger.console_logger.info("Updated target network") 123 | 124 | def cuda(self): 125 | self.mac.cuda() 126 | self.target_mac.cuda() 127 | if self.mixer is not None: 128 | self.mixer.cuda() 129 | self.target_mixer.cuda() 130 | 131 | def save_models(self, path): 132 | self.mac.save_models(path) 133 | if self.mixer is not None: 134 | th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) 135 | th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) 136 | 137 | def load_models(self, path): 138 | self.mac.load_models(path) 139 | # Not quite right but I don't want to save target networks 140 | self.target_mac.load_models(path) 141 | if self.mixer is not None: 142 | self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) 143 | self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) 144 | -------------------------------------------------------------------------------- /src/learners/coma_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.critics.coma import COMACritic 4 | from utils.rl_utils import build_td_lambda_targets 5 | import torch as th 6 | from torch.optim import RMSprop 7 | 8 | 9 | class COMALearner: 10 | def __init__(self, mac, scheme, logger, args): 11 | self.args = args 12 | self.n_agents = args.n_agents 13 | self.n_actions = args.n_actions 14 | self.mac = mac 15 | self.logger = logger 16 | 17 | self.last_target_update_step = 0 18 | self.critic_training_steps = 0 19 | 20 | self.log_stats_t = -self.args.learner_log_interval - 1 21 | 22 | self.critic = COMACritic(scheme, args) 23 | self.target_critic = copy.deepcopy(self.critic) 24 | 25 | self.agent_params = list(mac.parameters()) 26 | self.critic_params = list(self.critic.parameters()) 27 | self.params = self.agent_params + self.critic_params 28 | 29 | self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 30 | self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) 31 | 32 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 33 | # Get the relevant quantities 34 | bs = batch.batch_size 35 | max_t = batch.max_seq_length 36 | rewards = batch["reward"][:, :-1] 37 | actions = batch["actions"][:, :] 38 | terminated = batch["terminated"][:, :-1].float() 39 | mask = batch["filled"][:, :-1].float() 40 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 41 | avail_actions = batch["avail_actions"][:, :-1] 42 | 43 | critic_mask = mask.clone() 44 | 45 | mask = mask.repeat(1, 1, self.n_agents).view(-1) 46 | 47 | q_vals, critic_train_stats = self._train_critic(batch, rewards, terminated, actions, avail_actions, 48 | critic_mask, bs, max_t) 49 | 50 | actions = actions[:,:-1] 51 | 52 | mac_out = [] 53 | self.mac.init_hidden(batch.batch_size) 54 | for t in range(batch.max_seq_length - 1): 55 | agent_outs = self.mac.forward(batch, t=t) 56 | mac_out.append(agent_outs) 57 | mac_out = th.stack(mac_out, dim=1) # Concat over time 58 | 59 | # Mask out unavailable actions, renormalise (as in action selection) 60 | mac_out[avail_actions == 0] = 0 61 | mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True) 62 | mac_out[avail_actions == 0] = 0 63 | 64 | # Calculated baseline 65 | q_vals = q_vals.reshape(-1, self.n_actions) 66 | pi = mac_out.view(-1, self.n_actions) 67 | baseline = (pi * q_vals).sum(-1).detach() 68 | 69 | # Calculate policy grad with mask 70 | q_taken = th.gather(q_vals, dim=1, index=actions.reshape(-1, 1)).squeeze(1) 71 | pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) 72 | pi_taken[mask == 0] = 1.0 73 | log_pi_taken = th.log(pi_taken) 74 | 75 | advantages = (q_taken - baseline).detach() 76 | 77 | coma_loss = - ((advantages * log_pi_taken) * mask).sum() / mask.sum() 78 | 79 | # Optimise agents 80 | self.agent_optimiser.zero_grad() 81 | coma_loss.backward() 82 | grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) 83 | self.agent_optimiser.step() 84 | 85 | if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0: 86 | self._update_targets() 87 | self.last_target_update_step = self.critic_training_steps 88 | 89 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 90 | ts_logged = len(critic_train_stats["critic_loss"]) 91 | for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean"]: 92 | self.logger.log_stat(key, sum(critic_train_stats[key])/ts_logged, t_env) 93 | 94 | self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) 95 | self.logger.log_stat("coma_loss", coma_loss.item(), t_env) 96 | self.logger.log_stat("agent_grad_norm", grad_norm, t_env) 97 | self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) 98 | self.log_stats_t = t_env 99 | 100 | def _train_critic(self, batch, rewards, terminated, actions, avail_actions, mask, bs, max_t): 101 | # Optimise critic 102 | target_q_vals = self.target_critic(batch)[:, :] 103 | targets_taken = th.gather(target_q_vals, dim=3, index=actions).squeeze(3) 104 | 105 | # Calculate td-lambda targets 106 | targets = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda) 107 | 108 | q_vals = th.zeros_like(target_q_vals)[:, :-1] 109 | 110 | running_log = { 111 | "critic_loss": [], 112 | "critic_grad_norm": [], 113 | "td_error_abs": [], 114 | "target_mean": [], 115 | "q_taken_mean": [], 116 | } 117 | 118 | for t in reversed(range(rewards.size(1))): 119 | mask_t = mask[:, t].expand(-1, self.n_agents) 120 | if mask_t.sum() == 0: 121 | continue 122 | 123 | q_t = self.critic(batch, t) 124 | q_vals[:, t] = q_t.view(bs, self.n_agents, self.n_actions) 125 | q_taken = th.gather(q_t, dim=3, index=actions[:, t:t+1]).squeeze(3).squeeze(1) 126 | targets_t = targets[:, t] 127 | 128 | td_error = (q_taken - targets_t.detach()) 129 | 130 | # 0-out the targets that came from padded data 131 | masked_td_error = td_error * mask_t 132 | 133 | # Normal L2 loss, take mean over actual data 134 | loss = (masked_td_error ** 2).sum() / mask_t.sum() 135 | self.critic_optimiser.zero_grad() 136 | loss.backward() 137 | grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip) 138 | self.critic_optimiser.step() 139 | self.critic_training_steps += 1 140 | 141 | running_log["critic_loss"].append(loss.item()) 142 | running_log["critic_grad_norm"].append(grad_norm) 143 | mask_elems = mask_t.sum().item() 144 | running_log["td_error_abs"].append((masked_td_error.abs().sum().item() / mask_elems)) 145 | running_log["q_taken_mean"].append((q_taken * mask_t).sum().item() / mask_elems) 146 | running_log["target_mean"].append((targets_t * mask_t).sum().item() / mask_elems) 147 | 148 | return q_vals, running_log 149 | 150 | def _update_targets(self): 151 | self.target_critic.load_state_dict(self.critic.state_dict()) 152 | self.logger.console_logger.info("Updated target network") 153 | 154 | def cuda(self): 155 | self.mac.cuda() 156 | self.critic.cuda() 157 | self.target_critic.cuda() 158 | 159 | def save_models(self, path): 160 | self.mac.save_models(path) 161 | th.save(self.critic.state_dict(), "{}/critic.th".format(path)) 162 | th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) 163 | th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) 164 | 165 | def load_models(self, path): 166 | self.mac.load_models(path) 167 | self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) 168 | # Not quite right but I don't want to save target networks 169 | self.target_critic.load_state_dict(self.critic.state_dict()) 170 | self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) 171 | self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage)) 172 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pprint 4 | import time 5 | import threading 6 | import torch as th 7 | from types import SimpleNamespace as SN 8 | from utils.logging import Logger 9 | from utils.timehelper import time_left, time_str 10 | from os.path import dirname, abspath 11 | 12 | from learners import REGISTRY as le_REGISTRY 13 | from runners import REGISTRY as r_REGISTRY 14 | from controllers import REGISTRY as mac_REGISTRY 15 | from components.episode_buffer import ReplayBuffer 16 | from components.transforms import OneHot 17 | 18 | def run(_run, _config, _log): 19 | 20 | # check args sanity 21 | _config = args_sanity_check(_config, _log) 22 | args = SN(**_config) 23 | args.device = "cuda" if args.use_cuda else "cpu" 24 | 25 | # setup loggers 26 | logger = Logger(_log) 27 | 28 | _log.info("Experiment Parameters:") 29 | experiment_params = pprint.pformat(_config, 30 | indent=4, 31 | width=1) 32 | _log.info("\n\n" + experiment_params + "\n") 33 | 34 | # configure tensorboard logger 35 | unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 36 | args.unique_token = unique_token 37 | if args.use_tensorboard: 38 | tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs") 39 | tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token) 40 | logger.setup_tb(tb_exp_direc) 41 | 42 | # sacred is on by default 43 | logger.setup_sacred(_run) 44 | 45 | # Run and train 46 | run_sequential(args=args, logger=logger) 47 | 48 | # Clean up after finishing 49 | print("Exiting Main") 50 | 51 | print("Stopping all threads") 52 | for t in threading.enumerate(): 53 | if t.name != "MainThread": 54 | print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon)) 55 | t.join(timeout=1) 56 | print("Thread joined") 57 | 58 | print("Exiting script") 59 | 60 | # Making sure framework really exits 61 | os._exit(os.EX_OK) 62 | 63 | def evaluate_sequential(args, runner): 64 | 65 | for _ in range(args.test_nepisode): 66 | runner.run(test_mode=True) 67 | 68 | if args.save_replay: 69 | runner.save_replay() 70 | 71 | runner.close_env() 72 | 73 | def run_sequential(args, logger): 74 | 75 | # Init runner so we can get env info 76 | runner = r_REGISTRY[args.runner](args=args, logger=logger) 77 | 78 | # Set up schemes and groups here 79 | env_info = runner.get_env_info() 80 | args.episode_limit = env_info["episode_limit"] 81 | args.n_agents = env_info["n_agents"] 82 | args.n_actions = env_info["n_actions"] 83 | args.state_shape = env_info["state_shape"] 84 | args.unit_dim = 16#env_info["unit_dim"] 85 | 86 | # Default/Base scheme 87 | scheme = { 88 | "state": {"vshape": env_info["state_shape"]}, 89 | "obs": {"vshape": env_info["obs_shape"], "group": "agents"}, 90 | "actions": {"vshape": (1,), "group": "agents", "dtype": th.long}, 91 | "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int}, 92 | "reward": {"vshape": (1,)}, 93 | "terminated": {"vshape": (1,), "dtype": th.uint8}, 94 | } 95 | groups = { 96 | "agents": args.n_agents 97 | } 98 | preprocess = { 99 | "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)]) 100 | } 101 | 102 | buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1, 103 | args.burn_in_period, 104 | preprocess=preprocess, 105 | device="cpu" if args.buffer_cpu_only else args.device) 106 | 107 | # Setup multiagent controller here 108 | mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args) 109 | 110 | # Give runner the scheme 111 | runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac) 112 | 113 | # Learner 114 | learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args) 115 | 116 | if args.use_cuda: 117 | learner.cuda() 118 | 119 | if args.checkpoint_path != "": 120 | 121 | timesteps = [] 122 | timestep_to_load = 0 123 | 124 | if not os.path.isdir(args.checkpoint_path): 125 | logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path)) 126 | return 127 | 128 | # Go through all files in args.checkpoint_path 129 | for name in os.listdir(args.checkpoint_path): 130 | full_name = os.path.join(args.checkpoint_path, name) 131 | # Check if they are dirs the names of which are numbers 132 | if os.path.isdir(full_name) and name.isdigit(): 133 | timesteps.append(int(name)) 134 | 135 | if args.load_step == 0: 136 | # choose the max timestep 137 | timestep_to_load = max(timesteps) 138 | else: 139 | # choose the timestep closest to load_step 140 | timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step)) 141 | 142 | model_path = os.path.join(args.checkpoint_path, str(timestep_to_load)) 143 | 144 | logger.console_logger.info("Loading model from {}".format(model_path)) 145 | learner.load_models(model_path) 146 | runner.t_env = timestep_to_load 147 | 148 | if args.evaluate or args.save_replay: 149 | evaluate_sequential(args, runner) 150 | return 151 | 152 | # start training 153 | episode = 0 154 | last_test_T = -args.test_interval - 1 155 | last_log_T = 0 156 | model_save_time = 0 157 | 158 | start_time = time.time() 159 | last_time = start_time 160 | 161 | logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max)) 162 | 163 | while runner.t_env <= args.t_max: 164 | 165 | # Run for a whole episode at a time 166 | episode_batch = runner.run(test_mode=False) 167 | buffer.insert_episode_batch(episode_batch) 168 | 169 | if buffer.can_sample(args.batch_size): 170 | episode_sample = buffer.sample(args.batch_size) 171 | 172 | # Truncate batch to only filled timesteps 173 | max_ep_t = episode_sample.max_t_filled() 174 | episode_sample = episode_sample[:, :max_ep_t] 175 | 176 | if episode_sample.device != args.device: 177 | episode_sample.to(args.device) 178 | 179 | learner.train(episode_sample, runner.t_env, episode) 180 | 181 | # Execute test runs once in a while 182 | n_test_runs = max(1, args.test_nepisode // runner.batch_size) 183 | if (runner.t_env - last_test_T) / args.test_interval >= 1.0 : 184 | 185 | logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max)) 186 | logger.console_logger.info("Estimated time left: {}. Time passed: {}".format( 187 | time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time))) 188 | last_time = time.time() 189 | 190 | last_test_T = runner.t_env 191 | for _ in range(n_test_runs): 192 | runner.run(test_mode=True) 193 | 194 | if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0): 195 | model_save_time = runner.t_env 196 | save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env)) 197 | #"results/models/{}".format(unique_token) 198 | os.makedirs(save_path, exist_ok=True) 199 | logger.console_logger.info("Saving models to {}".format(save_path)) 200 | 201 | # learner should handle saving/loading -- delegate actor save/load to mac, 202 | # use appropriate filenames to do critics, optimizer states 203 | learner.save_models(save_path) 204 | 205 | episode += args.batch_size_run 206 | 207 | if (runner.t_env - last_log_T) >= args.log_interval: 208 | logger.log_stat("episode", episode, runner.t_env) 209 | logger.print_recent_stats() 210 | last_log_T = runner.t_env 211 | 212 | runner.close_env() 213 | logger.console_logger.info("Finished Training") 214 | 215 | 216 | def args_sanity_check(config, _log): 217 | 218 | # set CUDA flags 219 | # config["use_cuda"] = True # Use cuda whenever possible! 220 | if config["use_cuda"] and not th.cuda.is_available(): 221 | config["use_cuda"] = False 222 | _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!") 223 | 224 | if config["test_nepisode"] < config["batch_size_run"]: 225 | config["test_nepisode"] = config["batch_size_run"] 226 | else: 227 | config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"] 228 | 229 | return config 230 | -------------------------------------------------------------------------------- /src/learners/qtran_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.qtran import QTranBase 4 | import torch as th 5 | from torch.optim import RMSprop, Adam 6 | 7 | 8 | class QLearner: 9 | def __init__(self, mac, scheme, logger, args): 10 | self.args = args 11 | self.mac = mac 12 | self.logger = logger 13 | 14 | self.params = list(mac.parameters()) 15 | 16 | self.last_target_update_episode = 0 17 | 18 | self.mixer = None 19 | if args.mixer == "qtran_base": 20 | self.mixer = QTranBase(args) 21 | elif args.mixer == "qtran_alt": 22 | raise Exception("Not implemented here!") 23 | 24 | self.params += list(self.mixer.parameters()) 25 | self.target_mixer = copy.deepcopy(self.mixer) 26 | 27 | self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 28 | 29 | # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC 30 | self.target_mac = copy.deepcopy(mac) 31 | 32 | self.log_stats_t = -self.args.learner_log_interval - 1 33 | 34 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, show_demo=False, save_data=None): 35 | # Get the relevant quantities 36 | rewards = batch["reward"][:, :-1] 37 | actions = batch["actions"][:, :-1] 38 | terminated = batch["terminated"][:, :-1].float() 39 | mask = batch["filled"][:, :-1].float() 40 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 41 | avail_actions = batch["avail_actions"] 42 | 43 | # Calculate estimated Q-Values 44 | mac_out = [] 45 | mac_hidden_states = [] 46 | self.mac.init_hidden(batch.batch_size) 47 | for t in range(batch.max_seq_length): 48 | agent_outs = self.mac.forward(batch, t=t) 49 | mac_out.append(agent_outs) 50 | mac_hidden_states.append(self.mac.hidden_states) 51 | mac_out = th.stack(mac_out, dim=1) # Concat over time 52 | mac_hidden_states = th.stack(mac_hidden_states, dim=1) 53 | mac_hidden_states = mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1,2) #btav 54 | 55 | # Pick the Q-Values for the actions taken by each agent 56 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 57 | 58 | # Calculate the Q-Values necessary for the target 59 | target_mac_out = [] 60 | target_mac_hidden_states = [] 61 | self.target_mac.init_hidden(batch.batch_size) 62 | for t in range(batch.max_seq_length): 63 | target_agent_outs = self.target_mac.forward(batch, t=t) 64 | target_mac_out.append(target_agent_outs) 65 | target_mac_hidden_states.append(self.target_mac.hidden_states) 66 | 67 | # We don't need the first timesteps Q-Value estimate for calculating targets 68 | target_mac_out = th.stack(target_mac_out[:], dim=1) # Concat across time 69 | target_mac_hidden_states = th.stack(target_mac_hidden_states, dim=1) 70 | target_mac_hidden_states = target_mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1,2) #btav 71 | 72 | # Mask out unavailable actions 73 | target_mac_out[avail_actions[:, :] == 0] = -9999999 # From OG deepmarl 74 | mac_out_maxs = mac_out.clone() 75 | mac_out_maxs[avail_actions == 0] = -9999999 76 | 77 | # Best joint action computed by target agents 78 | target_max_actions = target_mac_out.max(dim=3, keepdim=True)[1] 79 | # Best joint-action computed by regular agents 80 | max_actions_qvals, max_actions_current = mac_out_maxs[:, :].max(dim=3, keepdim=True) 81 | 82 | if self.args.mixer == "qtran_base": 83 | # -- TD Loss -- 84 | # Joint-action Q-Value estimates 85 | joint_qs, vs = self.mixer(batch[:, :-1], mac_hidden_states[:,:-1]) 86 | 87 | # Need to argmax across the target agents' actions to compute target joint-action Q-Values 88 | if self.args.double_q: 89 | max_actions_current_ = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) 90 | max_actions_current_onehot = max_actions_current_.scatter(3, max_actions_current[:, :], 1) 91 | max_actions_onehot = max_actions_current_onehot 92 | else: 93 | max_actions = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) 94 | max_actions_onehot = max_actions.scatter(3, target_max_actions[:, :], 1) 95 | target_joint_qs, target_vs = self.target_mixer(batch[:, 1:], hidden_states=target_mac_hidden_states[:,1:], actions=max_actions_onehot[:,1:]) 96 | 97 | # Td loss targets 98 | td_targets = rewards.reshape(-1,1) + self.args.gamma * (1 - terminated.reshape(-1, 1)) * target_joint_qs 99 | td_error = (joint_qs - td_targets.detach()) 100 | masked_td_error = td_error * mask.reshape(-1, 1) 101 | td_loss = (masked_td_error ** 2).sum() / mask.sum() 102 | # -- TD Loss -- 103 | 104 | # -- Opt Loss -- 105 | # Argmax across the current agents' actions 106 | if not self.args.double_q: # Already computed if we're doing double Q-Learning 107 | max_actions_current_ = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device ) 108 | max_actions_current_onehot = max_actions_current_.scatter(3, max_actions_current[:, :], 1) 109 | max_joint_qs, _ = self.mixer(batch[:, :-1], mac_hidden_states[:,:-1], actions=max_actions_current_onehot[:,:-1]) # Don't use the target network and target agent max actions as per author's email 110 | 111 | # max_actions_qvals = th.gather(mac_out[:, :-1], dim=3, index=max_actions_current[:,:-1]) 112 | opt_error = max_actions_qvals[:,:-1].sum(dim=2).reshape(-1, 1) - max_joint_qs.detach() + vs 113 | masked_opt_error = opt_error * mask.reshape(-1, 1) 114 | opt_loss = (masked_opt_error ** 2).sum() / mask.sum() 115 | # -- Opt Loss -- 116 | 117 | # -- Nopt Loss -- 118 | # target_joint_qs, _ = self.target_mixer(batch[:, :-1]) 119 | nopt_values = chosen_action_qvals.sum(dim=2).reshape(-1, 1) - joint_qs.detach() + vs # Don't use target networks here either 120 | nopt_error = nopt_values.clamp(max=0) 121 | masked_nopt_error = nopt_error * mask.reshape(-1, 1) 122 | nopt_loss = (masked_nopt_error ** 2).sum() / mask.sum() 123 | # -- Nopt loss -- 124 | 125 | elif self.args.mixer == "qtran_alt": 126 | raise Exception("Not supported yet.") 127 | 128 | loss = td_loss + self.args.opt_loss * opt_loss + self.args.nopt_min_loss * nopt_loss 129 | 130 | # Optimise 131 | self.optimiser.zero_grad() 132 | loss.backward() 133 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 134 | self.optimiser.step() 135 | 136 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 137 | self._update_targets() 138 | self.last_target_update_episode = episode_num 139 | 140 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 141 | self.logger.log_stat("loss", loss.item(), t_env) 142 | self.logger.log_stat("td_loss", td_loss.item(), t_env) 143 | self.logger.log_stat("opt_loss", opt_loss.item(), t_env) 144 | self.logger.log_stat("nopt_loss", nopt_loss.item(), t_env) 145 | self.logger.log_stat("grad_norm", grad_norm, t_env) 146 | if self.args.mixer == "qtran_base": 147 | mask_elems = mask.sum().item() 148 | self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) 149 | self.logger.log_stat("td_targets", ((masked_td_error).sum().item()/mask_elems), t_env) 150 | self.logger.log_stat("td_chosen_qs", (joint_qs.sum().item()/mask_elems), t_env) 151 | self.logger.log_stat("v_mean", (vs.sum().item()/mask_elems), t_env) 152 | self.logger.log_stat("agent_indiv_qs", ((chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents)), t_env) 153 | self.log_stats_t = t_env 154 | 155 | def _update_targets(self): 156 | self.target_mac.load_state(self.mac) 157 | if self.mixer is not None: 158 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 159 | self.logger.console_logger.info("Updated target network") 160 | 161 | def cuda(self): 162 | self.mac.cuda() 163 | self.target_mac.cuda() 164 | if self.mixer is not None: 165 | self.mixer.cuda() 166 | self.target_mixer.cuda() 167 | 168 | def save_models(self, path): 169 | self.mac.save_models(path) 170 | if self.mixer is not None: 171 | th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) 172 | th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) 173 | 174 | def load_models(self, path): 175 | self.mac.load_models(path) 176 | # Not quite right but I don't want to save target networks 177 | self.target_mac.load_models(path) 178 | if self.mixer is not None: 179 | self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) 180 | self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) 181 | -------------------------------------------------------------------------------- /src/runners/parallel_runner.py: -------------------------------------------------------------------------------- 1 | from envs import REGISTRY as env_REGISTRY 2 | from functools import partial 3 | from components.episode_buffer import EpisodeBatch 4 | from multiprocessing import Pipe, Process 5 | import numpy as np 6 | import torch as th 7 | 8 | 9 | # Based (very) heavily on SubprocVecEnv from OpenAI Baselines 10 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py 11 | class ParallelRunner: 12 | 13 | def __init__(self, args, logger): 14 | self.args = args 15 | self.logger = logger 16 | self.batch_size = self.args.batch_size_run 17 | 18 | # Make subprocesses for the envs 19 | self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)]) 20 | env_fn = env_REGISTRY[self.args.env] 21 | self.ps = [Process(target=env_worker, args=(worker_conn, CloudpickleWrapper(partial(env_fn, **self.args.env_args)))) 22 | for worker_conn in self.worker_conns] 23 | 24 | for p in self.ps: 25 | p.daemon = True 26 | p.start() 27 | 28 | self.parent_conns[0].send(("get_env_info", None)) 29 | self.env_info = self.parent_conns[0].recv() 30 | self.episode_limit = self.env_info["episode_limit"] 31 | 32 | self.t = 0 33 | 34 | self.t_env = 0 35 | 36 | self.train_returns = [] 37 | self.test_returns = [] 38 | self.train_stats = {} 39 | self.test_stats = {} 40 | 41 | self.log_train_stats_t = -100000 42 | 43 | def setup(self, scheme, groups, preprocess, mac): 44 | self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, 45 | preprocess=preprocess, device=self.args.device) 46 | self.mac = mac 47 | self.scheme = scheme 48 | self.groups = groups 49 | self.preprocess = preprocess 50 | 51 | def get_env_info(self): 52 | return self.env_info 53 | 54 | def save_replay(self): 55 | pass 56 | 57 | def close_env(self): 58 | for parent_conn in self.parent_conns: 59 | parent_conn.send(("close", None)) 60 | 61 | def reset(self): 62 | self.batch = self.new_batch() 63 | 64 | # Reset the envs 65 | for parent_conn in self.parent_conns: 66 | parent_conn.send(("reset", None)) 67 | 68 | pre_transition_data = { 69 | "state": [], 70 | "avail_actions": [], 71 | "obs": [] 72 | } 73 | # Get the obs, state and avail_actions back 74 | for parent_conn in self.parent_conns: 75 | data = parent_conn.recv() 76 | pre_transition_data["state"].append(data["state"]) 77 | pre_transition_data["avail_actions"].append(data["avail_actions"]) 78 | pre_transition_data["obs"].append(data["obs"]) 79 | 80 | self.batch.update(pre_transition_data, ts=0) 81 | 82 | self.t = 0 83 | self.env_steps_this_run = 0 84 | 85 | def run(self, test_mode=False): 86 | self.reset() 87 | 88 | all_terminated = False 89 | episode_returns = [0 for _ in range(self.batch_size)] 90 | episode_lengths = [0 for _ in range(self.batch_size)] 91 | self.mac.init_hidden(batch_size=self.batch_size) 92 | terminated = [False for _ in range(self.batch_size)] 93 | envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed] 94 | final_env_infos = [] # may store extra stats like battle won. this is filled in ORDER OF TERMINATION 95 | 96 | while True: 97 | 98 | # Pass the entire batch of experiences up till now to the agents 99 | # Receive the actions for each agent at this timestep in a batch for each un-terminated env 100 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, bs=envs_not_terminated, test_mode=test_mode) 101 | cpu_actions = actions.to("cpu").numpy() 102 | 103 | # Update the actions taken 104 | actions_chosen = { 105 | "actions": actions.unsqueeze(1) 106 | } 107 | self.batch.update(actions_chosen, bs=envs_not_terminated, ts=self.t, mark_filled=False) 108 | 109 | # Send actions to each env 110 | action_idx = 0 111 | for idx, parent_conn in enumerate(self.parent_conns): 112 | if idx in envs_not_terminated: # We produced actions for this env 113 | if not terminated[idx]: # Only send the actions to the env if it hasn't terminated 114 | parent_conn.send(("step", cpu_actions[action_idx])) 115 | action_idx += 1 # actions is not a list over every env 116 | 117 | # Update envs_not_terminated 118 | envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed] 119 | all_terminated = all(terminated) 120 | if all_terminated: 121 | break 122 | 123 | # Post step data we will insert for the current timestep 124 | post_transition_data = { 125 | "reward": [], 126 | "terminated": [] 127 | } 128 | # Data for the next step we will insert in order to select an action 129 | pre_transition_data = { 130 | "state": [], 131 | "avail_actions": [], 132 | "obs": [] 133 | } 134 | 135 | # Receive data back for each unterminated env 136 | for idx, parent_conn in enumerate(self.parent_conns): 137 | if not terminated[idx]: 138 | data = parent_conn.recv() 139 | # Remaining data for this current timestep 140 | post_transition_data["reward"].append((data["reward"],)) 141 | 142 | episode_returns[idx] += data["reward"] 143 | episode_lengths[idx] += 1 144 | if not test_mode: 145 | self.env_steps_this_run += 1 146 | 147 | env_terminated = False 148 | if data["terminated"]: 149 | final_env_infos.append(data["info"]) 150 | if data["terminated"] and not data["info"].get("episode_limit", False): 151 | env_terminated = True 152 | terminated[idx] = data["terminated"] 153 | post_transition_data["terminated"].append((env_terminated,)) 154 | 155 | # Data for the next timestep needed to select an action 156 | pre_transition_data["state"].append(data["state"]) 157 | pre_transition_data["avail_actions"].append(data["avail_actions"]) 158 | pre_transition_data["obs"].append(data["obs"]) 159 | 160 | # Add post_transiton data into the batch 161 | self.batch.update(post_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=False) 162 | 163 | # Move onto the next timestep 164 | self.t += 1 165 | 166 | # Add the pre-transition data 167 | self.batch.update(pre_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=True) 168 | 169 | if not test_mode: 170 | self.t_env += self.env_steps_this_run 171 | 172 | # Get stats back for each env 173 | for parent_conn in self.parent_conns: 174 | parent_conn.send(("get_stats",None)) 175 | 176 | env_stats = [] 177 | for parent_conn in self.parent_conns: 178 | env_stat = parent_conn.recv() 179 | env_stats.append(env_stat) 180 | 181 | cur_stats = self.test_stats if test_mode else self.train_stats 182 | cur_returns = self.test_returns if test_mode else self.train_returns 183 | log_prefix = "test_" if test_mode else "" 184 | infos = [cur_stats] + final_env_infos 185 | cur_stats.update({k: sum(d.get(k, 0) for d in infos) for k in set.union(*[set(d) for d in infos])}) 186 | cur_stats["n_episodes"] = self.batch_size + cur_stats.get("n_episodes", 0) 187 | cur_stats["ep_length"] = sum(episode_lengths) + cur_stats.get("ep_length", 0) 188 | 189 | cur_returns.extend(episode_returns) 190 | 191 | n_test_runs = max(1, self.args.test_nepisode // self.batch_size) * self.batch_size 192 | if test_mode and (len(self.test_returns) == n_test_runs): 193 | self._log(cur_returns, cur_stats, log_prefix) 194 | elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval: 195 | self._log(cur_returns, cur_stats, log_prefix) 196 | if hasattr(self.mac.action_selector, "epsilon"): 197 | self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env) 198 | self.log_train_stats_t = self.t_env 199 | 200 | return self.batch 201 | 202 | def _log(self, returns, stats, prefix): 203 | self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env) 204 | self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env) 205 | returns.clear() 206 | 207 | for k, v in stats.items(): 208 | if k != "n_episodes": 209 | self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env) 210 | stats.clear() 211 | 212 | 213 | def env_worker(remote, env_fn): 214 | # Make environment 215 | env = env_fn.x() 216 | while True: 217 | cmd, data = remote.recv() 218 | if cmd == "step": 219 | actions = data 220 | # Take a step in the environment 221 | reward, terminated, env_info = env.step(actions) 222 | # Return the observations, avail_actions and state to make the next action 223 | state = env.get_state() 224 | avail_actions = env.get_avail_actions() 225 | obs = env.get_obs() 226 | remote.send({ 227 | # Data for the next timestep needed to pick an action 228 | "state": state, 229 | "avail_actions": avail_actions, 230 | "obs": obs, 231 | # Rest of the data for the current timestep 232 | "reward": reward, 233 | "terminated": terminated, 234 | "info": env_info 235 | }) 236 | elif cmd == "reset": 237 | env.reset() 238 | remote.send({ 239 | "state": env.get_state(), 240 | "avail_actions": env.get_avail_actions(), 241 | "obs": env.get_obs() 242 | }) 243 | elif cmd == "close": 244 | env.close() 245 | remote.close() 246 | break 247 | elif cmd == "get_env_info": 248 | remote.send(env.get_env_info()) 249 | elif cmd == "get_stats": 250 | remote.send(env.get_stats()) 251 | else: 252 | raise NotImplementedError 253 | 254 | 255 | class CloudpickleWrapper(): 256 | """ 257 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 258 | """ 259 | def __init__(self, x): 260 | self.x = x 261 | def __getstate__(self): 262 | import cloudpickle 263 | return cloudpickle.dumps(self.x) 264 | def __setstate__(self, ob): 265 | import pickle 266 | self.x = pickle.loads(ob) 267 | 268 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2017 Google Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /src/components/episode_buffer.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from types import SimpleNamespace as SN 4 | 5 | 6 | class EpisodeBatch: 7 | def __init__(self, 8 | scheme, 9 | groups, 10 | batch_size, 11 | max_seq_length, 12 | data=None, 13 | preprocess=None, 14 | device="cpu"): 15 | self.scheme = scheme.copy() 16 | self.groups = groups 17 | self.batch_size = batch_size 18 | self.max_seq_length = max_seq_length 19 | self.preprocess = {} if preprocess is None else preprocess 20 | self.device = device 21 | 22 | if data is not None: 23 | self.data = data 24 | else: 25 | self.data = SN() 26 | self.data.transition_data = {} 27 | self.data.episode_data = {} 28 | self._setup_data(self.scheme, self.groups, batch_size, max_seq_length, self.preprocess) 29 | 30 | def _setup_data(self, scheme, groups, batch_size, max_seq_length, preprocess): 31 | if preprocess is not None: 32 | for k in preprocess: 33 | assert k in scheme 34 | new_k = preprocess[k][0] 35 | transforms = preprocess[k][1] 36 | 37 | vshape = self.scheme[k]["vshape"] 38 | dtype = self.scheme[k]["dtype"] 39 | for transform in transforms: 40 | vshape, dtype = transform.infer_output_info(vshape, dtype) 41 | 42 | self.scheme[new_k] = { 43 | "vshape": vshape, 44 | "dtype": dtype 45 | } 46 | if "group" in self.scheme[k]: 47 | self.scheme[new_k]["group"] = self.scheme[k]["group"] 48 | if "episode_const" in self.scheme[k]: 49 | self.scheme[new_k]["episode_const"] = self.scheme[k]["episode_const"] 50 | 51 | assert "filled" not in scheme, '"filled" is a reserved key for masking.' 52 | scheme.update({ 53 | "filled": {"vshape": (1,), "dtype": th.long}, 54 | }) 55 | 56 | for field_key, field_info in scheme.items(): 57 | assert "vshape" in field_info, "Scheme must define vshape for {}".format(field_key) 58 | vshape = field_info["vshape"] 59 | episode_const = field_info.get("episode_const", False) 60 | group = field_info.get("group", None) 61 | dtype = field_info.get("dtype", th.float32) 62 | 63 | if isinstance(vshape, int): 64 | vshape = (vshape,) 65 | 66 | if group: 67 | assert group in groups, "Group {} must have its number of members defined in _groups_".format(group) 68 | shape = (groups[group], *vshape) 69 | else: 70 | shape = vshape 71 | 72 | if episode_const: 73 | self.data.episode_data[field_key] = th.zeros((batch_size, *shape), dtype=dtype, device=self.device) 74 | else: 75 | self.data.transition_data[field_key] = th.zeros((batch_size, max_seq_length, *shape), dtype=dtype, device=self.device) 76 | 77 | def extend(self, scheme, groups=None): 78 | self._setup_data(scheme, self.groups if groups is None else groups, self.batch_size, self.max_seq_length) 79 | 80 | def to(self, device): 81 | for k, v in self.data.transition_data.items(): 82 | self.data.transition_data[k] = v.to(device) 83 | for k, v in self.data.episode_data.items(): 84 | self.data.episode_data[k] = v.to(device) 85 | self.device = device 86 | 87 | def update(self, data, bs=slice(None), ts=slice(None), mark_filled=True): 88 | slices = self._parse_slices((bs, ts)) 89 | for k, v in data.items(): 90 | if k in self.data.transition_data: 91 | target = self.data.transition_data 92 | if mark_filled: 93 | target["filled"][slices] = 1 94 | mark_filled = False 95 | _slices = slices 96 | elif k in self.data.episode_data: 97 | target = self.data.episode_data 98 | _slices = slices[0] 99 | else: 100 | raise KeyError("{} not found in transition or episode data".format(k)) 101 | 102 | dtype = self.scheme[k].get("dtype", th.float32) 103 | v = th.tensor(v, dtype=dtype, device=self.device) 104 | self._check_safe_view(v, target[k][_slices]) 105 | target[k][_slices] = v.view_as(target[k][_slices]) 106 | 107 | if k in self.preprocess: 108 | new_k = self.preprocess[k][0] 109 | v = target[k][_slices] 110 | for transform in self.preprocess[k][1]: 111 | v = transform.transform(v) 112 | target[new_k][_slices] = v.view_as(target[new_k][_slices]) 113 | 114 | def _check_safe_view(self, v, dest): 115 | idx = len(v.shape) - 1 116 | for s in dest.shape[::-1]: 117 | if v.shape[idx] != s: 118 | if s != 1: 119 | raise ValueError("Unsafe reshape of {} to {}".format(v.shape, dest.shape)) 120 | else: 121 | idx -= 1 122 | 123 | def __getitem__(self, item): 124 | if isinstance(item, str): 125 | if item in self.data.episode_data: 126 | return self.data.episode_data[item] 127 | elif item in self.data.transition_data: 128 | return self.data.transition_data[item] 129 | else: 130 | raise ValueError 131 | elif isinstance(item, tuple) and all([isinstance(it, str) for it in item]): 132 | new_data = self._new_data_sn() 133 | for key in item: 134 | if key in self.data.transition_data: 135 | new_data.transition_data[key] = self.data.transition_data[key] 136 | elif key in self.data.episode_data: 137 | new_data.episode_data[key] = self.data.episode_data[key] 138 | else: 139 | raise KeyError("Unrecognised key {}".format(key)) 140 | 141 | # Update the scheme to only have the requested keys 142 | new_scheme = {key: self.scheme[key] for key in item} 143 | new_groups = {self.scheme[key]["group"]: self.groups[self.scheme[key]["group"]] 144 | for key in item if "group" in self.scheme[key]} 145 | ret = EpisodeBatch(new_scheme, new_groups, self.batch_size, self.max_seq_length, data=new_data, device=self.device) 146 | return ret 147 | else: 148 | item = self._parse_slices(item) 149 | new_data = self._new_data_sn() 150 | for k, v in self.data.transition_data.items(): 151 | new_data.transition_data[k] = v[item] 152 | for k, v in self.data.episode_data.items(): 153 | new_data.episode_data[k] = v[item[0]] 154 | 155 | ret_bs = self._get_num_items(item[0], self.batch_size) 156 | ret_max_t = self._get_num_items(item[1], self.max_seq_length) 157 | 158 | ret = EpisodeBatch(self.scheme, self.groups, ret_bs, ret_max_t, data=new_data, device=self.device) 159 | return ret 160 | 161 | def _get_num_items(self, indexing_item, max_size): 162 | if isinstance(indexing_item, list) or isinstance(indexing_item, np.ndarray): 163 | return len(indexing_item) 164 | elif isinstance(indexing_item, slice): 165 | _range = indexing_item.indices(max_size) 166 | return 1 + (_range[1] - _range[0] - 1)//_range[2] 167 | 168 | def _new_data_sn(self): 169 | new_data = SN() 170 | new_data.transition_data = {} 171 | new_data.episode_data = {} 172 | return new_data 173 | 174 | def _parse_slices(self, items): 175 | parsed = [] 176 | # Only batch slice given, add full time slice 177 | if (isinstance(items, slice) # slice a:b 178 | or isinstance(items, int) # int i 179 | or (isinstance(items, (list, np.ndarray, th.LongTensor, th.cuda.LongTensor))) # [a,b,c] 180 | ): 181 | items = (items, slice(None)) 182 | 183 | # Need the time indexing to be contiguous 184 | if isinstance(items[1], list): 185 | raise IndexError("Indexing across Time must be contiguous") 186 | 187 | for item in items: 188 | #TODO: stronger checks to ensure only supported options get through 189 | if isinstance(item, int): 190 | # Convert single indices to slices 191 | parsed.append(slice(item, item+1)) 192 | else: 193 | # Leave slices and lists as is 194 | parsed.append(item) 195 | return parsed 196 | 197 | def max_t_filled(self): 198 | return th.sum(self.data.transition_data["filled"], 1).max(0)[0] 199 | 200 | def __repr__(self): 201 | return "EpisodeBatch. Batch Size:{} Max_seq_len:{} Keys:{} Groups:{}".format(self.batch_size, 202 | self.max_seq_length, 203 | self.scheme.keys(), 204 | self.groups.keys()) 205 | 206 | 207 | class ReplayBuffer(EpisodeBatch): 208 | def __init__(self, scheme, groups, buffer_size, max_seq_length, burn_in_period, preprocess=None, device="cpu"): 209 | super(ReplayBuffer, self).__init__(scheme, groups, buffer_size, max_seq_length, preprocess=preprocess, device=device) 210 | self.buffer_size = buffer_size # same as self.batch_size but more explicit 211 | self.buffer_index = 0 212 | self.episodes_in_buffer = 0 213 | self.burn_in_period = burn_in_period 214 | 215 | def insert_episode_batch(self, ep_batch): 216 | if self.buffer_index + ep_batch.batch_size <= self.buffer_size: 217 | self.update(ep_batch.data.transition_data, 218 | slice(self.buffer_index, self.buffer_index + ep_batch.batch_size), 219 | slice(0, ep_batch.max_seq_length), 220 | mark_filled=False) 221 | self.update(ep_batch.data.episode_data, 222 | slice(self.buffer_index, self.buffer_index + ep_batch.batch_size)) 223 | self.buffer_index = (self.buffer_index + ep_batch.batch_size) 224 | self.episodes_in_buffer = max(self.episodes_in_buffer, self.buffer_index) 225 | self.buffer_index = self.buffer_index % self.buffer_size 226 | assert self.buffer_index < self.buffer_size 227 | else: 228 | buffer_left = self.buffer_size - self.buffer_index 229 | self.insert_episode_batch(ep_batch[0:buffer_left, :]) 230 | self.insert_episode_batch(ep_batch[buffer_left:, :]) 231 | 232 | def can_sample(self, batch_size): 233 | return self.episodes_in_buffer >= self.burn_in_period 234 | 235 | def sample(self, batch_size): 236 | assert self.can_sample(batch_size) 237 | if self.episodes_in_buffer == batch_size: 238 | return self[:batch_size] 239 | else: 240 | # Uniform sampling only atm 241 | ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False) 242 | return self[ep_ids] 243 | 244 | def __repr__(self): 245 | return "ReplayBuffer. {}/{} episodes. Keys:{} Groups:{}".format(self.episodes_in_buffer, 246 | self.buffer_size, 247 | self.scheme.keys(), 248 | self.groups.keys()) -------------------------------------------------------------------------------- /src/learners/fop_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.fop import FOPMixer 4 | import torch.nn.functional as F 5 | import torch as th 6 | from torch.optim import RMSprop 7 | import numpy as np 8 | from torch.distributions import Categorical 9 | from modules.critics.fop import FOPCritic 10 | from utils.rl_utils import build_td_lambda_targets 11 | 12 | class FOP_Learner: 13 | def __init__(self, mac, scheme, logger, args): 14 | self.args = args 15 | self.mac = mac 16 | self.logger = logger 17 | self.n_agents = args.n_agents 18 | self.n_actions = args.n_actions 19 | self.last_target_update_episode = 0 20 | self.critic_training_steps = 0 21 | 22 | self.log_stats_t = -self.args.learner_log_interval - 1 23 | 24 | self.critic1 = FOPCritic(scheme, args) 25 | self.critic2 = FOPCritic(scheme, args) 26 | 27 | self.mixer1 = FOPMixer(args) 28 | self.mixer2 = FOPMixer(args) 29 | 30 | self.target_mixer1 = copy.deepcopy(self.mixer1) 31 | self.target_mixer2 = copy.deepcopy(self.mixer2) 32 | 33 | self.target_critic1 = copy.deepcopy(self.critic1) 34 | self.target_critic2 = copy.deepcopy(self.critic2) 35 | 36 | self.agent_params = list(mac.parameters()) 37 | self.critic_params1 = list(self.critic1.parameters()) + list(self.mixer1.parameters()) 38 | self.critic_params2 = list(self.critic2.parameters()) + list(self.mixer2.parameters()) 39 | 40 | self.p_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 41 | self.c_optimiser1 = RMSprop(params=self.critic_params1, lr=args.c_lr, alpha=args.optim_alpha, eps=args.optim_eps) 42 | self.c_optimiser2 = RMSprop(params=self.critic_params2, lr=args.c_lr, alpha=args.optim_alpha, eps=args.optim_eps) 43 | 44 | def train_actor(self, batch: EpisodeBatch, t_env: int, episode_num: int): 45 | bs = batch.batch_size 46 | max_t = batch.max_seq_length 47 | terminated = batch["terminated"][:, :-1].float() 48 | mask = batch["filled"][:, :-1].float() 49 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 50 | mask = mask.repeat(1, 1, self.n_agents).view(-1) 51 | avail_actions = batch["avail_actions"] 52 | 53 | mac = self.mac 54 | alpha = max(0.05, 0.5 - t_env / 200000) # linear decay 55 | 56 | mac_out = [] 57 | mac.init_hidden(batch.batch_size) 58 | for t in range(batch.max_seq_length): 59 | agent_outs = mac.forward(batch, t=t) 60 | mac_out.append(agent_outs) 61 | mac_out = th.stack(mac_out, dim=1) # Concat over time 62 | 63 | # Mask out unavailable actions, renormalise (as in action selection) 64 | mac_out[avail_actions == 0] = 1e-10 65 | mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True) 66 | mac_out[avail_actions == 0] = 1e-10 67 | 68 | pi = mac_out[:,:-1].clone() 69 | pi = pi.reshape(-1, self.n_actions) 70 | log_pi = th.log(pi) 71 | 72 | inputs = self.critic1._build_inputs(batch, bs, max_t) 73 | q_vals1 = self.critic1.forward(inputs) 74 | q_vals2 = self.critic2.forward(inputs) 75 | q_vals = th.min(q_vals1, q_vals2) 76 | 77 | pi = mac_out[:,:-1].reshape(-1, self.n_actions) 78 | entropies = - (pi * log_pi).sum(dim=-1) 79 | 80 | # policy target for discrete actions (from Soft Actor-Critic for Discrete Action Settings) 81 | pol_target = (pi * (alpha * log_pi - q_vals[:,:-1].reshape(-1, self.n_actions))).sum(dim=-1) 82 | 83 | policy_loss = (pol_target * mask).sum() / mask.sum() 84 | 85 | # Optimise 86 | self.p_optimiser.zero_grad() 87 | policy_loss.backward() 88 | agent_grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) 89 | self.p_optimiser.step() 90 | 91 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 92 | self.logger.log_stat("policy_loss", policy_loss.item(), t_env) 93 | self.logger.log_stat("agent_grad_norm", agent_grad_norm, t_env) 94 | self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) 95 | self.logger.log_stat("alpha", alpha, t_env) 96 | self.logger.log_stat("ent", entropies.mean().item(), t_env) 97 | 98 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int, show_demo=False, save_data=None): 99 | self.train_actor(batch, t_env, episode_num) 100 | self.train_critic(batch, t_env) 101 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 102 | self._update_targets() 103 | self.last_target_update_episode = episode_num 104 | 105 | def train_critic(self, batch, t_env): 106 | bs = batch.batch_size 107 | max_t = batch.max_seq_length 108 | rewards = batch["reward"][:, :-1] 109 | actions = batch["actions"][:, :-1] 110 | terminated = batch["terminated"][:, :-1].float() 111 | mask = batch["filled"][:, :-1].float() 112 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 113 | avail_actions = batch["avail_actions"] 114 | actions_onehot = batch["actions_onehot"][:, :-1] 115 | states = batch["state"] 116 | 117 | mac = self.mac 118 | mixer1 = self.mixer1 119 | mixer2 = self.mixer2 120 | alpha = max(0.05, 0.5 - t_env / 200000) # linear decay 121 | 122 | mac_out = [] 123 | mac.init_hidden(batch.batch_size) 124 | for t in range(batch.max_seq_length): 125 | agent_outs = mac.forward(batch, t=t) 126 | mac_out.append(agent_outs) 127 | mac_out = th.stack(mac_out, dim=1) # Concat over time 128 | 129 | # Mask out unavailable actions 130 | mac_out[avail_actions == 0] = 0.0 131 | mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True) 132 | mac_out[avail_actions == 0] = 1e-10 133 | 134 | t_mac_out = mac_out.clone().detach() 135 | pi = t_mac_out 136 | 137 | # sample actions for next timesteps 138 | next_actions = Categorical(pi).sample().long().unsqueeze(3) 139 | next_actions_onehot = th.zeros(next_actions.squeeze(3).shape + (self.n_actions,)) 140 | if self.args.use_cuda: 141 | next_actions_onehot = next_actions_onehot.cuda() 142 | next_actions_onehot = next_actions_onehot.scatter_(3, next_actions, 1) 143 | 144 | pi_taken = th.gather(pi, dim=3, index=next_actions).squeeze(3)[:,1:] 145 | pi_taken[mask.expand_as(pi_taken) == 0] = 1.0 146 | log_pi_taken = th.log(pi_taken) 147 | 148 | target_inputs = self.target_critic1._build_inputs(batch, bs, max_t) 149 | target_q_vals1 = self.target_critic1.forward(target_inputs).detach() 150 | target_q_vals2 = self.target_critic2.forward(target_inputs).detach() 151 | 152 | # directly caculate the values by definition 153 | next_vs1 = th.logsumexp(target_q_vals1 / alpha, dim=-1) * alpha 154 | next_vs2 = th.logsumexp(target_q_vals2 / alpha, dim=-1) * alpha 155 | 156 | next_chosen_qvals1 = th.gather(target_q_vals1, dim=3, index=next_actions).squeeze(3) 157 | next_chosen_qvals2 = th.gather(target_q_vals2, dim=3, index=next_actions).squeeze(3) 158 | 159 | target_qvals1 = self.target_mixer1(next_chosen_qvals1, states, actions=next_actions_onehot, vs=next_vs1) 160 | target_qvals2 = self.target_mixer2(next_chosen_qvals2, states, actions=next_actions_onehot, vs=next_vs2) 161 | 162 | target_qvals = th.min(target_qvals1, target_qvals2) 163 | 164 | # Calculate td-lambda targets 165 | target_v = build_td_lambda_targets(rewards, terminated, mask, target_qvals, self.n_agents, self.args.gamma, self.args.td_lambda) 166 | targets = target_v - alpha * log_pi_taken.mean(dim=-1, keepdim=True) 167 | 168 | inputs = self.critic1._build_inputs(batch, bs, max_t) 169 | q_vals1 = self.critic1.forward(inputs) 170 | q_vals2 = self.critic2.forward(inputs) 171 | 172 | # directly caculate the values by definition 173 | vs1 = th.logsumexp(q_vals1 / alpha, dim=-1) * alpha 174 | vs2 = th.logsumexp(q_vals2 / alpha, dim=-1) * alpha 175 | 176 | q_taken1 = th.gather(q_vals1[:,:-1], dim=3, index=actions).squeeze(3) 177 | q_taken2 = th.gather(q_vals2[:,:-1], dim=3, index=actions).squeeze(3) 178 | 179 | q_taken1 = mixer1(q_taken1, states[:, :-1], actions=actions_onehot, vs=vs1[:, :-1]) 180 | q_taken2 = mixer2(q_taken2, states[:, :-1], actions=actions_onehot, vs=vs2[:, :-1]) 181 | 182 | td_error1 = q_taken1 - targets.detach() 183 | td_error2 = q_taken2 - targets.detach() 184 | 185 | mask = mask.expand_as(td_error1) 186 | 187 | # 0-out the targets that came from padded data 188 | masked_td_error1 = td_error1 * mask 189 | loss1 = (masked_td_error1 ** 2).sum() / mask.sum() 190 | masked_td_error2 = td_error2 * mask 191 | loss2 = (masked_td_error2 ** 2).sum() / mask.sum() 192 | 193 | # Optimise 194 | self.c_optimiser1.zero_grad() 195 | loss1.backward() 196 | grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params1, self.args.grad_norm_clip) 197 | self.c_optimiser1.step() 198 | 199 | self.c_optimiser2.zero_grad() 200 | loss2.backward() 201 | grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params2, self.args.grad_norm_clip) 202 | self.c_optimiser2.step() 203 | 204 | 205 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 206 | self.logger.log_stat("loss", loss1.item(), t_env) 207 | self.logger.log_stat("grad_norm", grad_norm, t_env) 208 | mask_elems = mask.sum().item() 209 | self.logger.log_stat("td_error_abs", (masked_td_error1.abs().sum().item() / mask_elems), t_env) 210 | self.logger.log_stat("q_taken_mean", 211 | (q_taken1 * mask).sum().item() / (mask_elems * self.args.n_agents), t_env) 212 | self.logger.log_stat("target_mean", (targets * mask).sum().item() / (mask_elems * self.args.n_agents), 213 | t_env) 214 | self.log_stats_t = t_env 215 | 216 | def _update_targets(self): 217 | self.target_critic1.load_state_dict(self.critic1.state_dict()) 218 | self.target_critic2.load_state_dict(self.critic2.state_dict()) 219 | self.target_mixer1.load_state_dict(self.mixer1.state_dict()) 220 | self.target_mixer2.load_state_dict(self.mixer2.state_dict()) 221 | self.logger.console_logger.info("Updated target network") 222 | 223 | def cuda(self): 224 | self.mac.cuda() 225 | self.critic1.cuda() 226 | self.mixer1.cuda() 227 | self.target_critic1.cuda() 228 | self.target_mixer1.cuda() 229 | self.critic2.cuda() 230 | self.mixer2.cuda() 231 | self.target_critic2.cuda() 232 | self.target_mixer2.cuda() 233 | 234 | def save_models(self, path): 235 | self.mac.save_models(path) 236 | th.save(self.critic1.state_dict(), "{}/critic1.th".format(path)) 237 | th.save(self.mixer1.state_dict(), "{}/mixer1.th".format(path)) 238 | th.save(self.critic2.state_dict(), "{}/critic2.th".format(path)) 239 | th.save(self.mixer2.state_dict(), "{}/mixer2.th".format(path)) 240 | th.save(self.p_optimiser.state_dict(), "{}/agent_opt.th".format(path)) 241 | th.save(self.c_optimiser1.state_dict(), "{}/critic_opt1.th".format(path)) 242 | th.save(self.c_optimiser2.state_dict(), "{}/critic_opt2.th".format(path)) 243 | 244 | def load_models(self, path): 245 | self.mac.load_models(path) 246 | self.critic1.load_state_dict(th.load("{}/critic1.th".format(path), map_location=lambda storage, loc: storage)) 247 | self.critic2.load_state_dict(th.load("{}/critic2.th".format(path), map_location=lambda storage, loc: storage)) 248 | # Not quite right but I don't want to save target networks 249 | self.target_critic1.load_state_dict(self.critic1.state_dict()) 250 | self.target_critic2.load_state_dict(self.critic2.state_dict()) 251 | 252 | self.mixer1.load_state_dict(th.load("{}/mixer1.th".format(path), map_location=lambda storage, loc: storage)) 253 | self.mixer2.load_state_dict(th.load("{}/mixer2.th".format(path), map_location=lambda storage, loc: storage)) 254 | 255 | self.p_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) 256 | self.critic_optimiser1.load_state_dict(th.load("{}/critic_opt1.th".format(path), map_location=lambda storage, loc: storage)) 257 | self.critic_optimiser2.load_state_dict(th.load("{}/critic_opt2.th".format(path), map_location=lambda storage, loc: storage)) 258 | 259 | def build_inputs(self, batch, bs, max_t, actions_onehot): 260 | inputs = [] 261 | inputs.append(batch["obs"][:]) 262 | actions = actions_onehot[:].reshape(bs, max_t, self.n_agents, -1) 263 | inputs.append(actions) 264 | inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)) 265 | inputs = th.cat([x.reshape(bs, max_t, self.n_agents, -1) for x in inputs], dim=-1) 266 | return inputs 267 | --------------------------------------------------------------------------------