├── src ├── __init__.py ├── modules │ ├── __init__.py │ ├── critics │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── coma.cpython-38.pyc │ │ │ └── __init__.cpython-38.pyc │ │ └── coma.py │ ├── mixers │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── qmix.cpython-38.pyc │ │ │ ├── rem.cpython-38.pyc │ │ │ ├── vdn.cpython-38.pyc │ │ │ ├── mmdmix.cpython-38.pyc │ │ │ ├── qatten.cpython-38.pyc │ │ │ ├── qtran.cpython-38.pyc │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── gagamix.cpython-38.pyc │ │ │ ├── magamix.cpython-38.pyc │ │ │ └── mgatmix.cpython-38.pyc │ │ ├── vdn.py │ │ ├── qmix.py │ │ └── qtran.py │ ├── agents │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── rnn_agent.cpython-38.pyc │ │ └── rnn_agent.py │ ├── vi │ │ ├── __pycache__ │ │ │ ├── vae.cpython-38.pyc │ │ │ └── vgae.cpython-38.pyc │ │ ├── vae.py │ │ └── vgae.py │ └── __pycache__ │ │ └── __init__.cpython-38.pyc ├── components │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── transforms.cpython-38.pyc │ │ ├── episode_buffer.cpython-38.pyc │ │ ├── action_selectors.cpython-38.pyc │ │ └── epsilon_schedules.cpython-38.pyc │ ├── transforms.py │ ├── epsilon_schedules.py │ ├── action_selectors.py │ └── episode_buffer.py ├── .gitignore ├── __pycache__ │ └── run.cpython-38.pyc ├── controllers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── basic_controller.cpython-38.pyc │ └── basic_controller.py ├── envs │ ├── __pycache__ │ │ └── __init__.cpython-38.pyc │ ├── starcraft2 │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── starcraft2.cpython-38.pyc │ │ ├── maps │ │ │ ├── __pycache__ │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ └── smac_maps.cpython-38.pyc │ │ │ ├── __init__.py │ │ │ └── smac_maps.py │ │ └── __init__.py │ ├── __init__.py │ └── multiagentenv.py ├── utils │ ├── __pycache__ │ │ ├── logging.cpython-38.pyc │ │ ├── rl_utils.cpython-38.pyc │ │ └── timehelper.cpython-38.pyc │ ├── dict2namedtuple.py │ ├── rl_utils.py │ ├── timehelper.py │ └── logging.py ├── runners │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── episode_runner.cpython-38.pyc │ │ └── parallel_runner.cpython-38.pyc │ ├── __init__.py │ ├── episode_runner.py │ └── parallel_runner.py ├── learners │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── q_learner.cpython-38.pyc │ │ ├── coma_learner.cpython-38.pyc │ │ ├── gaga_learner.cpython-38.pyc │ │ ├── maga_learner.cpython-38.pyc │ │ ├── mgat_learner.cpython-38.pyc │ │ ├── sqv_learner.cpython-38.pyc │ │ ├── mmd_q_learner.cpython-38.pyc │ │ ├── q_atten_learner.cpython-38.pyc │ │ └── qtran_learner.cpython-38.pyc │ ├── __init__.py │ ├── q_learner.py │ ├── coma_learner.py │ ├── qtran_learner.py │ └── side_learner.py ├── config │ ├── algs │ │ ├── vdn.yaml │ │ ├── iql.yaml │ │ ├── vdn_beta.yaml │ │ ├── iql_beta.yaml │ │ ├── qmix_beta.yaml │ │ ├── qmix.yaml │ │ ├── qtran.yaml │ │ ├── side.yaml │ │ └── coma.yaml │ ├── envs │ │ ├── sc2_beta.yaml │ │ ├── sc2.yaml │ │ └── sc2_po.yaml │ └── default.yaml ├── main.py └── run.py ├── docker ├── build.sh └── Dockerfile ├── run.sh ├── run_interactive.sh ├── requirements.txt ├── install_sc2.sh ├── README.md └── LICENSE /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/critics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/modules/mixers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | tb_logs/ 2 | results/ 3 | -------------------------------------------------------------------------------- /src/modules/agents/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .rnn_agent import RNNAgent 4 | REGISTRY["rnn"] = RNNAgent -------------------------------------------------------------------------------- /src/__pycache__/run.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/__pycache__/run.cpython-38.pyc -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Building Dockerfile with image name pymarl:1.0' 4 | docker build -t pymarl:1.0 . 5 | -------------------------------------------------------------------------------- /src/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .basic_controller import BasicMAC 4 | 5 | REGISTRY["basic_mac"] = BasicMAC -------------------------------------------------------------------------------- /src/envs/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/envs/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/vi/__pycache__/vae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/vi/__pycache__/vae.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/logging.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/utils/__pycache__/logging.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/rl_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/utils/__pycache__/rl_utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/vi/__pycache__/vgae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/vi/__pycache__/vgae.cpython-38.pyc -------------------------------------------------------------------------------- /src/runners/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/runners/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/__pycache__/timehelper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/utils/__pycache__/timehelper.cpython-38.pyc -------------------------------------------------------------------------------- /src/components/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/components/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/q_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/q_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/qmix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/qmix.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/rem.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/rem.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/vdn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/vdn.cpython-38.pyc -------------------------------------------------------------------------------- /src/components/__pycache__/transforms.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/components/__pycache__/transforms.cpython-38.pyc -------------------------------------------------------------------------------- /src/controllers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/controllers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/coma_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/coma_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/gaga_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/gaga_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/maga_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/maga_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/mgat_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/mgat_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/sqv_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/sqv_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/critics/__pycache__/coma.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/critics/__pycache__/coma.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/mmdmix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/mmdmix.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/qatten.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/qatten.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/qtran.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/qtran.cpython-38.pyc -------------------------------------------------------------------------------- /src/envs/starcraft2/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/envs/starcraft2/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/mmd_q_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/mmd_q_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/q_atten_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/q_atten_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/learners/__pycache__/qtran_learner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/learners/__pycache__/qtran_learner.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/agents/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/agents/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/agents/__pycache__/rnn_agent.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/agents/__pycache__/rnn_agent.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/critics/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/critics/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/gagamix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/gagamix.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/magamix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/magamix.cpython-38.pyc -------------------------------------------------------------------------------- /src/modules/mixers/__pycache__/mgatmix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/modules/mixers/__pycache__/mgatmix.cpython-38.pyc -------------------------------------------------------------------------------- /src/runners/__pycache__/episode_runner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/runners/__pycache__/episode_runner.cpython-38.pyc -------------------------------------------------------------------------------- /src/runners/__pycache__/parallel_runner.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/runners/__pycache__/parallel_runner.cpython-38.pyc -------------------------------------------------------------------------------- /src/components/__pycache__/episode_buffer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/components/__pycache__/episode_buffer.cpython-38.pyc -------------------------------------------------------------------------------- /src/envs/starcraft2/__pycache__/starcraft2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/envs/starcraft2/__pycache__/starcraft2.cpython-38.pyc -------------------------------------------------------------------------------- /src/components/__pycache__/action_selectors.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/components/__pycache__/action_selectors.cpython-38.pyc -------------------------------------------------------------------------------- /src/components/__pycache__/epsilon_schedules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/components/__pycache__/epsilon_schedules.cpython-38.pyc -------------------------------------------------------------------------------- /src/controllers/__pycache__/basic_controller.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/controllers/__pycache__/basic_controller.cpython-38.pyc -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/envs/starcraft2/maps/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/__pycache__/smac_maps.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deligentfool/SIDE/HEAD/src/envs/starcraft2/maps/__pycache__/smac_maps.cpython-38.pyc -------------------------------------------------------------------------------- /src/utils/dict2namedtuple.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | def convert(dictionary): 5 | return namedtuple('GenericDict', dictionary.keys())(**dictionary) 6 | -------------------------------------------------------------------------------- /src/runners/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .episode_runner import EpisodeRunner 4 | REGISTRY["episode"] = EpisodeRunner 5 | 6 | from .parallel_runner import ParallelRunner 7 | REGISTRY["parallel"] = ParallelRunner 8 | -------------------------------------------------------------------------------- /src/envs/starcraft2/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from absl import flags 6 | FLAGS = flags.FLAGS 7 | FLAGS(['main.py']) 8 | -------------------------------------------------------------------------------- /src/modules/mixers/vdn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | 5 | class VDNMixer(nn.Module): 6 | def __init__(self): 7 | super(VDNMixer, self).__init__() 8 | 9 | def forward(self, agent_qs, batch): 10 | return th.sum(agent_qs, dim=2, keepdim=True) -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from smac.env.starcraft2.maps import smac_maps 6 | 7 | 8 | def get_map_params(map_name): 9 | map_param_registry = smac_maps.get_smac_map_registry() 10 | return map_param_registry[map_name] 11 | -------------------------------------------------------------------------------- /src/learners/__init__.py: -------------------------------------------------------------------------------- 1 | from .q_learner import QLearner 2 | from .coma_learner import COMALearner 3 | from .qtran_learner import QLearner as QTranLearner 4 | from .side_learner import SIDELearner 5 | 6 | REGISTRY = {} 7 | 8 | REGISTRY["q_learner"] = QLearner 9 | REGISTRY["coma_learner"] = COMALearner 10 | REGISTRY["qtran_learner"] = QTranLearner 11 | REGISTRY["side_learner"] = SIDELearner -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from .multiagentenv import MultiAgentEnv 3 | from .starcraft2.starcraft2 import StarCraft2Env 4 | import sys 5 | import os 6 | 7 | def env_fn(env, **kwargs) -> MultiAgentEnv: 8 | return env(**kwargs) 9 | 10 | REGISTRY = {} 11 | REGISTRY["sc2"] = partial(env_fn, env=StarCraft2Env) 12 | 13 | if sys.platform == "linux": 14 | os.environ.setdefault("SC2PATH", 15 | os.path.join(os.getcwd(), "3rdparty", "StarCraftII")) -------------------------------------------------------------------------------- /src/config/algs/vdn.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "q_learner" 19 | double_q: True 20 | mixer: "vdn" 21 | 22 | name: "vdn" 23 | -------------------------------------------------------------------------------- /src/config/algs/iql.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "q_learner" 19 | double_q: True 20 | mixer: # Mixer becomes None 21 | 22 | name: "iql" 23 | -------------------------------------------------------------------------------- /src/config/algs/vdn_beta.yaml: -------------------------------------------------------------------------------- 1 | # --- VDN specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | 12 | buffer_size: 5000 13 | 14 | # update the target network every {} episodes 15 | target_update_interval: 200 16 | 17 | # use the Q_Learner to train 18 | agent_output_type: "q" 19 | learner: "q_learner" 20 | double_q: True 21 | mixer: "vdn" 22 | 23 | name: "vdn_smac_parallel" 24 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | HASH=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 4 | head -n 1) 3 | GPU=$1 4 | name=${USER}_pymarl_GPU_${GPU}_${HASH} 5 | 6 | echo "Launching container named '${name}' on GPU '${GPU}'" 7 | # Launches a docker container using our image, and runs the provided command 8 | 9 | if hash nvidia-docker 2>/dev/null; then 10 | cmd=nvidia-docker 11 | else 12 | cmd=docker 13 | fi 14 | 15 | NV_GPU="$GPU" ${cmd} run \ 16 | --name $name \ 17 | --user $(id -u):$(id -g) \ 18 | -v `pwd`:/pymarl \ 19 | -t pymarl:1.0 \ 20 | ${@:2} 21 | -------------------------------------------------------------------------------- /src/config/algs/iql_beta.yaml: -------------------------------------------------------------------------------- 1 | # --- IQL specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | 12 | buffer_size: 5000 13 | 14 | # update the target network every {} episodes 15 | target_update_interval: 200 16 | 17 | # use the Q_Learner to train 18 | agent_output_type: "q" 19 | learner: "q_learner" 20 | double_q: True 21 | mixer: # Mixer becomes None 22 | 23 | name: "iql_smac_parallel" -------------------------------------------------------------------------------- /run_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | HASH=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 4 | head -n 1) 3 | GPU=$1 4 | name=${USER}_pymarl_GPU_${GPU}_${HASH} 5 | 6 | echo "Launching container named '${name}' on GPU '${GPU}'" 7 | # Launches a docker container using our image, and runs the provided command 8 | 9 | if hash nvidia-docker 2>/dev/null; then 10 | cmd=nvidia-docker 11 | else 12 | cmd=docker 13 | fi 14 | 15 | NV_GPU="$GPU" ${cmd} run -i \ 16 | --name $name \ 17 | --user $(id -u):$(id -g) \ 18 | -v `pwd`:/pymarl \ 19 | -t pymarl:1.0 \ 20 | ${@:2} 21 | -------------------------------------------------------------------------------- /src/config/algs/qmix_beta.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | 12 | buffer_size: 5000 13 | 14 | # update the target network every {} episodes 15 | target_update_interval: 200 16 | 17 | # use the Q_Learner to train 18 | agent_output_type: "q" 19 | learner: "q_learner" 20 | double_q: True 21 | mixer: "qmix" 22 | mixing_embed_dim: 32 23 | 24 | name: "qmix_smac_parallel" 25 | -------------------------------------------------------------------------------- /src/config/algs/qmix.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "q_learner" 19 | double_q: True 20 | mixer: "qmix" 21 | mixing_embed_dim: 32 22 | hypernet_layers: 2 23 | hypernet_embed: 64 24 | 25 | name: "qmix" 26 | -------------------------------------------------------------------------------- /src/config/algs/qtran.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "qtran_learner" 19 | double_q: True 20 | mixer: "qtran_base" 21 | mixing_embed_dim: 64 22 | qtran_arch: "qtran_paper" 23 | 24 | opt_loss: 1 25 | nopt_min_loss: 0.1 26 | 27 | network_size: small 28 | 29 | name: "qtran" 30 | -------------------------------------------------------------------------------- /src/config/algs/side.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "side_learner" 19 | double_q: True 20 | mixer: "qmix" 21 | name: "side" 22 | 23 | hidden_dim: 128 24 | hypernet_layers: 2 25 | hypernet_embed: 64 26 | mixing_embed_dim: 32 27 | 28 | latent_dim: 64 29 | vgae: True 30 | prior: True -------------------------------------------------------------------------------- /src/components/transforms.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class Transform: 5 | def transform(self, tensor): 6 | raise NotImplementedError 7 | 8 | def infer_output_info(self, vshape_in, dtype_in): 9 | raise NotImplementedError 10 | 11 | 12 | class OneHot(Transform): 13 | def __init__(self, out_dim): 14 | self.out_dim = out_dim 15 | 16 | def transform(self, tensor): 17 | y_onehot = tensor.new(*tensor.shape[:-1], self.out_dim).zero_() 18 | y_onehot.scatter_(-1, tensor.long(), 1) 19 | return y_onehot.float() 20 | 21 | def infer_output_info(self, vshape_in, dtype_in): 22 | return (self.out_dim,), th.float32 -------------------------------------------------------------------------------- /src/config/algs/coma.yaml: -------------------------------------------------------------------------------- 1 | # --- COMA specific parameters --- 2 | 3 | action_selector: "multinomial" 4 | epsilon_start: .5 5 | epsilon_finish: .01 6 | epsilon_anneal_time: 100000 7 | mask_before_softmax: False 8 | 9 | runner: "parallel" 10 | 11 | buffer_size: 8 12 | batch_size_run: 8 13 | batch_size: 8 14 | 15 | env_args: 16 | state_last_action: False # critic adds last action internally 17 | 18 | # update the target network every {} training steps 19 | target_update_interval: 200 20 | 21 | lr: 0.0005 22 | critic_lr: 0.0005 23 | td_lambda: 0.8 24 | 25 | # use COMA 26 | agent_output_type: "pi_logits" 27 | learner: "coma_learner" 28 | critic_q_fn: "coma" 29 | critic_baseline_fn: "coma" 30 | critic_train_mode: "seq" 31 | critic_train_reps: 1 32 | q_nstep: 0 # 0 corresponds to default Q, 1 is r + gamma*Q, etc 33 | 34 | name: "coma" 35 | -------------------------------------------------------------------------------- /src/utils/rl_utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | def build_td_lambda_targets(rewards, terminated, mask, target_qs, n_agents, gamma, td_lambda): 5 | # Assumes in B*T*A and , , in (at least) B*T-1*1 6 | # Initialise last lambda -return for not terminated episodes 7 | ret = target_qs.new_zeros(*target_qs.shape) 8 | ret[:, -1] = target_qs[:, -1] * (1 - th.sum(terminated, dim=1)) 9 | # Backwards recursive update of the "forward view" 10 | for t in range(ret.shape[1] - 2, -1, -1): 11 | ret[:, t] = td_lambda * gamma * ret[:, t + 1] + mask[:, t] \ 12 | * (rewards[:, t] + (1 - td_lambda) * gamma * target_qs[:, t + 1] * (1 - terminated[:, t])) 13 | # Returns lambda-return from t=0 to t=T-1, i.e. in B*T-1*A 14 | return ret[:, 0:-1] 15 | 16 | -------------------------------------------------------------------------------- /src/modules/agents/rnn_agent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class RNNAgent(nn.Module): 6 | def __init__(self, input_shape, args): 7 | super(RNNAgent, self).__init__() 8 | self.args = args 9 | 10 | self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) 11 | self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) 12 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) 13 | 14 | def init_hidden(self): 15 | # make hidden states on same device as model 16 | return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() 17 | 18 | def forward(self, inputs, hidden_state): 19 | x = F.relu(self.fc1(inputs)) 20 | h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) 21 | h = self.rnn(x, h_in) 22 | q = self.fc2(h) 23 | return q, h 24 | -------------------------------------------------------------------------------- /src/components/epsilon_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DecayThenFlatSchedule(): 5 | 6 | def __init__(self, 7 | start, 8 | finish, 9 | time_length, 10 | decay="exp"): 11 | 12 | self.start = start 13 | self.finish = finish 14 | self.time_length = time_length 15 | self.delta = (self.start - self.finish) / self.time_length 16 | self.decay = decay 17 | 18 | if self.decay in ["exp"]: 19 | self.exp_scaling = (-1) * self.time_length / np.log(self.finish) if self.finish > 0 else 1 20 | 21 | def eval(self, T): 22 | if self.decay in ["linear"]: 23 | return max(self.finish, self.start - self.delta * T) 24 | elif self.decay in ["exp"]: 25 | return min(self.start, max(self.finish, np.exp(- T / self.exp_scaling))) 26 | pass 27 | -------------------------------------------------------------------------------- /src/config/envs/sc2_beta.yaml: -------------------------------------------------------------------------------- 1 | env: sc2 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "3m" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | state_last_action: True 27 | state_timestep_number: False 28 | step_mul: 8 29 | seed: null 30 | heuristic_ai: False 31 | debug: False 32 | 33 | learner_log_interval: 20000 34 | log_interval: 20000 35 | runner_log_interval: 20000 36 | t_max: 10050000 37 | test_interval: 20000 38 | test_nepisode: 24 39 | test_greedy: True 40 | -------------------------------------------------------------------------------- /src/config/envs/sc2.yaml: -------------------------------------------------------------------------------- 1 | env: sc2 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "3m" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | state_last_action: True 27 | state_timestep_number: False 28 | step_mul: 8 29 | seed: null 30 | heuristic_ai: False 31 | heuristic_rest: False 32 | debug: False 33 | 34 | test_greedy: True 35 | test_nepisode: 32 36 | test_interval: 10000 37 | log_interval: 10000 38 | runner_log_interval: 10000 39 | learner_log_interval: 10000 40 | t_max: 2050000 41 | -------------------------------------------------------------------------------- /src/config/envs/sc2_po.yaml: -------------------------------------------------------------------------------- 1 | env: sc2 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "3m" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: True 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | state_last_action: True 27 | state_timestep_number: False 28 | step_mul: 8 29 | seed: null 30 | heuristic_ai: False 31 | heuristic_rest: False 32 | debug: False 33 | 34 | test_greedy: True 35 | test_nepisode: 32 36 | test_interval: 10000 37 | log_interval: 10000 38 | runner_log_interval: 10000 39 | learner_log_interval: 10000 40 | t_max: 5100000 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | atomicwrites==1.2.1 3 | attrs==18.2.0 4 | certifi==2018.8.24 5 | chardet==3.0.4 6 | cycler==0.10.0 7 | docopt==0.6.2 8 | enum34==1.1.6 9 | future==0.16.0 10 | idna==2.7 11 | imageio==2.4.1 12 | jsonpickle==0.9.6 13 | kiwisolver==1.0.1 14 | matplotlib==3.0.0 15 | mock==2.0.0 16 | more-itertools==4.3.0 17 | mpyq==0.2.5 18 | munch==2.3.2 19 | numpy==1.15.2 20 | pathlib2==2.3.2 21 | pbr==4.3.0 22 | Pillow==6.2.0 23 | pluggy==0.7.1 24 | portpicker==1.2.0 25 | probscale==0.2.3 26 | protobuf==3.6.1 27 | py==1.6.0 28 | pygame==1.9.4 29 | pyparsing==2.2.2 30 | pysc2==3.0.0 31 | pytest==3.8.2 32 | python-dateutil==2.7.3 33 | PyYAML==3.13 34 | requests==2.20.0 35 | s2clientprotocol==4.10.1.75800.0 36 | sacred==0.7.2 37 | scipy==1.1.0 38 | six==1.11.0 39 | sk-video==1.1.10 40 | snakeviz==1.0.0 41 | tensorboard-logger==0.1.0 42 | torch==0.4.1 43 | torchvision==0.2.1 44 | tornado==5.1.1 45 | urllib3==1.24.2 46 | websocket-client==0.53.0 47 | whichcraft==0.5.2 48 | wrapt==1.10.11 49 | git+https://github.com/oxwhirl/smac.git 50 | -------------------------------------------------------------------------------- /install_sc2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Install SC2 and add the custom maps 3 | 4 | if [ -z "$EXP_DIR" ] 5 | then 6 | EXP_DIR=~ 7 | fi 8 | 9 | echo "EXP_DIR: $EXP_DIR" 10 | cd $EXP_DIR/pymarl 11 | 12 | mkdir 3rdparty 13 | cd 3rdparty 14 | 15 | export SC2PATH=`pwd`'/StarCraftII' 16 | echo 'SC2PATH is set to '$SC2PATH 17 | 18 | if [ ! -d $SC2PATH ]; then 19 | echo 'StarCraftII is not installed. Installing now ...'; 20 | wget -c http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.6.2.69232.zip 21 | unzip -P iagreetotheeula SC2.4.6.2.69232.zip 22 | rm -rf SC2.4.6.2.69232.zip 23 | else 24 | echo 'StarCraftII is already installed.' 25 | fi 26 | 27 | echo 'Adding SMAC maps.' 28 | MAP_DIR="$SC2PATH/Maps/" 29 | echo 'MAP_DIR is set to '$MAP_DIR 30 | 31 | if [ ! -d $MAP_DIR ]; then 32 | mkdir -p $MAP_DIR 33 | fi 34 | 35 | cd .. 36 | wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip 37 | unzip SMAC_Maps.zip 38 | mv SMAC_Maps $MAP_DIR 39 | rm -rf SMAC_Maps.zip 40 | 41 | echo 'StarCraft II and SMAC are installed.' 42 | -------------------------------------------------------------------------------- /src/utils/timehelper.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | def print_time(start_time, T, t_max, episode, episode_rewards): 6 | time_elapsed = time.time() - start_time 7 | T = max(1, T) 8 | time_left = time_elapsed * (t_max - T) / T 9 | # Just in case its over 100 days 10 | time_left = min(time_left, 60 * 60 * 24 * 100) 11 | last_reward = "N\A" 12 | if len(episode_rewards) > 5: 13 | last_reward = "{:.2f}".format(np.mean(episode_rewards[-50:])) 14 | print("\033[F\033[F\x1b[KEp: {:,}, T: {:,}/{:,}, Reward: {}, \n\x1b[KElapsed: {}, Left: {}\n".format(episode, T, t_max, last_reward, time_str(time_elapsed), time_str(time_left)), " " * 10, end="\r") 15 | 16 | 17 | def time_left(start_time, t_start, t_current, t_max): 18 | if t_current >= t_max: 19 | return "-" 20 | time_elapsed = time.time() - start_time 21 | t_current = max(1, t_current) 22 | time_left = time_elapsed * (t_max - t_current) / (t_current - t_start) 23 | # Just in case its over 100 days 24 | time_left = min(time_left, 60 * 60 * 24 * 100) 25 | return time_str(time_left) 26 | 27 | 28 | def time_str(s): 29 | """ 30 | Convert seconds to a nicer string showing days, hours, minutes and seconds 31 | """ 32 | days, remainder = divmod(s, 60 * 60 * 24) 33 | hours, remainder = divmod(remainder, 60 * 60) 34 | minutes, seconds = divmod(remainder, 60) 35 | string = "" 36 | if days > 0: 37 | string += "{:d} days, ".format(int(days)) 38 | if hours > 0: 39 | string += "{:d} hours, ".format(int(hours)) 40 | if minutes > 0: 41 | string += "{:d} minutes, ".format(int(minutes)) 42 | string += "{:d} seconds".format(int(seconds)) 43 | return string 44 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu16.04 2 | MAINTAINER Tabish Rashid 3 | 4 | # CUDA includes 5 | ENV CUDA_PATH /usr/local/cuda 6 | ENV CUDA_INCLUDE_PATH /usr/local/cuda/include 7 | ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64 8 | 9 | # Ubuntu Packages 10 | RUN apt-get update -y && apt-get install software-properties-common -y && \ 11 | add-apt-repository -y multiverse && apt-get update -y && apt-get upgrade -y && \ 12 | apt-get install -y apt-utils nano vim man build-essential wget sudo && \ 13 | rm -rf /var/lib/apt/lists/* 14 | 15 | # Install curl and other dependencies 16 | RUN apt-get update -y && apt-get install -y curl libssl-dev openssl libopenblas-dev \ 17 | libhdf5-dev hdf5-helpers hdf5-tools libhdf5-serial-dev libprotobuf-dev protobuf-compiler git 18 | RUN curl -sk https://raw.githubusercontent.com/torch/distro/master/install-deps | bash && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | # Install python3 pip3 22 | RUN apt-get update 23 | RUN apt-get -y install python3 24 | RUN apt-get -y install python3-pip 25 | RUN pip3 install --upgrade pip 26 | 27 | # Python packages we use (or used at one point...) 28 | RUN pip3 install numpy scipy pyyaml matplotlib 29 | RUN pip3 install imageio 30 | RUN pip3 install tensorboard-logger 31 | RUN pip3 install pygame 32 | 33 | RUN mkdir /install 34 | WORKDIR /install 35 | 36 | RUN pip3 install jsonpickle==0.9.6 37 | # install Sacred (from OxWhirl fork) 38 | RUN pip3 install setuptools 39 | RUN git clone https://github.com/oxwhirl/sacred.git /install/sacred && cd /install/sacred && python3 setup.py install 40 | 41 | #### ------------------------------------------------------------------- 42 | #### install pytorch 43 | #### ------------------------------------------------------------------- 44 | RUN pip3 install torch 45 | RUN pip3 install torchvision snakeviz pytest probscale 46 | 47 | ## -- SMAC 48 | RUN pip3 install git+https://github.com/oxwhirl/smac.git 49 | ENV SC2PATH /pymarl/3rdparty/StarCraftII 50 | 51 | WORKDIR /pymarl 52 | -------------------------------------------------------------------------------- /src/config/default.yaml: -------------------------------------------------------------------------------- 1 | # --- Defaults --- 2 | 3 | # --- pymarl options --- 4 | runner: "episode" # Runs 1 env for an episode 5 | mac: "basic_mac" # Basic controller 6 | env: "sc2" # Environment name 7 | env_args: {} # Arguments for the environment 8 | batch_size_run: 1 # Number of environments to run in parallel 9 | test_nepisode: 20 # Number of episodes to test for 10 | test_interval: 2000 # Test after {} timesteps have passed 11 | test_greedy: True # Use greedy evaluation (if False, will set epsilon floor to 0 12 | log_interval: 2000 # Log summary of stats after every {} timesteps 13 | runner_log_interval: 2000 # Log runner stats (not test stats) every {} timesteps 14 | learner_log_interval: 2000 # Log training stats every {} timesteps 15 | t_max: 10000 # Stop running after this many timesteps 16 | use_cuda: True # Use gpu by default unless it isn't available 17 | buffer_cpu_only: True # If true we won't keep all of the replay buffer in vram 18 | 19 | # --- Logging options --- 20 | use_tensorboard: False # Log results to tensorboard 21 | save_model: True # Save the models to disk 22 | save_model_interval: 1000000 # Save models after this many timesteps 23 | checkpoint_path: "" # Load a checkpoint from this path 24 | evaluate: False # Evaluate model for test_nepisode episodes and quit (no training) 25 | load_step: 0 # Load model trained on this many timesteps (0 if choose max possible) 26 | save_replay: False # Saving the replay of the model loaded from checkpoint_path 27 | local_results_path: "results" # Path for local results 28 | 29 | # --- RL hyperparameters --- 30 | gamma: 0.99 31 | batch_size: 32 # Number of episodes to train on 32 | buffer_size: 32 # Size of the replay buffer 33 | lr: 0.0005 # Learning rate for agents 34 | critic_lr: 0.0005 # Learning rate for critics 35 | optim_alpha: 0.99 # RMSProp alpha 36 | optim_eps: 0.00001 # RMSProp epsilon 37 | grad_norm_clip: 10 # Reduce magnitude of gradients above this L2 norm 38 | 39 | # --- Agent parameters --- 40 | agent: "rnn" # Default rnn agent 41 | rnn_hidden_dim: 64 # Size of hidden state for default rnn agent 42 | obs_agent_id: True # Include the agent's one_hot id in the observation 43 | obs_last_action: True # Include the agent's last action (one_hot) in the observation 44 | 45 | # --- Experiment running params --- 46 | repeat_id: 1 47 | label: "default_label" 48 | -------------------------------------------------------------------------------- /src/envs/multiagentenv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | class MultiAgentEnv(object): 7 | 8 | def step(self, actions): 9 | """Returns reward, terminated, info.""" 10 | raise NotImplementedError 11 | 12 | def get_obs(self): 13 | """Returns all agent observations in a list.""" 14 | raise NotImplementedError 15 | 16 | def get_obs_agent(self, agent_id): 17 | """Returns observation for agent_id.""" 18 | raise NotImplementedError 19 | 20 | def get_obs_size(self): 21 | """Returns the size of the observation.""" 22 | raise NotImplementedError 23 | 24 | def get_state(self): 25 | """Returns the global state.""" 26 | raise NotImplementedError 27 | 28 | def get_state_size(self): 29 | """Returns the size of the global state.""" 30 | raise NotImplementedError 31 | 32 | def get_avail_actions(self): 33 | """Returns the available actions of all agents in a list.""" 34 | raise NotImplementedError 35 | 36 | def get_avail_agent_actions(self, agent_id): 37 | """Returns the available actions for agent_id.""" 38 | raise NotImplementedError 39 | 40 | def get_total_actions(self): 41 | """Returns the total number of actions an agent could ever take.""" 42 | raise NotImplementedError 43 | 44 | def reset(self): 45 | """Returns initial observations and states.""" 46 | raise NotImplementedError 47 | 48 | def render(self): 49 | raise NotImplementedError 50 | 51 | def close(self): 52 | raise NotImplementedError 53 | 54 | def seed(self): 55 | raise NotImplementedError 56 | 57 | def save_replay(self): 58 | """Save a replay.""" 59 | raise NotImplementedError 60 | 61 | def get_env_info(self): 62 | env_info = {"state_shape": self.get_state_size(), 63 | "obs_shape": self.get_obs_size(), 64 | "n_actions": self.get_total_actions(), 65 | "n_agents": self.n_agents, 66 | "episode_limit": self.episode_limit, 67 | "n_self_feature": self.get_self_feature_size(), 68 | "n_enemies": self.n_enemies, 69 | "unit_dim": self.unit_dim} 70 | return env_info 71 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import logging 3 | import numpy as np 4 | 5 | class Logger: 6 | def __init__(self, console_logger): 7 | self.console_logger = console_logger 8 | 9 | self.use_tb = False 10 | self.use_sacred = False 11 | self.use_hdf = False 12 | 13 | self.stats = defaultdict(lambda: []) 14 | 15 | def setup_tb(self, directory_name): 16 | # Import here so it doesn't have to be installed if you don't use it 17 | from tensorboard_logger import configure, log_value 18 | configure(directory_name) 19 | self.tb_logger = log_value 20 | self.use_tb = True 21 | 22 | def setup_sacred(self, sacred_run_dict): 23 | self.sacred_info = sacred_run_dict.info 24 | self.use_sacred = True 25 | 26 | def log_stat(self, key, value, t, to_sacred=True): 27 | self.stats[key].append((t, value)) 28 | 29 | if self.use_tb: 30 | self.tb_logger(key, value, t) 31 | 32 | if self.use_sacred and to_sacred: 33 | if key in self.sacred_info: 34 | self.sacred_info["{}_T".format(key)].append(t) 35 | self.sacred_info[key].append(value) 36 | else: 37 | self.sacred_info["{}_T".format(key)] = [t] 38 | self.sacred_info[key] = [value] 39 | 40 | def print_recent_stats(self): 41 | log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(*self.stats["episode"][-1]) 42 | i = 0 43 | for (k, v) in sorted(self.stats.items()): 44 | if k == "episode": 45 | continue 46 | i += 1 47 | window = 5 if k != "epsilon" else 1 48 | import torch as th 49 | item = "{:.4f}".format(th.mean(th.tensor([x[1] for x in self.stats[k][-window:]]))) 50 | log_str += "{:<25}{:>8}".format(k + ":", item) 51 | log_str += "\n" if i % 4 == 0 else "\t" 52 | self.console_logger.info(log_str) 53 | 54 | 55 | # set up a custom logger 56 | def get_logger(): 57 | logger = logging.getLogger() 58 | logger.handlers = [] 59 | ch = logging.StreamHandler() 60 | formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S') 61 | ch.setFormatter(formatter) 62 | logger.addHandler(ch) 63 | logger.setLevel('DEBUG') 64 | 65 | return logger 66 | 67 | -------------------------------------------------------------------------------- /src/components/action_selectors.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch.distributions import Categorical 3 | from .epsilon_schedules import DecayThenFlatSchedule 4 | 5 | REGISTRY = {} 6 | 7 | 8 | class MultinomialActionSelector(): 9 | 10 | def __init__(self, args): 11 | self.args = args 12 | 13 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, 14 | decay="linear") 15 | self.epsilon = self.schedule.eval(0) 16 | self.test_greedy = getattr(args, "test_greedy", True) 17 | 18 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 19 | masked_policies = agent_inputs.clone() 20 | masked_policies[avail_actions == 0.0] = 0.0 21 | 22 | self.epsilon = self.schedule.eval(t_env) 23 | 24 | if test_mode and self.test_greedy: 25 | picked_actions = masked_policies.max(dim=2)[1] 26 | else: 27 | picked_actions = Categorical(masked_policies).sample().long() 28 | 29 | return picked_actions 30 | 31 | 32 | REGISTRY["multinomial"] = MultinomialActionSelector 33 | 34 | 35 | class EpsilonGreedyActionSelector(): 36 | 37 | def __init__(self, args): 38 | self.args = args 39 | 40 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, 41 | decay="linear") 42 | self.epsilon = self.schedule.eval(0) 43 | 44 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 45 | 46 | # Assuming agent_inputs is a batch of Q-Values for each agent bav 47 | self.epsilon = self.schedule.eval(t_env) 48 | 49 | if test_mode: 50 | # Greedy action selection only 51 | self.epsilon = 0.0 52 | 53 | # mask actions that are excluded from selection 54 | masked_q_values = agent_inputs.clone() 55 | masked_q_values[avail_actions == 0.0] = -float("inf") # should never be selected! 56 | 57 | random_numbers = th.rand_like(agent_inputs[:, :, 0]) 58 | pick_random = (random_numbers < self.epsilon).long() 59 | random_actions = Categorical(avail_actions.float()).sample().long() 60 | 61 | picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1] 62 | return picked_actions 63 | 64 | 65 | REGISTRY["epsilon_greedy"] = EpsilonGreedyActionSelector 66 | -------------------------------------------------------------------------------- /src/modules/mixers/qmix.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class QMixer(nn.Module): 8 | def __init__(self, args, state_dim=None): 9 | super(QMixer, self).__init__() 10 | 11 | self.args = args 12 | self.n_agents = args.n_agents 13 | if state_dim is None: 14 | self.state_dim = int(np.prod(args.state_shape)) 15 | else: 16 | self.state_dim = state_dim 17 | 18 | self.embed_dim = args.mixing_embed_dim 19 | 20 | if getattr(args, "hypernet_layers", 1) == 1: 21 | self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) 22 | self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) 23 | elif getattr(args, "hypernet_layers", 1) == 2: 24 | hypernet_embed = self.args.hypernet_embed 25 | self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 26 | nn.ReLU(), 27 | nn.Linear(hypernet_embed, self.embed_dim * self.n_agents)) 28 | self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 29 | nn.ReLU(), 30 | nn.Linear(hypernet_embed, self.embed_dim)) 31 | elif getattr(args, "hypernet_layers", 1) > 2: 32 | raise Exception("Sorry >2 hypernet layers is not implemented!") 33 | else: 34 | raise Exception("Error setting number of hypernet layers.") 35 | 36 | # State dependent bias for hidden layer 37 | self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 38 | 39 | # V(s) instead of a bias for the last layers 40 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 41 | nn.ReLU(), 42 | nn.Linear(self.embed_dim, 1)) 43 | 44 | def forward(self, agent_qs, states): 45 | bs = agent_qs.size(0) 46 | states = states.reshape(-1, self.state_dim) 47 | agent_qs = agent_qs.view(-1, 1, self.n_agents) 48 | # First layer 49 | w1 = th.abs(self.hyper_w_1(states)) 50 | b1 = self.hyper_b_1(states) 51 | w1 = w1.view(-1, self.n_agents, self.embed_dim) 52 | b1 = b1.view(-1, 1, self.embed_dim) 53 | hidden = F.elu(th.bmm(agent_qs, w1) + b1) 54 | # Second layer 55 | w_final = th.abs(self.hyper_w_final(states)) 56 | w_final = w_final.view(-1, self.embed_dim, 1) 57 | # State-dependent bias 58 | v = self.V(states).view(-1, 1, 1) 59 | # Compute final output 60 | y = th.bmm(hidden, w_final) + v 61 | # Reshape and return 62 | q_tot = y.view(bs, -1, 1) 63 | return q_tot 64 | -------------------------------------------------------------------------------- /src/modules/critics/coma.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class COMACritic(nn.Module): 7 | def __init__(self, scheme, args): 8 | super(COMACritic, self).__init__() 9 | 10 | self.args = args 11 | self.n_actions = args.n_actions 12 | self.n_agents = args.n_agents 13 | 14 | input_shape = self._get_input_shape(scheme) 15 | self.output_type = "q" 16 | 17 | # Set up network layers 18 | self.fc1 = nn.Linear(input_shape, 128) 19 | self.fc2 = nn.Linear(128, 128) 20 | self.fc3 = nn.Linear(128, self.n_actions) 21 | 22 | def forward(self, batch, t=None): 23 | inputs = self._build_inputs(batch, t=t) 24 | x = F.relu(self.fc1(inputs)) 25 | x = F.relu(self.fc2(x)) 26 | q = self.fc3(x) 27 | return q 28 | 29 | def _build_inputs(self, batch, t=None): 30 | bs = batch.batch_size 31 | max_t = batch.max_seq_length if t is None else 1 32 | ts = slice(None) if t is None else slice(t, t+1) 33 | inputs = [] 34 | # state 35 | inputs.append(batch["state"][:, ts].unsqueeze(2).repeat(1, 1, self.n_agents, 1)) 36 | 37 | # observation 38 | inputs.append(batch["obs"][:, ts]) 39 | 40 | # actions (masked out by agent) 41 | actions = batch["actions_onehot"][:, ts].view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1) 42 | agent_mask = (1 - th.eye(self.n_agents, device=batch.device)) 43 | agent_mask = agent_mask.view(-1, 1).repeat(1, self.n_actions).view(self.n_agents, -1) 44 | inputs.append(actions * agent_mask.unsqueeze(0).unsqueeze(0)) 45 | 46 | # last actions 47 | if t == 0: 48 | inputs.append(th.zeros_like(batch["actions_onehot"][:, 0:1]).view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1)) 49 | elif isinstance(t, int): 50 | inputs.append(batch["actions_onehot"][:, slice(t-1, t)].view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1)) 51 | else: 52 | last_actions = th.cat([th.zeros_like(batch["actions_onehot"][:, 0:1]), batch["actions_onehot"][:, :-1]], dim=1) 53 | last_actions = last_actions.view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1) 54 | inputs.append(last_actions) 55 | 56 | inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)) 57 | 58 | inputs = th.cat([x.reshape(bs, max_t, self.n_agents, -1) for x in inputs], dim=-1) 59 | return inputs 60 | 61 | def _get_input_shape(self, scheme): 62 | # state 63 | input_shape = scheme["state"]["vshape"] 64 | # observation 65 | input_shape += scheme["obs"]["vshape"] 66 | # actions and last actions 67 | input_shape += scheme["actions_onehot"]["vshape"][0] * self.n_agents * 2 68 | # agent id 69 | input_shape += self.n_agents 70 | return input_shape -------------------------------------------------------------------------------- /src/modules/vi/vae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class inner_state_autoencoder(nn.Module): 8 | def __init__(self, args): 9 | super(inner_state_autoencoder, self).__init__() 10 | self.state_input_dim = args.latent_dim * args.n_agents 11 | self.action_input_dim = args.n_actions 12 | self.latent_dim = args.latent_dim * args.n_agents 13 | self.hidden_dim = args.hidden_dim 14 | self.action_embedding_dim = 8 15 | self.n_agents = args.n_agents 16 | 17 | self.state_encoder = nn.Sequential( 18 | nn.Linear(self.state_input_dim, 128), 19 | nn.ReLU(), 20 | nn.Linear(128, self.hidden_dim), 21 | ) 22 | 23 | self.action_encoder = nn.Sequential( 24 | nn.Linear(self.action_input_dim, self.action_embedding_dim) 25 | ) 26 | 27 | self.cat_encoder = nn.Sequential( 28 | nn.Linear(self.hidden_dim + self.action_embedding_dim * args.n_agents, self.hidden_dim), 29 | nn.ReLU(), 30 | nn.Linear(self.hidden_dim, self.hidden_dim) 31 | ) 32 | self.mu = nn.Sequential( 33 | nn.Linear(self.hidden_dim, self.state_input_dim) 34 | ) 35 | 36 | self.logvar = nn.Sequential( 37 | nn.Linear(self.hidden_dim, self.state_input_dim) 38 | ) 39 | 40 | self.state_decoder = nn.Sequential( 41 | nn.Linear(self.state_input_dim, 128), 42 | nn.ReLU(), 43 | nn.Linear(128, 128), 44 | nn.ReLU(), 45 | nn.Linear(128, self.state_input_dim) 46 | ) 47 | 48 | self.action_decoder = nn.ModuleList([nn.Sequential( 49 | nn.Linear(self.state_input_dim, self.hidden_dim), 50 | nn.ReLU(), 51 | nn.Linear(self.hidden_dim, self.action_input_dim) 52 | ) for _ in range(args.n_agents)]) 53 | 54 | def encode(self, s_t_1, a): 55 | state_encoder = self.state_encoder(s_t_1) 56 | action_encoder = self.action_encoder(a) 57 | encoder = torch.cat([state_encoder, action_encoder.view(action_encoder.size(0), action_encoder.size(1), -1)], dim=-1) 58 | encoder = self.cat_encoder(encoder) 59 | mu, logvar = self.mu(encoder), self.logvar(encoder) 60 | return mu, logvar 61 | 62 | def sample_z(self, mu, logvar): 63 | eps = torch.randn_like(logvar) 64 | return mu + eps * torch.exp(0.5 * logvar) 65 | 66 | def decode(self, z): 67 | state_decoder = self.state_decoder(z) 68 | action_decoders = [] 69 | for n in range(self.n_agents): 70 | action_decoder = self.action_decoder[n](z) 71 | action_decoders.append(action_decoder) 72 | action_decoders = torch.stack(action_decoders, dim=-2) 73 | action_decoders = torch.log_softmax(action_decoders, dim=-1) 74 | return state_decoder, action_decoders 75 | 76 | def forward(self, s_t_1, a): 77 | mu, logvar = self.encode(s_t_1, a) 78 | z = self.sample_z(mu, logvar) 79 | state_decoder, action_decoders = self.decode(z) 80 | return state_decoder, action_decoders, mu, logvar 81 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import collections 4 | from os.path import dirname, abspath 5 | from copy import deepcopy 6 | from sacred import Experiment, SETTINGS 7 | from sacred.observers import FileStorageObserver 8 | from sacred.utils import apply_backspaces_and_linefeeds 9 | import sys 10 | import torch as th 11 | from utils.logging import get_logger 12 | import yaml 13 | 14 | from run import run 15 | 16 | SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console 17 | logger = get_logger() 18 | 19 | ex = Experiment("pymarl") 20 | ex.logger = logger 21 | ex.captured_out_filter = apply_backspaces_and_linefeeds 22 | 23 | results_path = os.path.join(dirname(dirname(abspath(__file__))), "results") 24 | 25 | 26 | @ex.main 27 | def my_main(_run, _config, _log): 28 | # Setting the random seed throughout the modules 29 | config = config_copy(_config) 30 | config["seed"] = config['env_args']['seed'] 31 | np.random.seed(config["seed"]) 32 | th.manual_seed(config["seed"]) 33 | th.cuda.manual_seed(config["seed"]) 34 | th.backends.cudnn.benchmark = False 35 | th.backends.cudnn.deterministic = True 36 | 37 | # run the framework 38 | run(_run, config, _log) 39 | 40 | 41 | def _get_config(params, arg_name, subfolder): 42 | config_name = None 43 | for _i, _v in enumerate(params): 44 | if _v.split("=")[0] == arg_name: 45 | config_name = _v.split("=")[1] 46 | del params[_i] 47 | break 48 | 49 | if config_name is not None: 50 | with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f: 51 | try: 52 | config_dict = yaml.load(f) 53 | except yaml.YAMLError as exc: 54 | assert False, "{}.yaml error: {}".format(config_name, exc) 55 | return config_dict 56 | 57 | 58 | def recursive_dict_update(d, u): 59 | for k, v in u.items(): 60 | if isinstance(v, collections.Mapping): 61 | d[k] = recursive_dict_update(d.get(k, {}), v) 62 | else: 63 | d[k] = v 64 | return d 65 | 66 | 67 | def config_copy(config): 68 | if isinstance(config, dict): 69 | return {k: config_copy(v) for k, v in config.items()} 70 | elif isinstance(config, list): 71 | return [config_copy(v) for v in config] 72 | else: 73 | return deepcopy(config) 74 | 75 | 76 | if __name__ == '__main__': 77 | params = deepcopy(sys.argv) 78 | 79 | # Get the defaults from default.yaml 80 | with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f: 81 | try: 82 | config_dict = yaml.load(f) 83 | except yaml.YAMLError as exc: 84 | assert False, "default.yaml error: {}".format(exc) 85 | 86 | # Load algorithm and env base configs 87 | env_config = _get_config(params, "--env-config", "envs") 88 | alg_config = _get_config(params, "--config", "algs") 89 | # config_dict = {**config_dict, **env_config, **alg_config} 90 | config_dict = recursive_dict_update(config_dict, env_config) 91 | config_dict = recursive_dict_update(config_dict, alg_config) 92 | 93 | # now add all the config to sacred 94 | ex.add_config(config_dict) 95 | 96 | # Save to disk by default for sacred 97 | logger.info("Saving to FileStorageObserver in results/sacred.") 98 | file_obs_path = os.path.join(results_path, "sacred") 99 | ex.observers.append(FileStorageObserver.create(file_obs_path)) 100 | 101 | ex.run_commandline(params) 102 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SIDE: State Inference for Partially Observable Cooperative Multi-Agent Reinforcement Learning 2 | 3 | ## Note 4 | SIDE is a novel value decomposition framework, named State Inference for value DEcomposition, which eliminates the need to know the true state by simultaneously seeking solutions to the two problems of optimal control and state inference. 5 | 6 | The implementation of the following methods can also be found in this codebase, which are finished by the authors of [PyMARL](https://github.com/oxwhirl/pymarl): 7 | - [**SIDE**: SIDE: State Inference for Partially Observable Cooperative Multi-Agent Reinforcement Learning](https://arxiv.org/abs/2105.06228) 8 | - [**QMIX**: QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1803.11485) 9 | - [**COMA**: Counterfactual Multi-Agent Policy Gradients](https://arxiv.org/abs/1705.08926) 10 | - [**VDN**: Value-Decomposition Networks For Cooperative Multi-Agent Learning](https://arxiv.org/abs/1706.05296) 11 | - [**IQL**: Independent Q-Learning](https://arxiv.org/abs/1511.08779) 12 | - [**QTRAN**: QTRAN: Learning to Factorize with Transformation for Cooperative Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1905.05408) 13 | 14 | PyMARL is written in PyTorch and uses [SMAC](https://github.com/oxwhirl/smac) as its environment. 15 | 16 | ## Installation instructions 17 | 18 | Build the Dockerfile using 19 | ```shell 20 | cd docker 21 | bash build.sh 22 | ``` 23 | 24 | Set up StarCraft II and SMAC: 25 | ```shell 26 | bash install_sc2.sh 27 | ``` 28 | 29 | This will download SC2 into the 3rdparty folder and copy the maps necessary to run over. 30 | 31 | The requirements.txt file can be used to install the necessary packages into a virtual environment (not recomended). 32 | 33 | It is worth noting that we run the all experiments on **SC2.4.6.2.69232**, not SC2.4.10. Performance is *not* always comparable between versions. 34 | 35 | ## Run an experiment 36 | 37 | ```shell 38 | python3 src/main.py --config=side --env-config=sc2 with env_args.map_name=2s3z env_args.seed=1 39 | ``` 40 | 41 | The config files act as defaults for an algorithm or environment. 42 | 43 | They are all located in `src/config`. 44 | `--config` refers to the config files in `src/config/algs` 45 | `--env-config` refers to the config files in `src/config/envs` 46 | 47 | All results will be stored in the `Results` folder. 48 | 49 | The previous config files used for the SMAC Beta have the suffix `_beta`. 50 | 51 | ## Saving and loading learnt models 52 | 53 | ### Saving models 54 | 55 | You can save the learnt models to disk by setting `save_model = True`, which is set to `False` by default. The frequency of saving models can be adjusted using `save_model_interval` configuration. Models will be saved in the result directory, under the folder called *models*. The directory corresponding each run will contain models saved throughout the experiment, each within a folder corresponding to the number of timesteps passed since starting the learning process. 56 | 57 | ### Loading models 58 | 59 | Learnt models can be loaded using the `checkpoint_path` parameter, after which the learning will proceed from the corresponding timestep. 60 | 61 | ## Watching StarCraft II replays 62 | 63 | `save_replay` option allows saving replays of models which are loaded using `checkpoint_path`. Once the model is successfully loaded, `test_nepisode` number of episodes are run on the test mode and a .SC2Replay file is saved in the Replay directory of StarCraft II. Please make sure to use the episode runner if you wish to save a replay, i.e., `runner=episode`. The name of the saved replay file starts with the given `env_args.save_replay_prefix` (map_name if empty), followed by the current timestamp. 64 | 65 | The saved replays can be watched by double-clicking on them or using the following command: 66 | 67 | ```shell 68 | python -m pysc2.bin.play --norender --rgb_minimap_size 0 --replay NAME.SC2Replay 69 | ``` 70 | 71 | **Note:** Replays cannot be watched using the Linux version of StarCraft II. Please use either the Mac or Windows version of the StarCraft II client. 72 | -------------------------------------------------------------------------------- /src/controllers/basic_controller.py: -------------------------------------------------------------------------------- 1 | from modules.agents import REGISTRY as agent_REGISTRY 2 | from components.action_selectors import REGISTRY as action_REGISTRY 3 | import torch as th 4 | 5 | 6 | # This multi-agent controller shares parameters between agents 7 | class BasicMAC: 8 | def __init__(self, scheme, groups, args): 9 | self.n_agents = args.n_agents 10 | self.args = args 11 | input_shape = self._get_input_shape(scheme) 12 | self._build_agents(input_shape) 13 | self.agent_output_type = args.agent_output_type 14 | 15 | self.action_selector = action_REGISTRY[args.action_selector](args) 16 | 17 | self.hidden_states = None 18 | 19 | def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False): 20 | # Only select actions for the selected batch elements in bs 21 | avail_actions = ep_batch["avail_actions"][:, t_ep] 22 | agent_outputs = self.forward(ep_batch, t_ep, test_mode=test_mode) 23 | chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode) 24 | return chosen_actions 25 | 26 | def forward(self, ep_batch, t, test_mode=False): 27 | agent_inputs = self._build_inputs(ep_batch, t) 28 | avail_actions = ep_batch["avail_actions"][:, t] 29 | agent_outs, self.hidden_states = self.agent(agent_inputs, self.hidden_states) 30 | 31 | # Softmax the agent outputs if they're policy logits 32 | if self.agent_output_type == "pi_logits": 33 | 34 | if getattr(self.args, "mask_before_softmax", True): 35 | # Make the logits for unavailable actions very negative to minimise their affect on the softmax 36 | reshaped_avail_actions = avail_actions.reshape(ep_batch.batch_size * self.n_agents, -1) 37 | agent_outs[reshaped_avail_actions == 0] = -1e10 38 | 39 | agent_outs = th.nn.functional.softmax(agent_outs, dim=-1) 40 | if not test_mode: 41 | # Epsilon floor 42 | epsilon_action_num = agent_outs.size(-1) 43 | if getattr(self.args, "mask_before_softmax", True): 44 | # With probability epsilon, we will pick an available action uniformly 45 | epsilon_action_num = reshaped_avail_actions.sum(dim=1, keepdim=True).float() 46 | 47 | agent_outs = ((1 - self.action_selector.epsilon) * agent_outs 48 | + th.ones_like(agent_outs) * self.action_selector.epsilon/epsilon_action_num) 49 | 50 | if getattr(self.args, "mask_before_softmax", True): 51 | # Zero out the unavailable actions 52 | agent_outs[reshaped_avail_actions == 0] = 0.0 53 | 54 | return agent_outs.view(ep_batch.batch_size, self.n_agents, -1) 55 | 56 | def init_hidden(self, batch_size): 57 | self.hidden_states = self.agent.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1) # bav 58 | 59 | def parameters(self): 60 | return self.agent.parameters() 61 | 62 | def load_state(self, other_mac): 63 | self.agent.load_state_dict(other_mac.agent.state_dict()) 64 | 65 | def cuda(self): 66 | self.agent.cuda() 67 | 68 | def save_models(self, path): 69 | th.save(self.agent.state_dict(), "{}/agent.th".format(path)) 70 | 71 | def load_models(self, path): 72 | self.agent.load_state_dict(th.load("{}/agent.th".format(path), map_location=lambda storage, loc: storage)) 73 | 74 | def _build_agents(self, input_shape): 75 | self.agent = agent_REGISTRY[self.args.agent](input_shape, self.args) 76 | 77 | def _build_inputs(self, batch, t): 78 | # Assumes homogenous agents with flat observations. 79 | # Other MACs might want to e.g. delegate building inputs to each agent 80 | bs = batch.batch_size 81 | inputs = [] 82 | inputs.append(batch["obs"][:, t]) # b1av 83 | if self.args.obs_last_action: 84 | if t == 0: 85 | inputs.append(th.zeros_like(batch["actions_onehot"][:, t])) 86 | else: 87 | inputs.append(batch["actions_onehot"][:, t-1]) 88 | if self.args.obs_agent_id: 89 | inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1)) 90 | 91 | inputs = th.cat([x.reshape(bs*self.n_agents, -1) for x in inputs], dim=1) 92 | return inputs 93 | 94 | def _get_input_shape(self, scheme): 95 | input_shape = scheme["obs"]["vshape"] 96 | if self.args.obs_last_action: 97 | input_shape += scheme["actions_onehot"]["vshape"][0] 98 | if self.args.obs_agent_id: 99 | input_shape += self.n_agents 100 | 101 | return input_shape 102 | -------------------------------------------------------------------------------- /src/runners/episode_runner.py: -------------------------------------------------------------------------------- 1 | from envs import REGISTRY as env_REGISTRY 2 | from functools import partial 3 | from components.episode_buffer import EpisodeBatch 4 | import numpy as np 5 | 6 | 7 | class EpisodeRunner: 8 | 9 | def __init__(self, args, logger): 10 | self.args = args 11 | self.logger = logger 12 | self.batch_size = self.args.batch_size_run 13 | assert self.batch_size == 1 14 | 15 | self.env = env_REGISTRY[self.args.env](**self.args.env_args) 16 | self.episode_limit = self.env.episode_limit 17 | self.t = 0 18 | 19 | self.t_env = 0 20 | 21 | self.train_returns = [] 22 | self.test_returns = [] 23 | self.train_stats = {} 24 | self.test_stats = {} 25 | 26 | # Log the first run 27 | self.log_train_stats_t = -1000000 28 | 29 | def setup(self, scheme, groups, preprocess, mac): 30 | self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, 31 | preprocess=preprocess, device=self.args.device) 32 | self.mac = mac 33 | 34 | def get_env_info(self): 35 | return self.env.get_env_info() 36 | 37 | def save_replay(self): 38 | self.env.save_replay() 39 | 40 | def close_env(self): 41 | self.env.close() 42 | 43 | def reset(self): 44 | self.batch = self.new_batch() 45 | self.env.reset() 46 | self.t = 0 47 | 48 | def run(self, test_mode=False): 49 | self.reset() 50 | 51 | terminated = False 52 | episode_return = 0 53 | self.mac.init_hidden(batch_size=self.batch_size) 54 | env_info = {"alive_allies_list": [1 for _ in range(self.env.n_agents)]} 55 | 56 | while not terminated: 57 | 58 | pre_transition_data = { 59 | "state": [self.env.get_state()], 60 | "avail_actions": [self.env.get_avail_actions()], 61 | "obs": [self.env.get_obs()], 62 | "alive_allies": self.env.get_alive_agents(), 63 | "visible_allies": self.env.get_visibility_matrix() 64 | } 65 | 66 | self.batch.update(pre_transition_data, ts=self.t) 67 | 68 | # Pass the entire batch of experiences up till now to the agents 69 | # Receive the actions for each agent at this timestep in a batch of size 1 70 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode) 71 | 72 | reward, terminated, env_info = self.env.step(actions[0]) 73 | episode_return += reward 74 | 75 | post_transition_data = { 76 | "actions": actions, 77 | "reward": [(reward,)], 78 | "terminated": [(terminated != env_info.get("episode_limit", False),)] 79 | } 80 | 81 | self.batch.update(post_transition_data, ts=self.t) 82 | 83 | self.t += 1 84 | 85 | last_data = { 86 | "state": [self.env.get_state()], 87 | "avail_actions": [self.env.get_avail_actions()], 88 | "obs": [self.env.get_obs()], 89 | "alive_allies": self.env.get_alive_agents(), 90 | "visible_allies": self.env.get_visibility_matrix() 91 | } 92 | self.batch.update(last_data, ts=self.t) 93 | 94 | # Select actions in the last stored state 95 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode) 96 | self.batch.update({"actions": actions}, ts=self.t) 97 | 98 | cur_stats = self.test_stats if test_mode else self.train_stats 99 | cur_returns = self.test_returns if test_mode else self.train_returns 100 | log_prefix = "test_" if test_mode else "" 101 | cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) | set(env_info)}) 102 | cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0) 103 | cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0) 104 | 105 | if not test_mode: 106 | self.t_env += self.t 107 | 108 | cur_returns.append(episode_return) 109 | 110 | if test_mode and (len(self.test_returns) == self.args.test_nepisode): 111 | self._log(cur_returns, cur_stats, log_prefix) 112 | elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval: 113 | self._log(cur_returns, cur_stats, log_prefix) 114 | if hasattr(self.mac.action_selector, "epsilon"): 115 | self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env) 116 | self.log_train_stats_t = self.t_env 117 | 118 | return self.batch 119 | 120 | def _log(self, returns, stats, prefix): 121 | self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env) 122 | self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env) 123 | returns.clear() 124 | 125 | for k, v in stats.items(): 126 | if k != "n_episodes": 127 | self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env) 128 | stats.clear() 129 | -------------------------------------------------------------------------------- /src/modules/mixers/qtran.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class QTranBase(nn.Module): 8 | def __init__(self, args): 9 | super(QTranBase, self).__init__() 10 | 11 | self.args = args 12 | 13 | self.n_agents = args.n_agents 14 | self.n_actions = args.n_actions 15 | self.state_dim = int(np.prod(args.state_shape)) 16 | self.arch = self.args.qtran_arch # QTran architecture 17 | 18 | self.embed_dim = args.mixing_embed_dim 19 | 20 | # Q(s,u) 21 | if self.arch == "coma_critic": 22 | # Q takes [state, u] as input 23 | q_input_size = self.state_dim + (self.n_agents * self.n_actions) 24 | elif self.arch == "qtran_paper": 25 | # Q takes [state, agent_action_observation_encodings] 26 | q_input_size = self.state_dim + self.args.rnn_hidden_dim + self.n_actions 27 | else: 28 | raise Exception("{} is not a valid QTran architecture".format(self.arch)) 29 | 30 | if self.args.network_size == "small": 31 | self.Q = nn.Sequential(nn.Linear(q_input_size, self.embed_dim), 32 | nn.ReLU(), 33 | nn.Linear(self.embed_dim, self.embed_dim), 34 | nn.ReLU(), 35 | nn.Linear(self.embed_dim, 1)) 36 | 37 | # V(s) 38 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 39 | nn.ReLU(), 40 | nn.Linear(self.embed_dim, self.embed_dim), 41 | nn.ReLU(), 42 | nn.Linear(self.embed_dim, 1)) 43 | ae_input = self.args.rnn_hidden_dim + self.n_actions 44 | self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), 45 | nn.ReLU(), 46 | nn.Linear(ae_input, ae_input)) 47 | elif self.args.network_size == "big": 48 | self.Q = nn.Sequential(nn.Linear(q_input_size, self.embed_dim), 49 | nn.ReLU(), 50 | nn.Linear(self.embed_dim, self.embed_dim), 51 | nn.ReLU(), 52 | nn.Linear(self.embed_dim, self.embed_dim), 53 | nn.ReLU(), 54 | nn.Linear(self.embed_dim, 1)) 55 | # V(s) 56 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 57 | nn.ReLU(), 58 | nn.Linear(self.embed_dim, self.embed_dim), 59 | nn.ReLU(), 60 | nn.Linear(self.embed_dim, self.embed_dim), 61 | nn.ReLU(), 62 | nn.Linear(self.embed_dim, 1)) 63 | ae_input = self.args.rnn_hidden_dim + self.n_actions 64 | self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), 65 | nn.ReLU(), 66 | nn.Linear(ae_input, ae_input)) 67 | else: 68 | assert False 69 | 70 | def forward(self, batch, hidden_states, actions=None): 71 | bs = batch.batch_size 72 | ts = batch.max_seq_length 73 | 74 | states = batch["state"].reshape(bs * ts, self.state_dim) 75 | 76 | if self.arch == "coma_critic": 77 | if actions is None: 78 | # Use the actions taken by the agents 79 | actions = batch["actions_onehot"].reshape(bs * ts, self.n_agents * self.n_actions) 80 | else: 81 | # It will arrive as (bs, ts, agents, actions), we need to reshape it 82 | actions = actions.reshape(bs * ts, self.n_agents * self.n_actions) 83 | inputs = th.cat([states, actions], dim=1) 84 | elif self.arch == "qtran_paper": 85 | if actions is None: 86 | # Use the actions taken by the agents 87 | actions = batch["actions_onehot"].reshape(bs * ts, self.n_agents, self.n_actions) 88 | else: 89 | # It will arrive as (bs, ts, agents, actions), we need to reshape it 90 | actions = actions.reshape(bs * ts, self.n_agents, self.n_actions) 91 | 92 | hidden_states = hidden_states.reshape(bs * ts, self.n_agents, -1) 93 | agent_state_action_input = th.cat([hidden_states, actions], dim=2) 94 | agent_state_action_encoding = self.action_encoding(agent_state_action_input.reshape(bs * ts * self.n_agents, -1)).reshape(bs * ts, self.n_agents, -1) 95 | agent_state_action_encoding = agent_state_action_encoding.sum(dim=1) # Sum across agents 96 | 97 | inputs = th.cat([states, agent_state_action_encoding], dim=1) 98 | 99 | q_outputs = self.Q(inputs) 100 | 101 | states = batch["state"].reshape(bs * ts, self.state_dim) 102 | v_outputs = self.V(states) 103 | 104 | return q_outputs, v_outputs 105 | 106 | -------------------------------------------------------------------------------- /src/modules/vi/vgae.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class state_autoencoder(nn.Module): 8 | def __init__(self, args): 9 | super(state_autoencoder, self).__init__() 10 | self.args = args 11 | #self.input_dim = int(np.prod(args.observation_shape)) + args.n_actions 12 | self.input_dim = args.rnn_hidden_dim 13 | self.latent_dim = args.latent_dim 14 | self.hidden_dim = args.hidden_dim 15 | self.n_agents = args.n_agents 16 | 17 | self.encoder_weight_1 = nn.Linear(self.input_dim, self.hidden_dim, bias=False) 18 | #self.gru_layer = nn.GRU(self.hidden_dim, self.hidden_dim, batch_first=True) 19 | self.weight_logvar = nn.Linear(self.hidden_dim, self.latent_dim, bias=False) 20 | self.weight_mu = nn.Linear(self.hidden_dim, self.latent_dim, bias=False) 21 | 22 | self.decoder_weight_1 = nn.Linear(self.latent_dim, self.hidden_dim, bias=False) 23 | self.decoder_weight_2 = nn.Linear(self.hidden_dim, self.input_dim, bias=False) 24 | 25 | def encode(self, node_features, adj_list): 26 | hidden = torch.zeros(1, node_features.size(0) * node_features.size(1), self.hidden_dim).to(node_features.device) 27 | 28 | a_tilde = adj_list + torch.eye(self.n_agents).unsqueeze(0).unsqueeze(0).expand_as(adj_list).to(adj_list.device) 29 | d_tilde_diag = a_tilde.sum(-1).view(-1, self.n_agents, 1) ** (-0.5) 30 | d_tilde = d_tilde_diag.bmm(torch.ones_like(d_tilde_diag).permute([0, 2, 1])) * torch.eye(self.n_agents).unsqueeze(0).to(a_tilde.device) 31 | encoder_factor = d_tilde.bmm(a_tilde.view(-1, self.n_agents, self.n_agents)).bmm(d_tilde) 32 | 33 | node_features = node_features.reshape([-1, self.n_agents, self.input_dim]) 34 | 35 | encoder_1 = self.encoder_weight_1(encoder_factor.bmm(node_features)) 36 | encoder_1 = F.relu(encoder_1) 37 | #encoder_1, _ = self.gru_layer(encoder_1, hidden) 38 | 39 | encoder_2 = encoder_factor.bmm(encoder_1) 40 | 41 | mu = self.weight_mu(encoder_2).view(adj_list.size(0), adj_list.size(1), self.n_agents, self.latent_dim) 42 | logvar = self.weight_logvar(encoder_2).view(adj_list.size(0), adj_list.size(1), self.n_agents, self.latent_dim) 43 | return mu, logvar 44 | 45 | def sample_z(self, mu, logvar): 46 | eps = torch.randn_like(logvar) 47 | return mu + eps * torch.exp(0.5 * logvar) 48 | 49 | def decode(self, z, adj_list): 50 | a_hat = -adj_list + 2 * torch.eye(self.n_agents).unsqueeze(0).unsqueeze(0).expand_as(adj_list).to(adj_list.device) 51 | d_hat_diag = (adj_list.sum(-1).view(-1, self.n_agents, 1) + 2) ** (-0.5) 52 | d_hat = d_hat_diag.bmm(torch.ones_like(d_hat_diag).permute([0, 2, 1])) * torch.eye(self.n_agents).unsqueeze(0).to(a_hat.device) 53 | decoder_factor = d_hat.bmm(a_hat.view(-1, self.n_agents, self.n_agents)).bmm(d_hat) 54 | 55 | z = z.view(-1, self.n_agents, self.latent_dim) 56 | 57 | decoder_1 = self.decoder_weight_1(decoder_factor.bmm(z)) 58 | decoder_1 = F.relu(decoder_1) 59 | 60 | recon_node_features = self.decoder_weight_2(decoder_factor.bmm(decoder_1)).view(adj_list.size(0), adj_list.size(1), self.n_agents, self.input_dim) 61 | return recon_node_features 62 | 63 | def forward(self, node_features, adj_list): 64 | adj_list = adj_list * (1 - torch.eye(self.n_agents)).unsqueeze(0).unsqueeze(0).to(adj_list.device) 65 | mu, logvar = self.encode(node_features, adj_list) 66 | z = self.sample_z(mu, logvar) 67 | recon_node_features = self.decode(z, adj_list) 68 | return recon_node_features, mu.view(adj_list.size(0), adj_list.size(1), -1), logvar.view(adj_list.size(0), adj_list.size(1), -1), z.view(adj_list.size(0), adj_list.size(1), -1) 69 | 70 | 71 | class flatten_state_autoencoder(nn.Module): 72 | def __init__(self, args): 73 | super(flatten_state_autoencoder, self).__init__() 74 | self.args = args 75 | self.input_dim = args.rnn_hidden_dim 76 | self.latent_dim = args.latent_dim 77 | self.hidden_dim = args.hidden_dim 78 | self.n_agents = args.n_agents 79 | 80 | self.encoder_weight_1 = nn.Linear(self.input_dim, self.hidden_dim) 81 | self.weight_logvar = nn.Linear(self.hidden_dim, self.latent_dim) 82 | self.weight_mu = nn.Linear(self.hidden_dim, self.latent_dim) 83 | 84 | self.decoder_weight_1 = nn.Linear(self.latent_dim, self.hidden_dim) 85 | self.decoder_weight_2 = nn.Linear(self.hidden_dim, self.input_dim) 86 | 87 | def encode(self, node_features, adj_list): 88 | bs = adj_list.size(0) 89 | sl = adj_list.size(1) 90 | encoder_1 = F.relu(self.encoder_weight_1(node_features)) 91 | mu = self.weight_mu(encoder_1).view(bs, sl, self.n_agents, self.latent_dim) 92 | logvar = self.weight_logvar(encoder_1).view(bs, sl, self.n_agents, self.latent_dim) 93 | return mu, logvar 94 | 95 | def sample_z(self, mu, logvar): 96 | eps = torch.randn_like(logvar) 97 | return mu + eps * torch.exp(0.5 * logvar) 98 | 99 | def decode(self, z, adj_list): 100 | decoder_1 = F.relu(self.decoder_weight_1(z)) 101 | recon_node_features = self.decoder_weight_2(decoder_1).view(adj_list.size(0), adj_list.size(1), self.n_agents, self.input_dim) 102 | return recon_node_features 103 | 104 | def forward(self, node_features, adj_list): 105 | adj_list = adj_list * (1 - torch.eye(self.n_agents)).unsqueeze(0).unsqueeze(0).to(adj_list.device) 106 | mu, logvar = self.encode(node_features, adj_list) 107 | z = self.sample_z(mu, logvar) 108 | recon_node_features = self.decode(z, adj_list) 109 | return recon_node_features, mu.view(adj_list.size(0), adj_list.size(1), -1), logvar.view(adj_list.size(0), adj_list.size(1), -1), z.view(adj_list.size(0), adj_list.size(1), -1) 110 | -------------------------------------------------------------------------------- /src/learners/q_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.vdn import VDNMixer 4 | from modules.mixers.qmix import QMixer 5 | import torch as th 6 | from torch.optim import RMSprop 7 | 8 | 9 | class QLearner: 10 | def __init__(self, mac, scheme, logger, args): 11 | self.args = args 12 | self.mac = mac 13 | self.logger = logger 14 | 15 | self.params = list(mac.parameters()) 16 | 17 | self.last_target_update_episode = 0 18 | 19 | self.mixer = None 20 | if args.mixer is not None: 21 | if args.mixer == "vdn": 22 | self.mixer = VDNMixer() 23 | elif args.mixer == "qmix": 24 | self.mixer = QMixer(args) 25 | else: 26 | raise ValueError("Mixer {} not recognised.".format(args.mixer)) 27 | self.params += list(self.mixer.parameters()) 28 | self.target_mixer = copy.deepcopy(self.mixer) 29 | 30 | self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 31 | 32 | # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC 33 | self.target_mac = copy.deepcopy(mac) 34 | 35 | self.log_stats_t = -self.args.learner_log_interval - 1 36 | 37 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 38 | # Get the relevant quantities 39 | rewards = batch["reward"][:, :-1] 40 | actions = batch["actions"][:, :-1] 41 | terminated = batch["terminated"][:, :-1].float() 42 | mask = batch["filled"][:, :-1].float() 43 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 44 | avail_actions = batch["avail_actions"] 45 | 46 | # Calculate estimated Q-Values 47 | mac_out = [] 48 | self.mac.init_hidden(batch.batch_size) 49 | for t in range(batch.max_seq_length): 50 | agent_outs = self.mac.forward(batch, t=t) 51 | mac_out.append(agent_outs) 52 | mac_out = th.stack(mac_out, dim=1) # Concat over time 53 | 54 | # Pick the Q-Values for the actions taken by each agent 55 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 56 | 57 | # Calculate the Q-Values necessary for the target 58 | target_mac_out = [] 59 | self.target_mac.init_hidden(batch.batch_size) 60 | for t in range(batch.max_seq_length): 61 | target_agent_outs = self.target_mac.forward(batch, t=t) 62 | target_mac_out.append(target_agent_outs) 63 | 64 | # We don't need the first timesteps Q-Value estimate for calculating targets 65 | target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time 66 | 67 | # Mask out unavailable actions 68 | target_mac_out[avail_actions[:, 1:] == 0] = -9999999 69 | 70 | # Max over target Q-Values 71 | if self.args.double_q: 72 | # Get actions that maximise live Q (for double q-learning) 73 | mac_out_detach = mac_out.clone().detach() 74 | mac_out_detach[avail_actions == 0] = -9999999 75 | cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] 76 | target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) 77 | else: 78 | target_max_qvals = target_mac_out.max(dim=3)[0] 79 | 80 | # Mix 81 | if self.mixer is not None: 82 | chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) 83 | target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) 84 | 85 | # Calculate 1-step Q-Learning targets 86 | targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals 87 | 88 | # Td-error 89 | td_error = (chosen_action_qvals - targets.detach()) 90 | 91 | mask = mask.expand_as(td_error) 92 | 93 | # 0-out the targets that came from padded data 94 | masked_td_error = td_error * mask 95 | 96 | # Normal L2 loss, take mean over actual data 97 | loss = (masked_td_error ** 2).sum() / mask.sum() 98 | 99 | # Optimise 100 | self.optimiser.zero_grad() 101 | loss.backward() 102 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 103 | self.optimiser.step() 104 | 105 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 106 | self._update_targets() 107 | self.last_target_update_episode = episode_num 108 | 109 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 110 | self.logger.log_stat("loss", loss.item(), t_env) 111 | self.logger.log_stat("grad_norm", grad_norm, t_env) 112 | mask_elems = mask.sum().item() 113 | self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) 114 | self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 115 | self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 116 | self.log_stats_t = t_env 117 | 118 | def _update_targets(self): 119 | self.target_mac.load_state(self.mac) 120 | if self.mixer is not None: 121 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 122 | self.logger.console_logger.info("Updated target network") 123 | 124 | def cuda(self): 125 | self.mac.cuda() 126 | self.target_mac.cuda() 127 | if self.mixer is not None: 128 | self.mixer.cuda() 129 | self.target_mixer.cuda() 130 | 131 | def save_models(self, path): 132 | self.mac.save_models(path) 133 | if self.mixer is not None: 134 | th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) 135 | th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) 136 | 137 | def load_models(self, path): 138 | self.mac.load_models(path) 139 | # Not quite right but I don't want to save target networks 140 | self.target_mac.load_models(path) 141 | if self.mixer is not None: 142 | self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) 143 | self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) -------------------------------------------------------------------------------- /src/envs/starcraft2/maps/smac_maps.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pysc2.maps import lib 6 | 7 | 8 | class SMACMap(lib.Map): 9 | directory = "SMAC_Maps" 10 | download = "https://github.com/oxwhirl/smac#smac-maps" 11 | players = 2 12 | step_mul = 8 13 | game_steps_per_episode = 0 14 | 15 | 16 | map_param_registry = { 17 | "3m": { 18 | "n_agents": 3, 19 | "n_enemies": 3, 20 | "limit": 60, 21 | "a_race": "T", 22 | "b_race": "T", 23 | "unit_type_bits": 0, 24 | "map_type": "marines", 25 | }, 26 | "8m": { 27 | "n_agents": 8, 28 | "n_enemies": 8, 29 | "limit": 120, 30 | "a_race": "T", 31 | "b_race": "T", 32 | "unit_type_bits": 0, 33 | "map_type": "marines", 34 | }, 35 | "25m": { 36 | "n_agents": 25, 37 | "n_enemies": 25, 38 | "limit": 150, 39 | "a_race": "T", 40 | "b_race": "T", 41 | "unit_type_bits": 0, 42 | "map_type": "marines", 43 | }, 44 | "25m_modified": { 45 | "n_agents": 25, 46 | "n_enemies": 25, 47 | "limit": 150, 48 | "a_race": "T", 49 | "b_race": "T", 50 | "unit_type_bits": 0, 51 | "map_type": "marines", 52 | }, 53 | "5m_vs_6m": { 54 | "n_agents": 5, 55 | "n_enemies": 6, 56 | "limit": 70, 57 | "a_race": "T", 58 | "b_race": "T", 59 | "unit_type_bits": 0, 60 | "map_type": "marines", 61 | }, 62 | "8m_vs_9m": { 63 | "n_agents": 8, 64 | "n_enemies": 9, 65 | "limit": 120, 66 | "a_race": "T", 67 | "b_race": "T", 68 | "unit_type_bits": 0, 69 | "map_type": "marines", 70 | }, 71 | "10m_vs_11m": { 72 | "n_agents": 10, 73 | "n_enemies": 11, 74 | "limit": 150, 75 | "a_race": "T", 76 | "b_race": "T", 77 | "unit_type_bits": 0, 78 | "map_type": "marines", 79 | }, 80 | "10m_vs_11m_modified": { 81 | "n_agents": 10, 82 | "n_enemies": 11, 83 | "limit": 150, 84 | "a_race": "T", 85 | "b_race": "T", 86 | "unit_type_bits": 0, 87 | "map_type": "marines", 88 | }, 89 | "27m_vs_30m": { 90 | "n_agents": 27, 91 | "n_enemies": 30, 92 | "limit": 180, 93 | "a_race": "T", 94 | "b_race": "T", 95 | "unit_type_bits": 0, 96 | "map_type": "marines", 97 | }, 98 | "MMM": { 99 | "n_agents": 10, 100 | "n_enemies": 10, 101 | "limit": 150, 102 | "a_race": "T", 103 | "b_race": "T", 104 | "unit_type_bits": 3, 105 | "map_type": "MMM", 106 | }, 107 | "MMM_modified": { 108 | "n_agents": 10, 109 | "n_enemies": 10, 110 | "limit": 150, 111 | "a_race": "T", 112 | "b_race": "T", 113 | "unit_type_bits": 3, 114 | "map_type": "MMM", 115 | }, 116 | "MMM2": { 117 | "n_agents": 10, 118 | "n_enemies": 12, 119 | "limit": 180, 120 | "a_race": "T", 121 | "b_race": "T", 122 | "unit_type_bits": 3, 123 | "map_type": "MMM", 124 | }, 125 | "2s3z": { 126 | "n_agents": 5, 127 | "n_enemies": 5, 128 | "limit": 120, 129 | "a_race": "P", 130 | "b_race": "P", 131 | "unit_type_bits": 2, 132 | "map_type": "stalkers_and_zealots", 133 | }, 134 | "2s3z_modified": { 135 | "n_agents": 5, 136 | "n_enemies": 5, 137 | "limit": 120, 138 | "a_race": "P", 139 | "b_race": "P", 140 | "unit_type_bits": 2, 141 | "map_type": "stalkers_and_zealots", 142 | }, 143 | "3s5z": { 144 | "n_agents": 8, 145 | "n_enemies": 8, 146 | "limit": 150, 147 | "a_race": "P", 148 | "b_race": "P", 149 | "unit_type_bits": 2, 150 | "map_type": "stalkers_and_zealots", 151 | }, 152 | "3s5z_vs_3s6z": { 153 | "n_agents": 8, 154 | "n_enemies": 9, 155 | "limit": 170, 156 | "a_race": "P", 157 | "b_race": "P", 158 | "unit_type_bits": 2, 159 | "map_type": "stalkers_and_zealots", 160 | }, 161 | "3s_vs_3z": { 162 | "n_agents": 3, 163 | "n_enemies": 3, 164 | "limit": 150, 165 | "a_race": "P", 166 | "b_race": "P", 167 | "unit_type_bits": 0, 168 | "map_type": "stalkers", 169 | }, 170 | "3s_vs_4z": { 171 | "n_agents": 3, 172 | "n_enemies": 4, 173 | "limit": 200, 174 | "a_race": "P", 175 | "b_race": "P", 176 | "unit_type_bits": 0, 177 | "map_type": "stalkers", 178 | }, 179 | "3s_vs_5z": { 180 | "n_agents": 3, 181 | "n_enemies": 5, 182 | "limit": 250, 183 | "a_race": "P", 184 | "b_race": "P", 185 | "unit_type_bits": 0, 186 | "map_type": "stalkers", 187 | }, 188 | "1c3s5z": { 189 | "n_agents": 9, 190 | "n_enemies": 9, 191 | "limit": 180, 192 | "a_race": "P", 193 | "b_race": "P", 194 | "unit_type_bits": 3, 195 | "map_type": "colossi_stalkers_zealots", 196 | }, 197 | "2m_vs_1z": { 198 | "n_agents": 2, 199 | "n_enemies": 1, 200 | "limit": 150, 201 | "a_race": "T", 202 | "b_race": "P", 203 | "unit_type_bits": 0, 204 | "map_type": "marines", 205 | }, 206 | "corridor": { 207 | "n_agents": 6, 208 | "n_enemies": 24, 209 | "limit": 400, 210 | "a_race": "P", 211 | "b_race": "Z", 212 | "unit_type_bits": 0, 213 | "map_type": "zealots", 214 | }, 215 | "6h_vs_8z": { 216 | "n_agents": 6, 217 | "n_enemies": 8, 218 | "limit": 150, 219 | "a_race": "Z", 220 | "b_race": "P", 221 | "unit_type_bits": 0, 222 | "map_type": "hydralisks", 223 | }, 224 | "2s_vs_1sc": { 225 | "n_agents": 2, 226 | "n_enemies": 1, 227 | "limit": 300, 228 | "a_race": "P", 229 | "b_race": "Z", 230 | "unit_type_bits": 0, 231 | "map_type": "stalkers", 232 | }, 233 | "so_many_baneling": { 234 | "n_agents": 7, 235 | "n_enemies": 32, 236 | "limit": 100, 237 | "a_race": "P", 238 | "b_race": "Z", 239 | "unit_type_bits": 0, 240 | "map_type": "zealots", 241 | }, 242 | "bane_vs_bane": { 243 | "n_agents": 24, 244 | "n_enemies": 24, 245 | "limit": 200, 246 | "a_race": "Z", 247 | "b_race": "Z", 248 | "unit_type_bits": 2, 249 | "map_type": "bane", 250 | }, 251 | "2c_vs_64zg": { 252 | "n_agents": 2, 253 | "n_enemies": 64, 254 | "limit": 400, 255 | "a_race": "P", 256 | "b_race": "Z", 257 | "unit_type_bits": 0, 258 | "map_type": "colossus", 259 | }, 260 | } 261 | 262 | 263 | def get_smac_map_registry(): 264 | return map_param_registry 265 | 266 | 267 | for name in map_param_registry.keys(): 268 | globals()[name] = type(name, (SMACMap,), dict(filename=name)) 269 | -------------------------------------------------------------------------------- /src/learners/coma_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.critics.coma import COMACritic 4 | from utils.rl_utils import build_td_lambda_targets 5 | import torch as th 6 | from torch.optim import RMSprop 7 | 8 | 9 | class COMALearner: 10 | def __init__(self, mac, scheme, logger, args): 11 | self.args = args 12 | self.n_agents = args.n_agents 13 | self.n_actions = args.n_actions 14 | self.mac = mac 15 | self.logger = logger 16 | 17 | self.last_target_update_step = 0 18 | self.critic_training_steps = 0 19 | 20 | self.log_stats_t = -self.args.learner_log_interval - 1 21 | 22 | self.critic = COMACritic(scheme, args) 23 | self.target_critic = copy.deepcopy(self.critic) 24 | 25 | self.agent_params = list(mac.parameters()) 26 | self.critic_params = list(self.critic.parameters()) 27 | self.params = self.agent_params + self.critic_params 28 | 29 | self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 30 | self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) 31 | 32 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 33 | # Get the relevant quantities 34 | bs = batch.batch_size 35 | max_t = batch.max_seq_length 36 | rewards = batch["reward"][:, :-1] 37 | actions = batch["actions"][:, :] 38 | terminated = batch["terminated"][:, :-1].float() 39 | mask = batch["filled"][:, :-1].float() 40 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 41 | avail_actions = batch["avail_actions"][:, :-1] 42 | 43 | critic_mask = mask.clone() 44 | 45 | mask = mask.repeat(1, 1, self.n_agents).view(-1) 46 | 47 | q_vals, critic_train_stats = self._train_critic(batch, rewards, terminated, actions, avail_actions, 48 | critic_mask, bs, max_t) 49 | 50 | actions = actions[:,:-1] 51 | 52 | mac_out = [] 53 | self.mac.init_hidden(batch.batch_size) 54 | for t in range(batch.max_seq_length - 1): 55 | agent_outs = self.mac.forward(batch, t=t) 56 | mac_out.append(agent_outs) 57 | mac_out = th.stack(mac_out, dim=1) # Concat over time 58 | 59 | # Mask out unavailable actions, renormalise (as in action selection) 60 | mac_out[avail_actions == 0] = 0 61 | mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True) 62 | mac_out[avail_actions == 0] = 0 63 | 64 | # Calculated baseline 65 | q_vals = q_vals.reshape(-1, self.n_actions) 66 | pi = mac_out.view(-1, self.n_actions) 67 | baseline = (pi * q_vals).sum(-1).detach() 68 | 69 | # Calculate policy grad with mask 70 | q_taken = th.gather(q_vals, dim=1, index=actions.reshape(-1, 1)).squeeze(1) 71 | pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) 72 | pi_taken[mask == 0] = 1.0 73 | log_pi_taken = th.log(pi_taken) 74 | 75 | advantages = (q_taken - baseline).detach() 76 | 77 | coma_loss = - ((advantages * log_pi_taken) * mask).sum() / mask.sum() 78 | 79 | # Optimise agents 80 | self.agent_optimiser.zero_grad() 81 | coma_loss.backward() 82 | grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) 83 | self.agent_optimiser.step() 84 | 85 | if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0: 86 | self._update_targets() 87 | self.last_target_update_step = self.critic_training_steps 88 | 89 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 90 | ts_logged = len(critic_train_stats["critic_loss"]) 91 | for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean"]: 92 | self.logger.log_stat(key, sum(critic_train_stats[key])/ts_logged, t_env) 93 | 94 | self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) 95 | self.logger.log_stat("coma_loss", coma_loss.item(), t_env) 96 | self.logger.log_stat("agent_grad_norm", grad_norm, t_env) 97 | self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) 98 | self.log_stats_t = t_env 99 | 100 | def _train_critic(self, batch, rewards, terminated, actions, avail_actions, mask, bs, max_t): 101 | # Optimise critic 102 | target_q_vals = self.target_critic(batch)[:, :] 103 | targets_taken = th.gather(target_q_vals, dim=3, index=actions).squeeze(3) 104 | 105 | # Calculate td-lambda targets 106 | targets = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda) 107 | 108 | q_vals = th.zeros_like(target_q_vals)[:, :-1] 109 | 110 | running_log = { 111 | "critic_loss": [], 112 | "critic_grad_norm": [], 113 | "td_error_abs": [], 114 | "target_mean": [], 115 | "q_taken_mean": [], 116 | } 117 | 118 | for t in reversed(range(rewards.size(1))): 119 | mask_t = mask[:, t].expand(-1, self.n_agents) 120 | if mask_t.sum() == 0: 121 | continue 122 | 123 | q_t = self.critic(batch, t) 124 | q_vals[:, t] = q_t.view(bs, self.n_agents, self.n_actions) 125 | q_taken = th.gather(q_t, dim=3, index=actions[:, t:t+1]).squeeze(3).squeeze(1) 126 | targets_t = targets[:, t] 127 | 128 | td_error = (q_taken - targets_t.detach()) 129 | 130 | # 0-out the targets that came from padded data 131 | masked_td_error = td_error * mask_t 132 | 133 | # Normal L2 loss, take mean over actual data 134 | loss = (masked_td_error ** 2).sum() / mask_t.sum() 135 | self.critic_optimiser.zero_grad() 136 | loss.backward() 137 | grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip) 138 | self.critic_optimiser.step() 139 | self.critic_training_steps += 1 140 | 141 | running_log["critic_loss"].append(loss.item()) 142 | running_log["critic_grad_norm"].append(grad_norm) 143 | mask_elems = mask_t.sum().item() 144 | running_log["td_error_abs"].append((masked_td_error.abs().sum().item() / mask_elems)) 145 | running_log["q_taken_mean"].append((q_taken * mask_t).sum().item() / mask_elems) 146 | running_log["target_mean"].append((targets_t * mask_t).sum().item() / mask_elems) 147 | 148 | return q_vals, running_log 149 | 150 | def _update_targets(self): 151 | self.target_critic.load_state_dict(self.critic.state_dict()) 152 | self.logger.console_logger.info("Updated target network") 153 | 154 | def cuda(self): 155 | self.mac.cuda() 156 | self.critic.cuda() 157 | self.target_critic.cuda() 158 | 159 | def save_models(self, path): 160 | self.mac.save_models(path) 161 | th.save(self.critic.state_dict(), "{}/critic.th".format(path)) 162 | th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) 163 | th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) 164 | 165 | def load_models(self, path): 166 | self.mac.load_models(path) 167 | self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) 168 | # Not quite right but I don't want to save target networks 169 | self.target_critic.load_state_dict(self.critic.state_dict()) 170 | self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) 171 | self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage)) 172 | -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pprint 4 | import time 5 | import threading 6 | import torch as th 7 | from types import SimpleNamespace as SN 8 | from utils.logging import Logger 9 | from utils.timehelper import time_left, time_str 10 | from os.path import dirname, abspath 11 | 12 | from learners import REGISTRY as le_REGISTRY 13 | from runners import REGISTRY as r_REGISTRY 14 | from controllers import REGISTRY as mac_REGISTRY 15 | from components.episode_buffer import ReplayBuffer 16 | from components.transforms import OneHot 17 | 18 | 19 | def run(_run, _config, _log): 20 | 21 | # check args sanity 22 | _config = args_sanity_check(_config, _log) 23 | 24 | args = SN(**_config) 25 | args.device = "cuda" if args.use_cuda else "cpu" 26 | 27 | # setup loggers 28 | logger = Logger(_log) 29 | 30 | _log.info("Experiment Parameters:") 31 | experiment_params = pprint.pformat(_config, 32 | indent=4, 33 | width=1) 34 | _log.info("\n\n" + experiment_params + "\n") 35 | 36 | # configure tensorboard logger 37 | unique_token = "{}__{}".format(args.name, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) 38 | args.unique_token = unique_token 39 | if args.use_tensorboard: 40 | tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs") 41 | tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token) 42 | logger.setup_tb(tb_exp_direc) 43 | 44 | # sacred is on by default 45 | logger.setup_sacred(_run) 46 | 47 | # Run and train 48 | run_sequential(args=args, logger=logger) 49 | 50 | # Clean up after finishing 51 | print("Exiting Main") 52 | 53 | print("Stopping all threads") 54 | for t in threading.enumerate(): 55 | if t.name != "MainThread": 56 | print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon)) 57 | t.join(timeout=1) 58 | print("Thread joined") 59 | 60 | print("Exiting script") 61 | 62 | # Making sure framework really exits 63 | os._exit(os.EX_OK) 64 | 65 | 66 | def evaluate_sequential(args, runner): 67 | 68 | for _ in range(args.test_nepisode): 69 | runner.run(test_mode=True) 70 | 71 | if args.save_replay: 72 | runner.save_replay() 73 | 74 | runner.close_env() 75 | 76 | def run_sequential(args, logger): 77 | 78 | # Init runner so we can get env info 79 | runner = r_REGISTRY[args.runner](args=args, logger=logger) 80 | 81 | # Set up schemes and groups here 82 | env_info = runner.get_env_info() 83 | args.n_agents = env_info["n_agents"] 84 | args.n_actions = env_info["n_actions"] 85 | args.state_shape = env_info["state_shape"] 86 | args.observation_shape = env_info["obs_shape"] 87 | 88 | # Default/Base scheme 89 | scheme = { 90 | "state": {"vshape": env_info["state_shape"]}, 91 | "obs": {"vshape": env_info["obs_shape"], "group": "agents"}, 92 | "actions": {"vshape": (1,), "group": "agents", "dtype": th.long}, 93 | "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int}, 94 | "reward": {"vshape": (1,)}, 95 | "terminated": {"vshape": (1,), "dtype": th.uint8}, 96 | "alive_allies": {"vshape": (env_info["n_agents"], env_info["n_agents"])}, 97 | "visible_allies": {"vshape": (env_info["n_agents"], env_info["n_agents"] + env_info["n_enemies"])} 98 | } 99 | groups = { 100 | "agents": args.n_agents 101 | } 102 | preprocess = { 103 | "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)]) 104 | } 105 | 106 | buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1, 107 | preprocess=preprocess, 108 | device="cpu" if args.buffer_cpu_only else args.device) 109 | 110 | # Setup multiagent controller here 111 | mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args) 112 | 113 | # Give runner the scheme 114 | runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac) 115 | 116 | # Learner 117 | learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args) 118 | 119 | if args.use_cuda: 120 | learner.cuda() 121 | 122 | if args.checkpoint_path != "": 123 | 124 | timesteps = [] 125 | timestep_to_load = 0 126 | 127 | if not os.path.isdir(args.checkpoint_path): 128 | logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path)) 129 | return 130 | 131 | # Go through all files in args.checkpoint_path 132 | for name in os.listdir(args.checkpoint_path): 133 | full_name = os.path.join(args.checkpoint_path, name) 134 | # Check if they are dirs the names of which are numbers 135 | if os.path.isdir(full_name) and name.isdigit(): 136 | timesteps.append(int(name)) 137 | 138 | if args.load_step == 0: 139 | # choose the max timestep 140 | timestep_to_load = max(timesteps) 141 | else: 142 | # choose the timestep closest to load_step 143 | timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step)) 144 | 145 | model_path = os.path.join(args.checkpoint_path, str(timestep_to_load)) 146 | 147 | logger.console_logger.info("Loading model from {}".format(model_path)) 148 | learner.load_models(model_path) 149 | runner.t_env = timestep_to_load 150 | 151 | if args.evaluate or args.save_replay: 152 | evaluate_sequential(args, runner) 153 | return 154 | 155 | # start training 156 | episode = 0 157 | last_test_T = -args.test_interval - 1 158 | last_log_T = 0 159 | model_save_time = 0 160 | 161 | start_time = time.time() 162 | last_time = start_time 163 | 164 | logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max)) 165 | 166 | while runner.t_env <= args.t_max: 167 | 168 | # Run for a whole episode at a time 169 | episode_batch = runner.run(test_mode=False) 170 | buffer.insert_episode_batch(episode_batch) 171 | 172 | if buffer.can_sample(args.batch_size): 173 | episode_sample = buffer.sample(args.batch_size) 174 | 175 | # Truncate batch to only filled timesteps 176 | max_ep_t = episode_sample.max_t_filled() 177 | episode_sample = episode_sample[:, :max_ep_t] 178 | 179 | if episode_sample.device != args.device: 180 | episode_sample.to(args.device) 181 | 182 | learner.train(episode_sample, runner.t_env, episode) 183 | 184 | # Execute test runs once in a while 185 | n_test_runs = max(1, args.test_nepisode // runner.batch_size) 186 | if (runner.t_env - last_test_T) / args.test_interval >= 1.0: 187 | 188 | logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max)) 189 | logger.console_logger.info("Estimated time left: {}. Time passed: {}".format( 190 | time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time))) 191 | last_time = time.time() 192 | 193 | last_test_T = runner.t_env 194 | for _ in range(n_test_runs): 195 | runner.run(test_mode=True) 196 | 197 | if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0): 198 | model_save_time = runner.t_env 199 | save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env)) 200 | #"results/models/{}".format(unique_token) 201 | os.makedirs(save_path, exist_ok=True) 202 | logger.console_logger.info("Saving models to {}".format(save_path)) 203 | 204 | # learner should handle saving/loading -- delegate actor save/load to mac, 205 | # use appropriate filenames to do critics, optimizer states 206 | learner.save_models(save_path) 207 | 208 | episode += args.batch_size_run 209 | 210 | if (runner.t_env - last_log_T) >= args.log_interval: 211 | logger.log_stat("episode", episode, runner.t_env) 212 | logger.print_recent_stats() 213 | last_log_T = runner.t_env 214 | 215 | runner.close_env() 216 | logger.console_logger.info("Finished Training") 217 | 218 | 219 | def args_sanity_check(config, _log): 220 | 221 | # set CUDA flags 222 | # config["use_cuda"] = True # Use cuda whenever possible! 223 | if config["use_cuda"] and not th.cuda.is_available(): 224 | config["use_cuda"] = False 225 | _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!") 226 | 227 | if config["test_nepisode"] < config["batch_size_run"]: 228 | config["test_nepisode"] = config["batch_size_run"] 229 | else: 230 | config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"] 231 | 232 | return config 233 | -------------------------------------------------------------------------------- /src/learners/qtran_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.qtran import QTranBase 4 | import torch as th 5 | from torch.optim import RMSprop, Adam 6 | 7 | 8 | class QLearner: 9 | def __init__(self, mac, scheme, logger, args): 10 | self.args = args 11 | self.mac = mac 12 | self.logger = logger 13 | 14 | self.params = list(mac.parameters()) 15 | 16 | self.last_target_update_episode = 0 17 | 18 | self.mixer = None 19 | if args.mixer == "qtran_base": 20 | self.mixer = QTranBase(args) 21 | elif args.mixer == "qtran_alt": 22 | raise Exception("Not implemented here!") 23 | 24 | self.params += list(self.mixer.parameters()) 25 | self.target_mixer = copy.deepcopy(self.mixer) 26 | 27 | self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 28 | 29 | # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC 30 | self.target_mac = copy.deepcopy(mac) 31 | 32 | self.log_stats_t = -self.args.learner_log_interval - 1 33 | 34 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 35 | # Get the relevant quantities 36 | rewards = batch["reward"][:, :-1] 37 | actions = batch["actions"][:, :-1] 38 | terminated = batch["terminated"][:, :-1].float() 39 | mask = batch["filled"][:, :-1].float() 40 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 41 | avail_actions = batch["avail_actions"] 42 | 43 | # Calculate estimated Q-Values 44 | mac_out = [] 45 | mac_hidden_states = [] 46 | self.mac.init_hidden(batch.batch_size) 47 | for t in range(batch.max_seq_length): 48 | agent_outs = self.mac.forward(batch, t=t) 49 | mac_out.append(agent_outs) 50 | mac_hidden_states.append(self.mac.hidden_states) 51 | mac_out = th.stack(mac_out, dim=1) # Concat over time 52 | mac_hidden_states = th.stack(mac_hidden_states, dim=1) 53 | mac_hidden_states = mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1,2) #btav 54 | 55 | # Pick the Q-Values for the actions taken by each agent 56 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 57 | 58 | # Calculate the Q-Values necessary for the target 59 | target_mac_out = [] 60 | target_mac_hidden_states = [] 61 | self.target_mac.init_hidden(batch.batch_size) 62 | for t in range(batch.max_seq_length): 63 | target_agent_outs = self.target_mac.forward(batch, t=t) 64 | target_mac_out.append(target_agent_outs) 65 | target_mac_hidden_states.append(self.target_mac.hidden_states) 66 | 67 | # We don't need the first timesteps Q-Value estimate for calculating targets 68 | target_mac_out = th.stack(target_mac_out[:], dim=1) # Concat across time 69 | target_mac_hidden_states = th.stack(target_mac_hidden_states, dim=1) 70 | target_mac_hidden_states = target_mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1,2) #btav 71 | 72 | # Mask out unavailable actions 73 | target_mac_out[avail_actions[:, :] == 0] = -9999999 # From OG deepmarl 74 | mac_out_maxs = mac_out.clone() 75 | mac_out_maxs[avail_actions == 0] = -9999999 76 | 77 | # Best joint action computed by target agents 78 | target_max_actions = target_mac_out.max(dim=3, keepdim=True)[1] 79 | # Best joint-action computed by regular agents 80 | max_actions_qvals, max_actions_current = mac_out_maxs[:, :].max(dim=3, keepdim=True) 81 | 82 | if self.args.mixer == "qtran_base": 83 | # -- TD Loss -- 84 | # Joint-action Q-Value estimates 85 | joint_qs, vs = self.mixer(batch[:, :-1], mac_hidden_states[:,:-1]) 86 | 87 | # Need to argmax across the target agents' actions to compute target joint-action Q-Values 88 | if self.args.double_q: 89 | max_actions_current_ = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) 90 | max_actions_current_onehot = max_actions_current_.scatter(3, max_actions_current[:, :], 1) 91 | max_actions_onehot = max_actions_current_onehot 92 | else: 93 | max_actions = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) 94 | max_actions_onehot = max_actions.scatter(3, target_max_actions[:, :], 1) 95 | target_joint_qs, target_vs = self.target_mixer(batch[:, 1:], hidden_states=target_mac_hidden_states[:,1:], actions=max_actions_onehot[:,1:]) 96 | 97 | # Td loss targets 98 | td_targets = rewards.reshape(-1,1) + self.args.gamma * (1 - terminated.reshape(-1, 1)) * target_joint_qs 99 | td_error = (joint_qs - td_targets.detach()) 100 | masked_td_error = td_error * mask.reshape(-1, 1) 101 | td_loss = (masked_td_error ** 2).sum() / mask.sum() 102 | # -- TD Loss -- 103 | 104 | # -- Opt Loss -- 105 | # Argmax across the current agents' actions 106 | if not self.args.double_q: # Already computed if we're doing double Q-Learning 107 | max_actions_current_ = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device ) 108 | max_actions_current_onehot = max_actions_current_.scatter(3, max_actions_current[:, :], 1) 109 | max_joint_qs, _ = self.mixer(batch[:, :-1], mac_hidden_states[:,:-1], actions=max_actions_current_onehot[:,:-1]) # Don't use the target network and target agent max actions as per author's email 110 | 111 | # max_actions_qvals = th.gather(mac_out[:, :-1], dim=3, index=max_actions_current[:,:-1]) 112 | opt_error = max_actions_qvals[:,:-1].sum(dim=2).reshape(-1, 1) - max_joint_qs.detach() + vs 113 | masked_opt_error = opt_error * mask.reshape(-1, 1) 114 | opt_loss = (masked_opt_error ** 2).sum() / mask.sum() 115 | # -- Opt Loss -- 116 | 117 | # -- Nopt Loss -- 118 | # target_joint_qs, _ = self.target_mixer(batch[:, :-1]) 119 | nopt_values = chosen_action_qvals.sum(dim=2).reshape(-1, 1) - joint_qs.detach() + vs # Don't use target networks here either 120 | nopt_error = nopt_values.clamp(max=0) 121 | masked_nopt_error = nopt_error * mask.reshape(-1, 1) 122 | nopt_loss = (masked_nopt_error ** 2).sum() / mask.sum() 123 | # -- Nopt loss -- 124 | 125 | elif self.args.mixer == "qtran_alt": 126 | raise Exception("Not supported yet.") 127 | 128 | loss = td_loss + self.args.opt_loss * opt_loss + self.args.nopt_min_loss * nopt_loss 129 | 130 | # Optimise 131 | self.optimiser.zero_grad() 132 | loss.backward() 133 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 134 | self.optimiser.step() 135 | 136 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 137 | self._update_targets() 138 | self.last_target_update_episode = episode_num 139 | 140 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 141 | self.logger.log_stat("loss", loss.item(), t_env) 142 | self.logger.log_stat("td_loss", td_loss.item(), t_env) 143 | self.logger.log_stat("opt_loss", opt_loss.item(), t_env) 144 | self.logger.log_stat("nopt_loss", nopt_loss.item(), t_env) 145 | self.logger.log_stat("grad_norm", grad_norm, t_env) 146 | if self.args.mixer == "qtran_base": 147 | mask_elems = mask.sum().item() 148 | self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) 149 | self.logger.log_stat("td_targets", ((masked_td_error).sum().item()/mask_elems), t_env) 150 | self.logger.log_stat("td_chosen_qs", (joint_qs.sum().item()/mask_elems), t_env) 151 | self.logger.log_stat("v_mean", (vs.sum().item()/mask_elems), t_env) 152 | self.logger.log_stat("agent_indiv_qs", ((chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents)), t_env) 153 | self.log_stats_t = t_env 154 | 155 | def _update_targets(self): 156 | self.target_mac.load_state(self.mac) 157 | if self.mixer is not None: 158 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 159 | self.logger.console_logger.info("Updated target network") 160 | 161 | def cuda(self): 162 | self.mac.cuda() 163 | self.target_mac.cuda() 164 | if self.mixer is not None: 165 | self.mixer.cuda() 166 | self.target_mixer.cuda() 167 | 168 | def save_models(self, path): 169 | self.mac.save_models(path) 170 | if self.mixer is not None: 171 | th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) 172 | th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) 173 | 174 | def load_models(self, path): 175 | self.mac.load_models(path) 176 | # Not quite right but I don't want to save target networks 177 | self.target_mac.load_models(path) 178 | if self.mixer is not None: 179 | self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) 180 | self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) 181 | -------------------------------------------------------------------------------- /src/learners/side_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.vdn import VDNMixer 4 | from modules.mixers.qmix import QMixer 5 | import torch as th 6 | from torch.optim import RMSprop 7 | from modules.vi.vae import inner_state_autoencoder 8 | from modules.vi.vgae import state_autoencoder, flatten_state_autoencoder 9 | import numpy as np 10 | 11 | 12 | class SIDELearner: 13 | def __init__(self, mac, scheme, logger, args): 14 | self.args = args 15 | self.mac = mac 16 | self.logger = logger 17 | 18 | self.params = list(mac.parameters()) 19 | 20 | self.last_target_update_episode = 0 21 | 22 | self.mixer = None 23 | if args.mixer is not None: 24 | self.mixer = QMixer(args, args.latent_dim * args.n_agents) 25 | self.params += list(self.mixer.parameters()) 26 | self.target_mixer = copy.deepcopy(self.mixer) 27 | 28 | if self.args.vgae: 29 | self.state_vae = state_autoencoder(self.args).to(args.device) 30 | else: 31 | self.state_vae = flatten_state_autoencoder(self.args).to(args.device) 32 | self.state_prior_vae = inner_state_autoencoder(self.args).to(args.device) 33 | 34 | if self.args.prior: 35 | self.params += list(self.state_prior_vae.parameters()) 36 | self.params += list(self.state_vae.parameters()) 37 | 38 | self.rl_optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 39 | 40 | # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC 41 | self.target_mac = copy.deepcopy(mac) 42 | 43 | self.log_stats_t = -self.args.learner_log_interval - 1 44 | 45 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 46 | # Get the relevant quantities 47 | rewards = batch["reward"][:, :-1] 48 | actions = batch["actions"][:, :-1] 49 | terminated = batch["terminated"][:, :-1].float() 50 | mask = batch["filled"][:, :-1].float() 51 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 52 | vae_mask = batch["filled"].float() 53 | prior_vae_mask = th.cat([th.zeros_like(vae_mask[:, 0]).unsqueeze(1), vae_mask[:, :-1]], dim=1) 54 | avail_actions = batch["avail_actions"] 55 | past_onehot_action = th.cat([th.zeros_like(batch["actions_onehot"][:, 0]).unsqueeze(1), batch["actions_onehot"][:, :-1]], dim=1) 56 | #past_reward = th.cat([th.zeros_like(batch["reward"][:, 0]).unsqueeze(1), batch["reward"][:, :-2]], dim=1) 57 | 58 | # Calculate estimated Q-Values 59 | mac_out = [] 60 | mac_hidden_states = [] 61 | self.mac.init_hidden(batch.batch_size) 62 | for t in range(batch.max_seq_length): 63 | agent_outs = self.mac.forward(batch, t=t) 64 | mac_out.append(agent_outs) 65 | mac_hidden_states.append(self.mac.hidden_states) 66 | mac_out = th.stack(mac_out, dim=1) # Concat over time 67 | mac_hidden_states = th.stack(mac_hidden_states, dim=1) 68 | mac_hidden_states = mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1, 2).detach() #btav 69 | 70 | # Pick the Q-Values for the actions taken by each agent 71 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 72 | 73 | # Calculate the Q-Values necessary for the target 74 | target_mac_out = [] 75 | target_mac_hidden_states = [] 76 | self.target_mac.init_hidden(batch.batch_size) 77 | for t in range(batch.max_seq_length): 78 | target_agent_outs = self.target_mac.forward(batch, t=t) 79 | target_mac_out.append(target_agent_outs) 80 | target_mac_hidden_states.append(self.target_mac.hidden_states) 81 | 82 | # We don't need the first timesteps Q-Value estimate for calculating targets 83 | target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time 84 | target_mac_hidden_states = th.stack(target_mac_hidden_states, dim=1) 85 | target_mac_hidden_states = target_mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1, 2).detach() #btav 86 | 87 | # Mask out unavailable actions 88 | target_mac_out[avail_actions[:, 1:] == 0] = -9999999 89 | 90 | # Max over target Q-Values 91 | if self.args.double_q: 92 | # Get actions that maximise live Q (for double q-learning) 93 | mac_out_detach = mac_out.clone().detach() 94 | mac_out_detach[avail_actions == 0] = -9999999 95 | cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] 96 | target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) 97 | else: 98 | target_max_qvals = target_mac_out.max(dim=3)[0] 99 | 100 | # * vae for get state 101 | recon_node_features, state_mu, state_logvar, sample_state = self.state_vae.forward(target_mac_hidden_states, batch['alive_allies']) 102 | #sample_state = self.state_vae.sample_z(state_mu, state_logvar) 103 | past_sample_state = th.cat([th.zeros([sample_state.size(0), 1, sample_state.size(2)]).to(sample_state.device), sample_state[:, :-1]], dim=-2) 104 | recon_state, recon_action, prior_mu, prior_logvar = self.state_prior_vae.forward(past_sample_state, past_onehot_action) 105 | #inner_state = self.state_prior_vae.sample_z(prior_mu, prior_logvar) 106 | # Mix 107 | if self.mixer is not None: 108 | chosen_action_qvals = self.mixer(chosen_action_qvals, state_mu[:, :-1]) 109 | target_max_qvals = self.target_mixer(target_max_qvals, state_mu[:, 1:]) 110 | 111 | # Calculate 1-step Q-Learning targets 112 | targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals 113 | 114 | # Td-error 115 | td_error = (chosen_action_qvals - targets.detach()) 116 | 117 | mask = mask.expand_as(td_error) 118 | 119 | # 0-out the targets that came from padded data 120 | masked_td_error = td_error * mask 121 | 122 | prior_kld_loss = 1 + prior_logvar - prior_mu.pow(2) - prior_logvar.exp() 123 | if self.args.prior: 124 | kld_loss = state_logvar - prior_logvar + 1 - (state_logvar.exp() + (state_mu - prior_mu).pow(2)) / prior_logvar.exp() 125 | else: 126 | kld_loss = 1 + state_logvar - state_mu.pow(2) - state_logvar.exp() 127 | 128 | feature_mse_loss = (recon_node_features - target_mac_hidden_states).pow(2) 129 | prior_state_mse_loss = (recon_state - past_sample_state).pow(2) 130 | prior_action_loss = - past_onehot_action * recon_action 131 | 132 | # Normal L2 loss, take mean over actual data 133 | q_loss = (masked_td_error ** 2).sum() / mask.sum() 134 | 135 | state_loss = (feature_mse_loss * vae_mask.unsqueeze(-1).expand_as(feature_mse_loss)).sum() / vae_mask.unsqueeze(-1).expand_as(feature_mse_loss).sum() \ 136 | - 0.5 * (kld_loss * vae_mask.expand_as(kld_loss)).sum() / vae_mask.expand_as(kld_loss).sum() 137 | 138 | prior_loss = - 0.5 * (prior_kld_loss * prior_vae_mask.expand_as(prior_kld_loss)).sum() / prior_vae_mask.expand_as(prior_kld_loss).sum() \ 139 | + (prior_state_mse_loss * prior_vae_mask.expand_as(prior_state_mse_loss)).sum() / prior_vae_mask.expand_as(prior_state_mse_loss).sum() \ 140 | + (prior_action_loss * prior_vae_mask.unsqueeze(-1).expand_as(prior_action_loss)).sum() / prior_vae_mask.unsqueeze(-1).expand_as(prior_action_loss).sum() 141 | 142 | 143 | if not self.args.prior: 144 | prior_loss = 0 145 | loss = prior_loss + state_loss + q_loss 146 | 147 | # Optimise 148 | self.rl_optimiser.zero_grad() 149 | loss.backward() 150 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 151 | self.rl_optimiser.step() 152 | 153 | 154 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 155 | self._update_targets() 156 | self.last_target_update_episode = episode_num 157 | 158 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 159 | self.logger.log_stat("loss", loss.item(), t_env) 160 | self.logger.log_stat("grad_norm", grad_norm, t_env) 161 | mask_elems = mask.sum().item() 162 | self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) 163 | self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 164 | self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 165 | self.log_stats_t = t_env 166 | 167 | def _update_targets(self): 168 | self.target_mac.load_state(self.mac) 169 | if self.mixer is not None: 170 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 171 | self.logger.console_logger.info("Updated target network") 172 | 173 | def cuda(self): 174 | self.mac.cuda() 175 | self.target_mac.cuda() 176 | if self.mixer is not None: 177 | self.mixer.cuda() 178 | self.target_mixer.cuda() 179 | 180 | def save_models(self, path): 181 | self.mac.save_models(path) 182 | if self.mixer is not None: 183 | th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) 184 | th.save(self.rl_optimiser.state_dict(), "{}/opt.th".format(path)) 185 | 186 | def load_models(self, path): 187 | self.mac.load_models(path) 188 | # Not quite right but I don't want to save target networks 189 | self.target_mac.load_models(path) 190 | if self.mixer is not None: 191 | self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) 192 | self.rl_optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) 193 | -------------------------------------------------------------------------------- /src/runners/parallel_runner.py: -------------------------------------------------------------------------------- 1 | from envs import REGISTRY as env_REGISTRY 2 | from functools import partial 3 | from components.episode_buffer import EpisodeBatch 4 | from multiprocessing import Pipe, Process 5 | import numpy as np 6 | import torch as th 7 | 8 | 9 | # Based (very) heavily on SubprocVecEnv from OpenAI Baselines 10 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py 11 | class ParallelRunner: 12 | 13 | def __init__(self, args, logger): 14 | self.args = args 15 | self.logger = logger 16 | self.batch_size = self.args.batch_size_run 17 | 18 | # Make subprocesses for the envs 19 | self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)]) 20 | env_fn = env_REGISTRY[self.args.env] 21 | self.ps = [Process(target=env_worker, args=(worker_conn, CloudpickleWrapper(partial(env_fn, **self.args.env_args)))) 22 | for worker_conn in self.worker_conns] 23 | 24 | for p in self.ps: 25 | p.daemon = True 26 | p.start() 27 | 28 | self.parent_conns[0].send(("get_env_info", None)) 29 | self.env_info = self.parent_conns[0].recv() 30 | self.episode_limit = self.env_info["episode_limit"] 31 | 32 | self.t = 0 33 | 34 | self.t_env = 0 35 | 36 | self.train_returns = [] 37 | self.test_returns = [] 38 | self.train_stats = {} 39 | self.test_stats = {} 40 | 41 | self.log_train_stats_t = -100000 42 | 43 | def setup(self, scheme, groups, preprocess, mac): 44 | self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, 45 | preprocess=preprocess, device=self.args.device) 46 | self.mac = mac 47 | self.scheme = scheme 48 | self.groups = groups 49 | self.preprocess = preprocess 50 | 51 | def get_env_info(self): 52 | return self.env_info 53 | 54 | def save_replay(self): 55 | pass 56 | 57 | def close_env(self): 58 | for parent_conn in self.parent_conns: 59 | parent_conn.send(("close", None)) 60 | 61 | def reset(self): 62 | self.batch = self.new_batch() 63 | 64 | # Reset the envs 65 | for parent_conn in self.parent_conns: 66 | parent_conn.send(("reset", None)) 67 | 68 | pre_transition_data = { 69 | "state": [], 70 | "avail_actions": [], 71 | "obs": [] 72 | } 73 | # Get the obs, state and avail_actions back 74 | for parent_conn in self.parent_conns: 75 | data = parent_conn.recv() 76 | pre_transition_data["state"].append(data["state"]) 77 | pre_transition_data["avail_actions"].append(data["avail_actions"]) 78 | pre_transition_data["obs"].append(data["obs"]) 79 | 80 | self.batch.update(pre_transition_data, ts=0) 81 | 82 | self.t = 0 83 | self.env_steps_this_run = 0 84 | 85 | def run(self, test_mode=False): 86 | self.reset() 87 | 88 | all_terminated = False 89 | episode_returns = [0 for _ in range(self.batch_size)] 90 | episode_lengths = [0 for _ in range(self.batch_size)] 91 | self.mac.init_hidden(batch_size=self.batch_size) 92 | terminated = [False for _ in range(self.batch_size)] 93 | envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed] 94 | final_env_infos = [] # may store extra stats like battle won. this is filled in ORDER OF TERMINATION 95 | 96 | while True: 97 | 98 | # Pass the entire batch of experiences up till now to the agents 99 | # Receive the actions for each agent at this timestep in a batch for each un-terminated env 100 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, bs=envs_not_terminated, test_mode=test_mode) 101 | cpu_actions = actions.to("cpu").numpy() 102 | 103 | # Update the actions taken 104 | actions_chosen = { 105 | "actions": actions.unsqueeze(1) 106 | } 107 | self.batch.update(actions_chosen, bs=envs_not_terminated, ts=self.t, mark_filled=False) 108 | 109 | # Send actions to each env 110 | action_idx = 0 111 | for idx, parent_conn in enumerate(self.parent_conns): 112 | if idx in envs_not_terminated: # We produced actions for this env 113 | if not terminated[idx]: # Only send the actions to the env if it hasn't terminated 114 | parent_conn.send(("step", cpu_actions[action_idx])) 115 | action_idx += 1 # actions is not a list over every env 116 | 117 | # Update envs_not_terminated 118 | envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed] 119 | all_terminated = all(terminated) 120 | if all_terminated: 121 | break 122 | 123 | # Post step data we will insert for the current timestep 124 | post_transition_data = { 125 | "reward": [], 126 | "terminated": [] 127 | } 128 | # Data for the next step we will insert in order to select an action 129 | pre_transition_data = { 130 | "state": [], 131 | "avail_actions": [], 132 | "obs": [] 133 | } 134 | 135 | # Receive data back for each unterminated env 136 | for idx, parent_conn in enumerate(self.parent_conns): 137 | if not terminated[idx]: 138 | data = parent_conn.recv() 139 | # Remaining data for this current timestep 140 | post_transition_data["reward"].append((data["reward"],)) 141 | 142 | episode_returns[idx] += data["reward"] 143 | episode_lengths[idx] += 1 144 | if not test_mode: 145 | self.env_steps_this_run += 1 146 | 147 | env_terminated = False 148 | if data["terminated"]: 149 | final_env_infos.append(data["info"]) 150 | if data["terminated"] and not data["info"].get("episode_limit", False): 151 | env_terminated = True 152 | terminated[idx] = data["terminated"] 153 | post_transition_data["terminated"].append((env_terminated,)) 154 | 155 | # Data for the next timestep needed to select an action 156 | pre_transition_data["state"].append(data["state"]) 157 | pre_transition_data["avail_actions"].append(data["avail_actions"]) 158 | pre_transition_data["obs"].append(data["obs"]) 159 | 160 | # Add post_transiton data into the batch 161 | self.batch.update(post_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=False) 162 | 163 | # Move onto the next timestep 164 | self.t += 1 165 | 166 | # Add the pre-transition data 167 | self.batch.update(pre_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=True) 168 | 169 | if not test_mode: 170 | self.t_env += self.env_steps_this_run 171 | 172 | # Get stats back for each env 173 | for parent_conn in self.parent_conns: 174 | parent_conn.send(("get_stats",None)) 175 | 176 | env_stats = [] 177 | for parent_conn in self.parent_conns: 178 | env_stat = parent_conn.recv() 179 | env_stats.append(env_stat) 180 | 181 | cur_stats = self.test_stats if test_mode else self.train_stats 182 | cur_returns = self.test_returns if test_mode else self.train_returns 183 | log_prefix = "test_" if test_mode else "" 184 | infos = [cur_stats] + final_env_infos 185 | cur_stats.update({k: sum(d.get(k, 0) for d in infos) for k in set.union(*[set(d) for d in infos])}) 186 | cur_stats["n_episodes"] = self.batch_size + cur_stats.get("n_episodes", 0) 187 | cur_stats["ep_length"] = sum(episode_lengths) + cur_stats.get("ep_length", 0) 188 | 189 | cur_returns.extend(episode_returns) 190 | 191 | n_test_runs = max(1, self.args.test_nepisode // self.batch_size) * self.batch_size 192 | if test_mode and (len(self.test_returns) == n_test_runs): 193 | self._log(cur_returns, cur_stats, log_prefix) 194 | elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval: 195 | self._log(cur_returns, cur_stats, log_prefix) 196 | if hasattr(self.mac.action_selector, "epsilon"): 197 | self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env) 198 | self.log_train_stats_t = self.t_env 199 | 200 | return self.batch 201 | 202 | def _log(self, returns, stats, prefix): 203 | self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env) 204 | self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env) 205 | returns.clear() 206 | 207 | for k, v in stats.items(): 208 | if k != "n_episodes": 209 | self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env) 210 | stats.clear() 211 | 212 | 213 | def env_worker(remote, env_fn): 214 | # Make environment 215 | env = env_fn.x() 216 | while True: 217 | cmd, data = remote.recv() 218 | if cmd == "step": 219 | actions = data 220 | # Take a step in the environment 221 | reward, terminated, env_info = env.step(actions) 222 | # Return the observations, avail_actions and state to make the next action 223 | state = env.get_state() 224 | avail_actions = env.get_avail_actions() 225 | obs = env.get_obs() 226 | remote.send({ 227 | # Data for the next timestep needed to pick an action 228 | "state": state, 229 | "avail_actions": avail_actions, 230 | "obs": obs, 231 | # Rest of the data for the current timestep 232 | "reward": reward, 233 | "terminated": terminated, 234 | "info": env_info 235 | }) 236 | elif cmd == "reset": 237 | env.reset() 238 | remote.send({ 239 | "state": env.get_state(), 240 | "avail_actions": env.get_avail_actions(), 241 | "obs": env.get_obs() 242 | }) 243 | elif cmd == "close": 244 | env.close() 245 | remote.close() 246 | break 247 | elif cmd == "get_env_info": 248 | remote.send(env.get_env_info()) 249 | elif cmd == "get_stats": 250 | remote.send(env.get_stats()) 251 | else: 252 | raise NotImplementedError 253 | 254 | 255 | class CloudpickleWrapper(): 256 | """ 257 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 258 | """ 259 | def __init__(self, x): 260 | self.x = x 261 | def __getstate__(self): 262 | import cloudpickle 263 | return cloudpickle.dumps(self.x) 264 | def __setstate__(self, ob): 265 | import pickle 266 | self.x = pickle.loads(ob) 267 | 268 | -------------------------------------------------------------------------------- /src/components/episode_buffer.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from types import SimpleNamespace as SN 4 | 5 | 6 | class EpisodeBatch: 7 | def __init__(self, 8 | scheme, 9 | groups, 10 | batch_size, 11 | max_seq_length, 12 | data=None, 13 | preprocess=None, 14 | device="cpu"): 15 | self.scheme = scheme.copy() 16 | self.groups = groups 17 | self.batch_size = batch_size 18 | self.max_seq_length = max_seq_length 19 | self.preprocess = {} if preprocess is None else preprocess 20 | self.device = device 21 | 22 | if data is not None: 23 | self.data = data 24 | else: 25 | self.data = SN() 26 | self.data.transition_data = {} 27 | self.data.episode_data = {} 28 | self._setup_data(self.scheme, self.groups, batch_size, max_seq_length, self.preprocess) 29 | 30 | def _setup_data(self, scheme, groups, batch_size, max_seq_length, preprocess): 31 | if preprocess is not None: 32 | for k in preprocess: 33 | assert k in scheme 34 | new_k = preprocess[k][0] 35 | transforms = preprocess[k][1] 36 | 37 | vshape = self.scheme[k]["vshape"] 38 | dtype = self.scheme[k]["dtype"] 39 | for transform in transforms: 40 | vshape, dtype = transform.infer_output_info(vshape, dtype) 41 | 42 | self.scheme[new_k] = { 43 | "vshape": vshape, 44 | "dtype": dtype 45 | } 46 | if "group" in self.scheme[k]: 47 | self.scheme[new_k]["group"] = self.scheme[k]["group"] 48 | if "episode_const" in self.scheme[k]: 49 | self.scheme[new_k]["episode_const"] = self.scheme[k]["episode_const"] 50 | 51 | assert "filled" not in scheme, '"filled" is a reserved key for masking.' 52 | scheme.update({ 53 | "filled": {"vshape": (1,), "dtype": th.long}, 54 | }) 55 | 56 | for field_key, field_info in scheme.items(): 57 | assert "vshape" in field_info, "Scheme must define vshape for {}".format(field_key) 58 | vshape = field_info["vshape"] 59 | episode_const = field_info.get("episode_const", False) 60 | group = field_info.get("group", None) 61 | dtype = field_info.get("dtype", th.float32) 62 | 63 | if isinstance(vshape, int): 64 | vshape = (vshape,) 65 | 66 | if group: 67 | assert group in groups, "Group {} must have its number of members defined in _groups_".format(group) 68 | shape = (groups[group], *vshape) 69 | else: 70 | shape = vshape 71 | 72 | if episode_const: 73 | self.data.episode_data[field_key] = th.zeros((batch_size, *shape), dtype=dtype, device=self.device) 74 | else: 75 | self.data.transition_data[field_key] = th.zeros((batch_size, max_seq_length, *shape), dtype=dtype, device=self.device) 76 | 77 | def extend(self, scheme, groups=None): 78 | self._setup_data(scheme, self.groups if groups is None else groups, self.batch_size, self.max_seq_length) 79 | 80 | def to(self, device): 81 | for k, v in self.data.transition_data.items(): 82 | self.data.transition_data[k] = v.to(device) 83 | for k, v in self.data.episode_data.items(): 84 | self.data.episode_data[k] = v.to(device) 85 | self.device = device 86 | 87 | def update(self, data, bs=slice(None), ts=slice(None), mark_filled=True): 88 | slices = self._parse_slices((bs, ts)) 89 | for k, v in data.items(): 90 | if k in self.data.transition_data: 91 | target = self.data.transition_data 92 | if mark_filled: 93 | target["filled"][slices] = 1 94 | mark_filled = False 95 | _slices = slices 96 | elif k in self.data.episode_data: 97 | target = self.data.episode_data 98 | _slices = slices[0] 99 | else: 100 | raise KeyError("{} not found in transition or episode data".format(k)) 101 | 102 | dtype = self.scheme[k].get("dtype", th.float32) 103 | v = th.tensor(v, dtype=dtype, device=self.device) 104 | self._check_safe_view(v, target[k][_slices]) 105 | target[k][_slices] = v.view_as(target[k][_slices]) 106 | 107 | if k in self.preprocess: 108 | new_k = self.preprocess[k][0] 109 | v = target[k][_slices] 110 | for transform in self.preprocess[k][1]: 111 | v = transform.transform(v) 112 | target[new_k][_slices] = v.view_as(target[new_k][_slices]) 113 | 114 | def _check_safe_view(self, v, dest): 115 | idx = len(v.shape) - 1 116 | for s in dest.shape[::-1]: 117 | if v.shape[idx] != s: 118 | if s != 1: 119 | raise ValueError("Unsafe reshape of {} to {}".format(v.shape, dest.shape)) 120 | else: 121 | idx -= 1 122 | 123 | def __getitem__(self, item): 124 | if isinstance(item, str): 125 | if item in self.data.episode_data: 126 | return self.data.episode_data[item] 127 | elif item in self.data.transition_data: 128 | return self.data.transition_data[item] 129 | else: 130 | raise ValueError 131 | elif isinstance(item, tuple) and all([isinstance(it, str) for it in item]): 132 | new_data = self._new_data_sn() 133 | for key in item: 134 | if key in self.data.transition_data: 135 | new_data.transition_data[key] = self.data.transition_data[key] 136 | elif key in self.data.episode_data: 137 | new_data.episode_data[key] = self.data.episode_data[key] 138 | else: 139 | raise KeyError("Unrecognised key {}".format(key)) 140 | 141 | # Update the scheme to only have the requested keys 142 | new_scheme = {key: self.scheme[key] for key in item} 143 | new_groups = {self.scheme[key]["group"]: self.groups[self.scheme[key]["group"]] 144 | for key in item if "group" in self.scheme[key]} 145 | ret = EpisodeBatch(new_scheme, new_groups, self.batch_size, self.max_seq_length, data=new_data, device=self.device) 146 | return ret 147 | else: 148 | item = self._parse_slices(item) 149 | new_data = self._new_data_sn() 150 | for k, v in self.data.transition_data.items(): 151 | new_data.transition_data[k] = v[item] 152 | for k, v in self.data.episode_data.items(): 153 | new_data.episode_data[k] = v[item[0]] 154 | 155 | ret_bs = self._get_num_items(item[0], self.batch_size) 156 | ret_max_t = self._get_num_items(item[1], self.max_seq_length) 157 | 158 | ret = EpisodeBatch(self.scheme, self.groups, ret_bs, ret_max_t, data=new_data, device=self.device) 159 | return ret 160 | 161 | def _get_num_items(self, indexing_item, max_size): 162 | if isinstance(indexing_item, list) or isinstance(indexing_item, np.ndarray): 163 | return len(indexing_item) 164 | elif isinstance(indexing_item, slice): 165 | _range = indexing_item.indices(max_size) 166 | return 1 + (_range[1] - _range[0] - 1)//_range[2] 167 | 168 | def _new_data_sn(self): 169 | new_data = SN() 170 | new_data.transition_data = {} 171 | new_data.episode_data = {} 172 | return new_data 173 | 174 | def _parse_slices(self, items): 175 | parsed = [] 176 | # Only batch slice given, add full time slice 177 | if (isinstance(items, slice) # slice a:b 178 | or isinstance(items, int) # int i 179 | or (isinstance(items, (list, np.ndarray, th.LongTensor, th.cuda.LongTensor))) # [a,b,c] 180 | ): 181 | items = (items, slice(None)) 182 | 183 | # Need the time indexing to be contiguous 184 | if isinstance(items[1], list): 185 | raise IndexError("Indexing across Time must be contiguous") 186 | 187 | for item in items: 188 | #TODO: stronger checks to ensure only supported options get through 189 | if isinstance(item, int): 190 | # Convert single indices to slices 191 | parsed.append(slice(item, item+1)) 192 | else: 193 | # Leave slices and lists as is 194 | parsed.append(item) 195 | return parsed 196 | 197 | def max_t_filled(self): 198 | return th.sum(self.data.transition_data["filled"], 1).max(0)[0] 199 | 200 | def __repr__(self): 201 | return "EpisodeBatch. Batch Size:{} Max_seq_len:{} Keys:{} Groups:{}".format(self.batch_size, 202 | self.max_seq_length, 203 | self.scheme.keys(), 204 | self.groups.keys()) 205 | 206 | 207 | class ReplayBuffer(EpisodeBatch): 208 | def __init__(self, scheme, groups, buffer_size, max_seq_length, preprocess=None, device="cpu"): 209 | super(ReplayBuffer, self).__init__(scheme, groups, buffer_size, max_seq_length, preprocess=preprocess, device=device) 210 | self.buffer_size = buffer_size # same as self.batch_size but more explicit 211 | self.buffer_index = 0 212 | self.episodes_in_buffer = 0 213 | 214 | def insert_episode_batch(self, ep_batch): 215 | if self.buffer_index + ep_batch.batch_size <= self.buffer_size: 216 | self.update(ep_batch.data.transition_data, 217 | slice(self.buffer_index, self.buffer_index + ep_batch.batch_size), 218 | slice(0, ep_batch.max_seq_length), 219 | mark_filled=False) 220 | self.update(ep_batch.data.episode_data, 221 | slice(self.buffer_index, self.buffer_index + ep_batch.batch_size)) 222 | self.buffer_index = (self.buffer_index + ep_batch.batch_size) 223 | self.episodes_in_buffer = max(self.episodes_in_buffer, self.buffer_index) 224 | self.buffer_index = self.buffer_index % self.buffer_size 225 | assert self.buffer_index < self.buffer_size 226 | else: 227 | buffer_left = self.buffer_size - self.buffer_index 228 | self.insert_episode_batch(ep_batch[0:buffer_left, :]) 229 | self.insert_episode_batch(ep_batch[buffer_left:, :]) 230 | 231 | def can_sample(self, batch_size): 232 | return self.episodes_in_buffer >= batch_size 233 | 234 | def sample(self, batch_size): 235 | assert self.can_sample(batch_size) 236 | if self.episodes_in_buffer == batch_size: 237 | return self[:batch_size] 238 | else: 239 | # Uniform sampling only atm 240 | ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False) 241 | return self[ep_ids] 242 | 243 | def __repr__(self): 244 | return "ReplayBuffer. {}/{} episodes. Keys:{} Groups:{}".format(self.episodes_in_buffer, 245 | self.buffer_size, 246 | self.scheme.keys(), 247 | self.groups.keys()) 248 | 249 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2017 Google Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | --------------------------------------------------------------------------------