├── .gitignore ├── LICENSE ├── README.md ├── docker ├── Dockerfile └── build.sh ├── install_sc2.sh ├── multi.png ├── requirements.txt ├── run.sh ├── run_interactive.sh ├── single.png └── src ├── .gitignore ├── __init__.py ├── components ├── __init__.py ├── action_selectors.py ├── episode_buffer.py ├── epsilon_schedules.py └── transforms.py ├── config ├── algs │ ├── coma.yaml │ ├── iql.yaml │ ├── iql_beta.yaml │ ├── qmix.yaml │ ├── qmix_beta.yaml │ ├── qtran.yaml │ ├── vdn.yaml │ └── vdn_beta.yaml ├── default.yaml └── envs │ ├── sc2.yaml │ └── sc2_beta.yaml ├── controllers ├── __init__.py └── basic_controller.py ├── envs ├── __init__.py └── multiagentenv.py ├── learners ├── __init__.py ├── coma_learner.py ├── q_learner.py └── qtran_learner.py ├── main.py ├── modules ├── __init__.py ├── agents │ ├── __init__.py │ ├── rnn_agent.py │ ├── transformer_agg_agent.py │ └── updet_agent.py ├── critics │ ├── __init__.py │ └── coma.py └── mixers │ ├── __init__.py │ ├── qmix.py │ ├── qtran.py │ └── vdn.py ├── run.py ├── runners ├── __init__.py ├── episode_runner.py └── parallel_runner.py └── utils ├── dict2namedtuple.py ├── logging.py ├── rl_utils.py └── timehelper.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 hhhusiyi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UPDeT 2 | Official Implementation of [UPDeT: Universal Multi-agent Reinforcement Learning via Policy Decoupling with Transformers](https://openreview.net/forum?id=v9c7hr9ADKx) (ICLR 2021 spotlight) 3 | 4 | The framework is inherited from [PyMARL](https://github.com/oxwhirl/pymarl). [UPDeT](https://github.com/hhhusiyi-monash/UPDeT) is written in [pytorch](https://pytorch.org) and uses [SMAC](https://github.com/oxwhirl/smac) as its environment. 5 | 6 | ## Installation instructions 7 | 8 | #### Installing dependencies: 9 | 10 | ```shell 11 | pip install -r requirements.txt 12 | ``` 13 | 14 | #### Download SC2 into the `3rdparty/` folder and copy the maps necessary to run over. 15 | 16 | ```shell 17 | bash install_sc2.sh 18 | ``` 19 | 20 | 21 | ## Run an experiment 22 | 23 | Before training your own transformer-based multi-agent model, there are a list of things to note. 24 | 25 | - Currently, this repository supports marine-based battle scenarios. e.g. `3m`, `8m`, `5m_vs_6m`. 26 | - If you are interested in training a different unit type, carefully modify the ` Transformer Parameters` block at `src/config/default.yaml` and revise the `_build_input_transformer` function in `basic_controller.python`. 27 | - Before running the experiment, check the agent type in ` Agent Parameters` block at `src/config/default.yaml`. 28 | - This repository contains two new transformer-based agents from the [UPDeT paper](https://arxiv.org/pdf/2101.08001.pdf) including 29 | - Standard UPDeT 30 | - Aggregation Transformer 31 | 32 | #### Training script 33 | 34 | ```shell 35 | python3 src/main.py --config=vdn --env-config=sc2 with env_args.map_name=5m_vs_6m 36 | ``` 37 | All results will be stored in the `Results/` folder. 38 | 39 | ## Performance 40 | 41 | #### Single battle scenario 42 | Surpass the GRU baseline on hard `5m_vs_6m` with: 43 | - [**QMIX**: QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1803.11485) 44 | - [**VDN**: Value-Decomposition Networks For Cooperative Multi-Agent Learning](https://arxiv.org/abs/1706.05296) 45 | - [**QTRAN**: QTRAN: Learning to Factorize with Transformation for Cooperative Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1905.05408) 46 | 47 | ![](https://github.com/hhhusiyi-monash/UPDeT/blob/main/single.png) 48 | 49 | #### Multiple battle scenarios 50 | 51 | Zero-shot generalize to different tasks: 52 | 53 | - Result on `7m-5m-3m` transfer learning. 54 | 55 | ![](https://github.com/hhhusiyi-monash/UPDeT/blob/main/multi.png) 56 | 57 | **Note: Only** UPDeT can be deployed to other scenarios without changing the model's architecture. 58 | 59 | **More details please refer to [UPDeT paper](https://arxiv.org/pdf/2101.08001.pdf).** 60 | 61 | ## Bibtex 62 | 63 | ```tex 64 | @article{hu2021updet, 65 | title={UPDeT: Universal Multi-agent Reinforcement Learning via Policy Decoupling with Transformers}, 66 | author={Hu, Siyi and Zhu, Fengda and Chang, Xiaojun and Liang, Xiaodan}, 67 | journal={arXiv preprint arXiv:2101.08001}, 68 | year={2021} 69 | } 70 | ``` 71 | 72 | ## License 73 | 74 | The MIT License -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:9.2-cudnn7-devel-ubuntu16.04 2 | MAINTAINER Tabish Rashid 3 | 4 | # CUDA includes 5 | ENV CUDA_PATH /usr/local/cuda 6 | ENV CUDA_INCLUDE_PATH /usr/local/cuda/include 7 | ENV CUDA_LIBRARY_PATH /usr/local/cuda/lib64 8 | 9 | # Ubuntu Packages 10 | RUN apt-get update -y && apt-get install software-properties-common -y && \ 11 | add-apt-repository -y multiverse && apt-get update -y && apt-get upgrade -y && \ 12 | apt-get install -y apt-utils nano vim man build-essential wget sudo && \ 13 | rm -rf /var/lib/apt/lists/* 14 | 15 | # Install curl and other dependencies 16 | RUN apt-get update -y && apt-get install -y curl libssl-dev openssl libopenblas-dev \ 17 | libhdf5-dev hdf5-helpers hdf5-tools libhdf5-serial-dev libprotobuf-dev protobuf-compiler git 18 | RUN curl -sk https://raw.githubusercontent.com/torch/distro/master/install-deps | bash && \ 19 | rm -rf /var/lib/apt/lists/* 20 | 21 | # Install python3 pip3 22 | RUN apt-get update 23 | RUN apt-get -y install python3 24 | RUN apt-get -y install python3-pip 25 | RUN pip3 install --upgrade pip 26 | 27 | # Python packages we use (or used at one point...) 28 | RUN pip3 install numpy scipy pyyaml matplotlib 29 | RUN pip3 install imageio 30 | RUN pip3 install tensorboard-logger 31 | RUN pip3 install pygame 32 | 33 | RUN mkdir /install 34 | WORKDIR /install 35 | 36 | RUN pip3 install jsonpickle==0.9.6 37 | # install Sacred (from OxWhirl fork) 38 | RUN pip3 install setuptools 39 | RUN git clone https://github.com/oxwhirl/sacred.git /install/sacred && cd /install/sacred && python3 setup.py install 40 | 41 | #### ------------------------------------------------------------------- 42 | #### install pytorch 43 | #### ------------------------------------------------------------------- 44 | RUN pip3 install torch 45 | RUN pip3 install torchvision snakeviz pytest probscale 46 | 47 | ## -- SMAC 48 | RUN pip3 install git+https://github.com/oxwhirl/smac.git 49 | ENV SC2PATH /pymarl/3rdparty/StarCraftII 50 | 51 | WORKDIR /pymarl 52 | -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo 'Building Dockerfile with image name pymarl:1.0' 4 | docker build -t pymarl:1.0 . 5 | -------------------------------------------------------------------------------- /install_sc2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Install SC2 and add the custom maps 3 | 4 | if [ -z "$EXP_DIR" ] 5 | then 6 | EXP_DIR=~ 7 | fi 8 | 9 | echo "EXP_DIR: $EXP_DIR" 10 | cd $EXP_DIR/pymarl 11 | 12 | mkdir 3rdparty 13 | cd 3rdparty 14 | 15 | export SC2PATH=`pwd`'/StarCraftII' 16 | echo 'SC2PATH is set to '$SC2PATH 17 | 18 | if [ ! -d $SC2PATH ]; then 19 | echo 'StarCraftII is not installed. Installing now ...'; 20 | wget http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip 21 | unzip -P iagreetotheeula SC2.4.10.zip 22 | rm -rf SC2.4.10.zip 23 | else 24 | echo 'StarCraftII is already installed.' 25 | fi 26 | 27 | echo 'Adding SMAC maps.' 28 | MAP_DIR="$SC2PATH/Maps/" 29 | echo 'MAP_DIR is set to '$MAP_DIR 30 | 31 | if [ ! -d $MAP_DIR ]; then 32 | mkdir -p $MAP_DIR 33 | fi 34 | 35 | cd .. 36 | wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip 37 | unzip SMAC_Maps.zip 38 | mv SMAC_Maps $MAP_DIR 39 | rm -rf SMAC_Maps.zip 40 | 41 | echo 'StarCraft II and SMAC are installed.' 42 | 43 | -------------------------------------------------------------------------------- /multi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Theohhhu/UPDeT/94b3db7a05e6c366596b76dd9c48981473cf6823/multi.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.5.0 2 | atomicwrites==1.2.1 3 | attrs==18.2.0 4 | certifi==2018.8.24 5 | chardet==3.0.4 6 | cycler==0.10.0 7 | docopt==0.6.2 8 | enum34==1.1.6 9 | future==0.16.0 10 | idna==2.7 11 | imageio==2.4.1 12 | jsonpickle==0.9.6 13 | kiwisolver==1.0.1 14 | matplotlib==3.0.0 15 | mock==2.0.0 16 | more-itertools==4.3.0 17 | mpyq==0.2.5 18 | munch==2.3.2 19 | numpy==1.15.2 20 | pathlib2==2.3.2 21 | pbr==4.3.0 22 | Pillow==6.2.0 23 | pluggy==0.7.1 24 | portpicker==1.2.0 25 | probscale==0.2.3 26 | protobuf==3.6.1 27 | py==1.6.0 28 | pygame==1.9.4 29 | pyparsing==2.2.2 30 | pysc2==3.0.0 31 | pytest==3.8.2 32 | python-dateutil==2.7.3 33 | PyYAML==3.13 34 | requests==2.20.0 35 | s2clientprotocol==4.10.1.75800.0 36 | sacred==0.7.2 37 | scipy==1.1.0 38 | six==1.11.0 39 | sk-video==1.1.10 40 | snakeviz==1.0.0 41 | tensorboard-logger==0.1.0 42 | torch==0.4.1 43 | torchvision==0.2.1 44 | tornado==5.1.1 45 | urllib3==1.24.2 46 | websocket-client==0.53.0 47 | whichcraft==0.5.2 48 | wrapt==1.10.11 49 | git+https://github.com/oxwhirl/smac.git 50 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | HASH=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 4 | head -n 1) 3 | GPU=$1 4 | name=${USER}_pymarl_GPU_${GPU}_${HASH} 5 | 6 | echo "Launching container named '${name}' on GPU '${GPU}'" 7 | # Launches a docker container using our image, and runs the provided command 8 | 9 | if hash nvidia-docker 2>/dev/null; then 10 | cmd=nvidia-docker 11 | else 12 | cmd=docker 13 | fi 14 | 15 | NV_GPU="$GPU" ${cmd} run \ 16 | --name $name \ 17 | --user $(id -u):$(id -g) \ 18 | -v `pwd`:/pymarl \ 19 | -t pymarl:1.0 \ 20 | ${@:2} 21 | -------------------------------------------------------------------------------- /run_interactive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | HASH=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 4 | head -n 1) 3 | GPU=$1 4 | name=${USER}_pymarl_GPU_${GPU}_${HASH} 5 | 6 | echo "Launching container named '${name}' on GPU '${GPU}'" 7 | # Launches a docker container using our image, and runs the provided command 8 | 9 | if hash nvidia-docker 2>/dev/null; then 10 | cmd=nvidia-docker 11 | else 12 | cmd=docker 13 | fi 14 | 15 | NV_GPU="$GPU" ${cmd} run -i \ 16 | --name $name \ 17 | --user $(id -u):$(id -g) \ 18 | -v `pwd`:/pymarl \ 19 | -t pymarl:1.0 \ 20 | ${@:2} 21 | -------------------------------------------------------------------------------- /single.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Theohhhu/UPDeT/94b3db7a05e6c366596b76dd9c48981473cf6823/single.png -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | tb_logs/ 2 | results/ 3 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Theohhhu/UPDeT/94b3db7a05e6c366596b76dd9c48981473cf6823/src/__init__.py -------------------------------------------------------------------------------- /src/components/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Theohhhu/UPDeT/94b3db7a05e6c366596b76dd9c48981473cf6823/src/components/__init__.py -------------------------------------------------------------------------------- /src/components/action_selectors.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch.distributions import Categorical 3 | from .epsilon_schedules import DecayThenFlatSchedule 4 | 5 | REGISTRY = {} 6 | 7 | 8 | class MultinomialActionSelector(): 9 | 10 | def __init__(self, args): 11 | self.args = args 12 | 13 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, 14 | decay="linear") 15 | self.epsilon = self.schedule.eval(0) 16 | self.test_greedy = getattr(args, "test_greedy", True) 17 | 18 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 19 | masked_policies = agent_inputs.clone() 20 | masked_policies[avail_actions == 0.0] = 0.0 21 | 22 | self.epsilon = self.schedule.eval(t_env) 23 | 24 | if test_mode and self.test_greedy: 25 | picked_actions = masked_policies.max(dim=2)[1] 26 | else: 27 | picked_actions = Categorical(masked_policies).sample().long() 28 | 29 | return picked_actions 30 | 31 | 32 | REGISTRY["multinomial"] = MultinomialActionSelector 33 | 34 | 35 | class EpsilonGreedyActionSelector(): 36 | 37 | def __init__(self, args): 38 | self.args = args 39 | 40 | self.schedule = DecayThenFlatSchedule(args.epsilon_start, args.epsilon_finish, args.epsilon_anneal_time, 41 | decay="linear") 42 | self.epsilon = self.schedule.eval(0) 43 | 44 | def select_action(self, agent_inputs, avail_actions, t_env, test_mode=False): 45 | 46 | # Assuming agent_inputs is a batch of Q-Values for each agent bav 47 | self.epsilon = self.schedule.eval(t_env) 48 | 49 | if test_mode: 50 | # Greedy action selection only 51 | self.epsilon = 0.0 52 | 53 | # mask actions that are excluded from selection 54 | masked_q_values = agent_inputs.clone() 55 | masked_q_values[avail_actions == 0.0] = -float("inf") # should never be selected! 56 | 57 | random_numbers = th.rand_like(agent_inputs[:, :, 0]) 58 | pick_random = (random_numbers < self.epsilon).long() 59 | random_actions = Categorical(avail_actions.float()).sample().long() 60 | 61 | picked_actions = pick_random * random_actions + (1 - pick_random) * masked_q_values.max(dim=2)[1] 62 | return picked_actions 63 | 64 | 65 | REGISTRY["epsilon_greedy"] = EpsilonGreedyActionSelector 66 | -------------------------------------------------------------------------------- /src/components/episode_buffer.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import numpy as np 3 | from types import SimpleNamespace as SN 4 | 5 | 6 | class EpisodeBatch: 7 | def __init__(self, 8 | scheme, 9 | groups, 10 | batch_size, 11 | max_seq_length, 12 | data=None, 13 | preprocess=None, 14 | device="cpu"): 15 | self.scheme = scheme.copy() 16 | self.groups = groups 17 | self.batch_size = batch_size 18 | self.max_seq_length = max_seq_length 19 | self.preprocess = {} if preprocess is None else preprocess 20 | self.device = device 21 | 22 | if data is not None: 23 | self.data = data 24 | else: 25 | self.data = SN() 26 | self.data.transition_data = {} 27 | self.data.episode_data = {} 28 | self._setup_data(self.scheme, self.groups, batch_size, max_seq_length, self.preprocess) 29 | 30 | def _setup_data(self, scheme, groups, batch_size, max_seq_length, preprocess): 31 | if preprocess is not None: 32 | for k in preprocess: 33 | assert k in scheme 34 | new_k = preprocess[k][0] 35 | transforms = preprocess[k][1] 36 | 37 | vshape = self.scheme[k]["vshape"] 38 | dtype = self.scheme[k]["dtype"] 39 | for transform in transforms: 40 | vshape, dtype = transform.infer_output_info(vshape, dtype) 41 | 42 | self.scheme[new_k] = { 43 | "vshape": vshape, 44 | "dtype": dtype 45 | } 46 | if "group" in self.scheme[k]: 47 | self.scheme[new_k]["group"] = self.scheme[k]["group"] 48 | if "episode_const" in self.scheme[k]: 49 | self.scheme[new_k]["episode_const"] = self.scheme[k]["episode_const"] 50 | 51 | assert "filled" not in scheme, '"filled" is a reserved key for masking.' 52 | scheme.update({ 53 | "filled": {"vshape": (1,), "dtype": th.long}, 54 | }) 55 | 56 | for field_key, field_info in scheme.items(): 57 | assert "vshape" in field_info, "Scheme must define vshape for {}".format(field_key) 58 | vshape = field_info["vshape"] 59 | episode_const = field_info.get("episode_const", False) 60 | group = field_info.get("group", None) 61 | dtype = field_info.get("dtype", th.float32) 62 | 63 | if isinstance(vshape, int): 64 | vshape = (vshape,) 65 | 66 | if group: 67 | assert group in groups, "Group {} must have its number of members defined in _groups_".format(group) 68 | shape = (groups[group], *vshape) 69 | else: 70 | shape = vshape 71 | 72 | if episode_const: 73 | self.data.episode_data[field_key] = th.zeros((batch_size, *shape), dtype=dtype, device=self.device) 74 | else: 75 | self.data.transition_data[field_key] = th.zeros((batch_size, max_seq_length, *shape), dtype=dtype, device=self.device) 76 | 77 | def extend(self, scheme, groups=None): 78 | self._setup_data(scheme, self.groups if groups is None else groups, self.batch_size, self.max_seq_length) 79 | 80 | def to(self, device): 81 | for k, v in self.data.transition_data.items(): 82 | self.data.transition_data[k] = v.to(device) 83 | for k, v in self.data.episode_data.items(): 84 | self.data.episode_data[k] = v.to(device) 85 | self.device = device 86 | 87 | def update(self, data, bs=slice(None), ts=slice(None), mark_filled=True): 88 | slices = self._parse_slices((bs, ts)) 89 | for k, v in data.items(): 90 | if k in self.data.transition_data: 91 | target = self.data.transition_data 92 | if mark_filled: 93 | target["filled"][slices] = 1 94 | mark_filled = False 95 | _slices = slices 96 | elif k in self.data.episode_data: 97 | target = self.data.episode_data 98 | _slices = slices[0] 99 | else: 100 | raise KeyError("{} not found in transition or episode data".format(k)) 101 | 102 | dtype = self.scheme[k].get("dtype", th.float32) 103 | v = th.tensor(v, dtype=dtype, device=self.device) 104 | self._check_safe_view(v, target[k][_slices]) 105 | target[k][_slices] = v.view_as(target[k][_slices]) 106 | 107 | if k in self.preprocess: 108 | new_k = self.preprocess[k][0] 109 | v = target[k][_slices] 110 | for transform in self.preprocess[k][1]: 111 | v = transform.transform(v) 112 | target[new_k][_slices] = v.view_as(target[new_k][_slices]) 113 | 114 | def _check_safe_view(self, v, dest): 115 | idx = len(v.shape) - 1 116 | for s in dest.shape[::-1]: 117 | if v.shape[idx] != s: 118 | if s != 1: 119 | raise ValueError("Unsafe reshape of {} to {}".format(v.shape, dest.shape)) 120 | else: 121 | idx -= 1 122 | 123 | def __getitem__(self, item): 124 | if isinstance(item, str): 125 | if item in self.data.episode_data: 126 | return self.data.episode_data[item] 127 | elif item in self.data.transition_data: 128 | return self.data.transition_data[item] 129 | else: 130 | raise ValueError 131 | elif isinstance(item, tuple) and all([isinstance(it, str) for it in item]): 132 | new_data = self._new_data_sn() 133 | for key in item: 134 | if key in self.data.transition_data: 135 | new_data.transition_data[key] = self.data.transition_data[key] 136 | elif key in self.data.episode_data: 137 | new_data.episode_data[key] = self.data.episode_data[key] 138 | else: 139 | raise KeyError("Unrecognised key {}".format(key)) 140 | 141 | # Update the scheme to only have the requested keys 142 | new_scheme = {key: self.scheme[key] for key in item} 143 | new_groups = {self.scheme[key]["group"]: self.groups[self.scheme[key]["group"]] 144 | for key in item if "group" in self.scheme[key]} 145 | ret = EpisodeBatch(new_scheme, new_groups, self.batch_size, self.max_seq_length, data=new_data, device=self.device) 146 | return ret 147 | else: 148 | item = self._parse_slices(item) 149 | new_data = self._new_data_sn() 150 | for k, v in self.data.transition_data.items(): 151 | new_data.transition_data[k] = v[item] 152 | for k, v in self.data.episode_data.items(): 153 | new_data.episode_data[k] = v[item[0]] 154 | 155 | ret_bs = self._get_num_items(item[0], self.batch_size) 156 | ret_max_t = self._get_num_items(item[1], self.max_seq_length) 157 | 158 | ret = EpisodeBatch(self.scheme, self.groups, ret_bs, ret_max_t, data=new_data, device=self.device) 159 | return ret 160 | 161 | def _get_num_items(self, indexing_item, max_size): 162 | if isinstance(indexing_item, list) or isinstance(indexing_item, np.ndarray): 163 | return len(indexing_item) 164 | elif isinstance(indexing_item, slice): 165 | _range = indexing_item.indices(max_size) 166 | return 1 + (_range[1] - _range[0] - 1)//_range[2] 167 | 168 | def _new_data_sn(self): 169 | new_data = SN() 170 | new_data.transition_data = {} 171 | new_data.episode_data = {} 172 | return new_data 173 | 174 | def _parse_slices(self, items): 175 | parsed = [] 176 | # Only batch slice given, add full time slice 177 | if (isinstance(items, slice) # slice a:b 178 | or isinstance(items, int) # int i 179 | or (isinstance(items, (list, np.ndarray, th.LongTensor, th.cuda.LongTensor))) # [a,b,c] 180 | ): 181 | items = (items, slice(None)) 182 | 183 | # Need the time indexing to be contiguous 184 | if isinstance(items[1], list): 185 | raise IndexError("Indexing across Time must be contiguous") 186 | 187 | for item in items: 188 | #TODO: stronger checks to ensure only supported options get through 189 | if isinstance(item, int): 190 | # Convert single indices to slices 191 | parsed.append(slice(item, item+1)) 192 | else: 193 | # Leave slices and lists as is 194 | parsed.append(item) 195 | return parsed 196 | 197 | def max_t_filled(self): 198 | return th.sum(self.data.transition_data["filled"], 1).max(0)[0] 199 | 200 | def __repr__(self): 201 | return "EpisodeBatch. Batch Size:{} Max_seq_len:{} Keys:{} Groups:{}".format(self.batch_size, 202 | self.max_seq_length, 203 | self.scheme.keys(), 204 | self.groups.keys()) 205 | 206 | 207 | class ReplayBuffer(EpisodeBatch): 208 | def __init__(self, scheme, groups, buffer_size, max_seq_length, preprocess=None, device="cpu"): 209 | super(ReplayBuffer, self).__init__(scheme, groups, buffer_size, max_seq_length, preprocess=preprocess, device=device) 210 | self.buffer_size = buffer_size # same as self.batch_size but more explicit 211 | self.buffer_index = 0 212 | self.episodes_in_buffer = 0 213 | 214 | def insert_episode_batch(self, ep_batch): 215 | if self.buffer_index + ep_batch.batch_size <= self.buffer_size: 216 | self.update(ep_batch.data.transition_data, 217 | slice(self.buffer_index, self.buffer_index + ep_batch.batch_size), 218 | slice(0, ep_batch.max_seq_length), 219 | mark_filled=False) 220 | self.update(ep_batch.data.episode_data, 221 | slice(self.buffer_index, self.buffer_index + ep_batch.batch_size)) 222 | self.buffer_index = (self.buffer_index + ep_batch.batch_size) 223 | self.episodes_in_buffer = max(self.episodes_in_buffer, self.buffer_index) 224 | self.buffer_index = self.buffer_index % self.buffer_size 225 | assert self.buffer_index < self.buffer_size 226 | else: 227 | buffer_left = self.buffer_size - self.buffer_index 228 | self.insert_episode_batch(ep_batch[0:buffer_left, :]) 229 | self.insert_episode_batch(ep_batch[buffer_left:, :]) 230 | 231 | def can_sample(self, batch_size): 232 | return self.episodes_in_buffer >= batch_size 233 | 234 | def sample(self, batch_size): 235 | assert self.can_sample(batch_size) 236 | if self.episodes_in_buffer == batch_size: 237 | return self[:batch_size] 238 | else: 239 | # Uniform sampling only atm 240 | ep_ids = np.random.choice(self.episodes_in_buffer, batch_size, replace=False) 241 | return self[ep_ids] 242 | 243 | def __repr__(self): 244 | return "ReplayBuffer. {}/{} episodes. Keys:{} Groups:{}".format(self.episodes_in_buffer, 245 | self.buffer_size, 246 | self.scheme.keys(), 247 | self.groups.keys()) 248 | 249 | -------------------------------------------------------------------------------- /src/components/epsilon_schedules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DecayThenFlatSchedule(): 5 | 6 | def __init__(self, 7 | start, 8 | finish, 9 | time_length, 10 | decay="exp"): 11 | 12 | self.start = start 13 | self.finish = finish 14 | self.time_length = time_length 15 | self.delta = (self.start - self.finish) / self.time_length 16 | self.decay = decay 17 | 18 | if self.decay in ["exp"]: 19 | self.exp_scaling = (-1) * self.time_length / np.log(self.finish) if self.finish > 0 else 1 20 | 21 | def eval(self, T): 22 | if self.decay in ["linear"]: 23 | return max(self.finish, self.start - self.delta * T) 24 | elif self.decay in ["exp"]: 25 | return min(self.start, max(self.finish, np.exp(- T / self.exp_scaling))) 26 | pass 27 | -------------------------------------------------------------------------------- /src/components/transforms.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | class Transform: 5 | def transform(self, tensor): 6 | raise NotImplementedError 7 | 8 | def infer_output_info(self, vshape_in, dtype_in): 9 | raise NotImplementedError 10 | 11 | 12 | class OneHot(Transform): 13 | def __init__(self, out_dim): 14 | self.out_dim = out_dim 15 | 16 | def transform(self, tensor): 17 | y_onehot = tensor.new(*tensor.shape[:-1], self.out_dim).zero_() 18 | y_onehot.scatter_(-1, tensor.long(), 1) 19 | return y_onehot.float() 20 | 21 | def infer_output_info(self, vshape_in, dtype_in): 22 | return (self.out_dim,), th.float32 -------------------------------------------------------------------------------- /src/config/algs/coma.yaml: -------------------------------------------------------------------------------- 1 | # --- COMA specific parameters --- 2 | 3 | action_selector: "multinomial" 4 | epsilon_start: .5 5 | epsilon_finish: .01 6 | epsilon_anneal_time: 100000 7 | mask_before_softmax: False 8 | 9 | runner: "parallel" 10 | 11 | buffer_size: 8 12 | batch_size_run: 8 13 | batch_size: 8 14 | 15 | env_args: 16 | state_last_action: False # critic adds last action internally 17 | 18 | # update the target network every {} training steps 19 | target_update_interval: 200 20 | 21 | lr: 0.0005 22 | critic_lr: 0.0005 23 | td_lambda: 0.8 24 | 25 | # use COMA 26 | agent_output_type: "pi_logits" 27 | learner: "coma_learner" 28 | critic_q_fn: "coma" 29 | critic_baseline_fn: "coma" 30 | critic_train_mode: "seq" 31 | critic_train_reps: 1 32 | q_nstep: 0 # 0 corresponds to default Q, 1 is r + gamma*Q, etc 33 | 34 | name: "coma" 35 | -------------------------------------------------------------------------------- /src/config/algs/iql.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "q_learner" 19 | double_q: True 20 | mixer: # Mixer becomes None 21 | 22 | name: "iql" 23 | -------------------------------------------------------------------------------- /src/config/algs/iql_beta.yaml: -------------------------------------------------------------------------------- 1 | # --- IQL specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | 12 | buffer_size: 5000 13 | 14 | # update the target network every {} episodes 15 | target_update_interval: 200 16 | 17 | # use the Q_Learner to train 18 | agent_output_type: "q" 19 | learner: "q_learner" 20 | double_q: True 21 | mixer: # Mixer becomes None 22 | 23 | name: "iql_smac_parallel" -------------------------------------------------------------------------------- /src/config/algs/qmix.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "q_learner" 19 | double_q: True 20 | mixer: "qmix" 21 | mixing_embed_dim: 32 22 | hypernet_layers: 2 23 | hypernet_embed: 64 24 | 25 | name: "qmix" 26 | -------------------------------------------------------------------------------- /src/config/algs/qmix_beta.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | 12 | buffer_size: 5000 13 | 14 | # update the target network every {} episodes 15 | target_update_interval: 200 16 | 17 | # use the Q_Learner to train 18 | agent_output_type: "q" 19 | learner: "q_learner" 20 | double_q: True 21 | mixer: "qmix" 22 | mixing_embed_dim: 32 23 | 24 | name: "qmix_smac_parallel" 25 | -------------------------------------------------------------------------------- /src/config/algs/qtran.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "qtran_learner" 19 | double_q: True 20 | mixer: "qtran_base" 21 | mixing_embed_dim: 64 22 | qtran_arch: "qtran_paper" 23 | 24 | opt_loss: 1 25 | nopt_min_loss: 0.1 26 | 27 | network_size: small 28 | 29 | name: "qtran" 30 | -------------------------------------------------------------------------------- /src/config/algs/vdn.yaml: -------------------------------------------------------------------------------- 1 | # --- QMIX specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "episode" 10 | 11 | buffer_size: 5000 12 | 13 | # update the target network every {} episodes 14 | target_update_interval: 200 15 | 16 | # use the Q_Learner to train 17 | agent_output_type: "q" 18 | learner: "q_learner" 19 | double_q: True 20 | mixer: "vdn" 21 | 22 | name: "vdn" 23 | -------------------------------------------------------------------------------- /src/config/algs/vdn_beta.yaml: -------------------------------------------------------------------------------- 1 | # --- VDN specific parameters --- 2 | 3 | # use epsilon greedy action selector 4 | action_selector: "epsilon_greedy" 5 | epsilon_start: 1.0 6 | epsilon_finish: 0.05 7 | epsilon_anneal_time: 50000 8 | 9 | runner: "parallel" 10 | batch_size_run: 8 11 | 12 | buffer_size: 5000 13 | 14 | # update the target network every {} episodes 15 | target_update_interval: 200 16 | 17 | # use the Q_Learner to train 18 | agent_output_type: "q" 19 | learner: "q_learner" 20 | double_q: True 21 | mixer: "vdn" 22 | 23 | name: "vdn_smac_parallel" 24 | -------------------------------------------------------------------------------- /src/config/default.yaml: -------------------------------------------------------------------------------- 1 | # --- Defaults --- 2 | 3 | # --- pymarl options --- 4 | runner: "episode" # Runs 1 env for an episode 5 | mac: "basic_mac" # Basic controller 6 | env: "sc2" # Environment name 7 | env_args: {} # Arguments for the environment 8 | batch_size_run: 1 # Number of environments to run in parallel 9 | test_nepisode: 20 # Number of episodes to test for 10 | test_interval: 2000 # Test after {} timesteps have passed 11 | test_greedy: True # Use greedy evaluation (if False, will set epsilon floor to 0 12 | log_interval: 2000 # Log summary of stats after every {} timesteps 13 | runner_log_interval: 2000 # Log runner stats (not test stats) every {} timesteps 14 | learner_log_interval: 2000 # Log training stats every {} timesteps 15 | t_max: 10000 # Stop running after this many timesteps 16 | use_cuda: True # Use gpu by default unless it isn't available 17 | buffer_cpu_only: True # If true we won't keep all of the replay buffer in vram 18 | 19 | # --- Logging options --- 20 | use_tensorboard: True # Log results to tensorboard 21 | save_model: True # Save the models to disk 22 | save_model_interval: 2000000 # Save models after this many timesteps 23 | checkpoint_path: "" # Load a checkpoint from this path 24 | evaluate: False # Evaluate model for test_nepisode episodes and quit (no training) 25 | load_step: 0 # Load model trained on this many timesteps (0 if choose max possible) 26 | save_replay: False # Saving the replay of the model loaded from checkpoint_path 27 | local_results_path: "results" # Path for local results 28 | 29 | # --- RL hyperparameters --- 30 | gamma: 0.99 31 | batch_size: 32 # Number of episodes to train on 32 | buffer_size: 32 # Size of the replay buffer 33 | lr: 0.0005 # Learning rate for agents 34 | critic_lr: 0.0005 # Learning rate for critics 35 | optim_alpha: 0.99 # RMSProp alpha 36 | optim_eps: 0.00001 # RMSProp epsilon 37 | grad_norm_clip: 10 # Reduce magnitude of gradients above this L2 norm 38 | 39 | # --- Agent parameters. Should be set manually. --- 40 | agent: "updet" # Options [updet, transformer_aggregation, rnn] 41 | rnn_hidden_dim: 64 # Size of hidden state for default rnn agent 42 | obs_agent_id: False # Include the agent's one_hot id in the observation 43 | obs_last_action: False # Include the agent's last action (one_hot) in the observation 44 | 45 | # --- Transformer parameters. Should be set manually. --- 46 | token_dim: 5 # Marines. For other unit type (e.g. Zeolot) this number can be different (6). 47 | emb: 32 # embedding dimension of transformer 48 | heads: 3 # head number of transformer 49 | depth: 2 # block number of transformer 50 | ally_num: 5 # number of ally (5m_vs_6m) 51 | enemy_num: 6 # number of enemy (5m_vs_6m) 52 | 53 | # --- Experiment running params --- 54 | repeat_id: 1 55 | label: "default_label" 56 | -------------------------------------------------------------------------------- /src/config/envs/sc2.yaml: -------------------------------------------------------------------------------- 1 | env: sc2 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "3m" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | state_last_action: True 27 | state_timestep_number: False 28 | step_mul: 8 29 | seed: null 30 | heuristic_ai: False 31 | heuristic_rest: False 32 | debug: False 33 | 34 | test_greedy: True 35 | test_nepisode: 32 36 | test_interval: 10000 37 | log_interval: 10000 38 | runner_log_interval: 10000 39 | learner_log_interval: 10000 40 | t_max: 2050000 41 | -------------------------------------------------------------------------------- /src/config/envs/sc2_beta.yaml: -------------------------------------------------------------------------------- 1 | env: sc2 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "3m" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | state_last_action: True 27 | state_timestep_number: False 28 | step_mul: 8 29 | seed: null 30 | heuristic_ai: False 31 | debug: False 32 | 33 | learner_log_interval: 20000 34 | log_interval: 20000 35 | runner_log_interval: 20000 36 | t_max: 10050000 37 | test_interval: 20000 38 | test_nepisode: 24 39 | test_greedy: True 40 | -------------------------------------------------------------------------------- /src/controllers/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .basic_controller import BasicMAC 4 | 5 | REGISTRY["basic_mac"] = BasicMAC -------------------------------------------------------------------------------- /src/controllers/basic_controller.py: -------------------------------------------------------------------------------- 1 | from modules.agents import REGISTRY as agent_REGISTRY 2 | from components.action_selectors import REGISTRY as action_REGISTRY 3 | import torch as th 4 | 5 | 6 | # This multi-agent controller shares parameters between agents 7 | class BasicMAC: 8 | def __init__(self, scheme, groups, args): 9 | self.n_agents = args.n_agents 10 | self.args = args 11 | input_shape = self._get_input_shape(scheme) 12 | self._build_agents(input_shape) 13 | self.agent_output_type = args.agent_output_type 14 | 15 | self.action_selector = action_REGISTRY[args.action_selector](args) 16 | 17 | self.hidden_states = None 18 | 19 | def select_actions(self, ep_batch, t_ep, t_env, bs=slice(None), test_mode=False): 20 | # Only select actions for the selected batch elements in bs 21 | avail_actions = ep_batch["avail_actions"][:, t_ep] 22 | agent_outputs = self.forward(ep_batch, t_ep, test_mode=test_mode) 23 | chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode) 24 | return chosen_actions 25 | 26 | def forward(self, ep_batch, t, test_mode=False): 27 | 28 | # rnn based agent 29 | if self.args.agent not in ['updet', 'transformer_aggregation']: 30 | agent_inputs = self._build_inputs(ep_batch, t) 31 | avail_actions = ep_batch["avail_actions"][:, t] 32 | agent_outs, self.hidden_states = self.agent(agent_inputs, self.hidden_states) 33 | 34 | # Softmax the agent outputs if they're policy logits 35 | if self.agent_output_type == "pi_logits": 36 | 37 | if getattr(self.args, "mask_before_softmax", True): 38 | # Make the logits for unavailable actions very negative to minimise their affect on the softmax 39 | reshaped_avail_actions = avail_actions.reshape(ep_batch.batch_size * self.n_agents, -1) 40 | agent_outs[reshaped_avail_actions == 0] = -1e10 41 | 42 | agent_outs = th.nn.functional.softmax(agent_outs, dim=-1) 43 | if not test_mode: 44 | # Epsilon floor 45 | epsilon_action_num = agent_outs.size(-1) 46 | if getattr(self.args, "mask_before_softmax", True): 47 | # With probability epsilon, we will pick an available action uniformly 48 | epsilon_action_num = reshaped_avail_actions.sum(dim=1, keepdim=True).float() 49 | 50 | agent_outs = ((1 - self.action_selector.epsilon) * agent_outs 51 | + th.ones_like(agent_outs) * self.action_selector.epsilon/epsilon_action_num) 52 | 53 | if getattr(self.args, "mask_before_softmax", True): 54 | # Zero out the unavailable actions 55 | agent_outs[reshaped_avail_actions == 0] = 0.0 56 | 57 | # transformer based agent 58 | else: 59 | agent_inputs = self._build_inputs_transformer(ep_batch, t) 60 | agent_outs, self.hidden_states = self.agent(agent_inputs, 61 | self.hidden_states.reshape(-1, 1, self.args.emb), 62 | self.args.enemy_num, self.args.ally_num) 63 | 64 | return agent_outs.view(ep_batch.batch_size, self.n_agents, -1) 65 | 66 | def init_hidden(self, batch_size): 67 | if self.args.agent not in ['updet', 'transformer_aggregation']: 68 | self.hidden_states = self.agent.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, -1) # bav 69 | else: 70 | self.hidden_states = self.agent.init_hidden().unsqueeze(0).expand(batch_size, self.n_agents, 1, -1) 71 | 72 | 73 | def parameters(self): 74 | return self.agent.parameters() 75 | 76 | def load_state(self, other_mac): 77 | self.agent.load_state_dict(other_mac.agent.state_dict()) 78 | 79 | def cuda(self): 80 | self.agent.cuda() 81 | 82 | def save_models(self, path): 83 | th.save(self.agent.state_dict(), "{}/agent.th".format(path)) 84 | 85 | def load_models(self, path): 86 | self.agent.load_state_dict(th.load("{}/agent.th".format(path), map_location=lambda storage, loc: storage)) 87 | 88 | def _build_agents(self, input_shape): 89 | self.agent = agent_REGISTRY[self.args.agent](input_shape, self.args) 90 | 91 | def _build_inputs(self, batch, t): 92 | # Assumes homogenous agents with flat observations. 93 | # Other MACs might want to e.g. delegate building inputs to each agent 94 | bs = batch.batch_size 95 | inputs = [] 96 | inputs.append(batch["obs"][:, t]) # b1av 97 | if self.args.obs_last_action: 98 | if t == 0: 99 | inputs.append(th.zeros_like(batch["actions_onehot"][:, t])) 100 | else: 101 | inputs.append(batch["actions_onehot"][:, t-1]) 102 | if self.args.obs_agent_id: 103 | inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).expand(bs, -1, -1)) 104 | 105 | inputs = th.cat([x.reshape(bs*self.n_agents, -1) for x in inputs], dim=1) 106 | return inputs 107 | 108 | def _build_inputs_transformer(self, batch, t): 109 | # currently we only support battles with marines (e.g. 3m 8m 5m_vs_6m) 110 | # you can implement your own with any other agent type. 111 | inputs = [] 112 | raw_obs = batch["obs"][:, t] 113 | arranged_obs = th.cat((raw_obs[:, :, -1:], raw_obs[:, :, :-1]), 2) 114 | reshaped_obs = arranged_obs.view(-1, 1 + (self.args.enemy_num - 1) + self.args.ally_num, self.args.token_dim) 115 | inputs.append(reshaped_obs) 116 | inputs = th.cat(inputs, dim=1).cuda() 117 | return inputs 118 | 119 | def _get_input_shape(self, scheme): 120 | input_shape = scheme["obs"]["vshape"] 121 | if self.args.obs_last_action: 122 | input_shape += scheme["actions_onehot"]["vshape"][0] 123 | if self.args.obs_agent_id: 124 | input_shape += self.n_agents 125 | 126 | return input_shape 127 | -------------------------------------------------------------------------------- /src/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from smac.env import MultiAgentEnv, StarCraft2Env 3 | import sys 4 | import os 5 | 6 | def env_fn(env, **kwargs) -> MultiAgentEnv: 7 | return env(**kwargs) 8 | 9 | REGISTRY = {} 10 | REGISTRY["sc2"] = partial(env_fn, env=StarCraft2Env) 11 | 12 | if sys.platform == "linux": 13 | os.environ.setdefault("SC2PATH", 14 | os.path.join(os.getcwd(), "3rdparty", "StarCraftII")) 15 | -------------------------------------------------------------------------------- /src/envs/multiagentenv.py: -------------------------------------------------------------------------------- 1 | class MultiAgentEnv(object): 2 | 3 | def step(self, actions): 4 | """ Returns reward, terminated, info """ 5 | raise NotImplementedError 6 | 7 | def get_obs(self): 8 | """ Returns all agent observations in a list """ 9 | raise NotImplementedError 10 | 11 | def get_obs_agent(self, agent_id): 12 | """ Returns observation for agent_id """ 13 | raise NotImplementedError 14 | 15 | def get_obs_size(self): 16 | """ Returns the shape of the observation """ 17 | raise NotImplementedError 18 | 19 | def get_state(self): 20 | raise NotImplementedError 21 | 22 | def get_state_size(self): 23 | """ Returns the shape of the state""" 24 | raise NotImplementedError 25 | 26 | def get_avail_actions(self): 27 | raise NotImplementedError 28 | 29 | def get_avail_agent_actions(self, agent_id): 30 | """ Returns the available actions for agent_id """ 31 | raise NotImplementedError 32 | 33 | def get_total_actions(self): 34 | """ Returns the total number of actions an agent could ever take """ 35 | # TODO: This is only suitable for a discrete 1 dimensional action space for each agent 36 | raise NotImplementedError 37 | 38 | def reset(self): 39 | """ Returns initial observations and states""" 40 | raise NotImplementedError 41 | 42 | def render(self): 43 | raise NotImplementedError 44 | 45 | def close(self): 46 | raise NotImplementedError 47 | 48 | def seed(self): 49 | raise NotImplementedError 50 | 51 | def save_replay(self): 52 | raise NotImplementedError 53 | 54 | def get_env_info(self): 55 | env_info = {"state_shape": self.get_state_size(), 56 | "obs_shape": self.get_obs_size(), 57 | "n_actions": self.get_total_actions(), 58 | "n_agents": self.n_agents, 59 | "episode_limit": self.episode_limit} 60 | return env_info 61 | -------------------------------------------------------------------------------- /src/learners/__init__.py: -------------------------------------------------------------------------------- 1 | from .q_learner import QLearner 2 | from .coma_learner import COMALearner 3 | from .qtran_learner import QLearner as QTranLearner 4 | 5 | REGISTRY = {} 6 | 7 | REGISTRY["q_learner"] = QLearner 8 | REGISTRY["coma_learner"] = COMALearner 9 | REGISTRY["qtran_learner"] = QTranLearner 10 | -------------------------------------------------------------------------------- /src/learners/coma_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.critics.coma import COMACritic 4 | from utils.rl_utils import build_td_lambda_targets 5 | import torch as th 6 | from torch.optim import RMSprop 7 | 8 | 9 | class COMALearner: 10 | def __init__(self, mac, scheme, logger, args): 11 | self.args = args 12 | self.n_agents = args.n_agents 13 | self.n_actions = args.n_actions 14 | self.mac = mac 15 | self.logger = logger 16 | 17 | self.last_target_update_step = 0 18 | self.critic_training_steps = 0 19 | 20 | self.log_stats_t = -self.args.learner_log_interval - 1 21 | 22 | self.critic = COMACritic(scheme, args) 23 | self.target_critic = copy.deepcopy(self.critic) 24 | 25 | self.agent_params = list(mac.parameters()) 26 | self.critic_params = list(self.critic.parameters()) 27 | self.params = self.agent_params + self.critic_params 28 | 29 | self.agent_optimiser = RMSprop(params=self.agent_params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 30 | self.critic_optimiser = RMSprop(params=self.critic_params, lr=args.critic_lr, alpha=args.optim_alpha, eps=args.optim_eps) 31 | 32 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 33 | # Get the relevant quantities 34 | bs = batch.batch_size 35 | max_t = batch.max_seq_length 36 | rewards = batch["reward"][:, :-1] 37 | actions = batch["actions"][:, :] 38 | terminated = batch["terminated"][:, :-1].float() 39 | mask = batch["filled"][:, :-1].float() 40 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 41 | avail_actions = batch["avail_actions"][:, :-1] 42 | 43 | critic_mask = mask.clone() 44 | 45 | mask = mask.repeat(1, 1, self.n_agents).view(-1) 46 | 47 | q_vals, critic_train_stats = self._train_critic(batch, rewards, terminated, actions, avail_actions, 48 | critic_mask, bs, max_t) 49 | 50 | actions = actions[:,:-1] 51 | 52 | mac_out = [] 53 | self.mac.init_hidden(batch.batch_size) 54 | for t in range(batch.max_seq_length - 1): 55 | agent_outs = self.mac.forward(batch, t=t) 56 | mac_out.append(agent_outs) 57 | mac_out = th.stack(mac_out, dim=1) # Concat over time 58 | 59 | # Mask out unavailable actions, renormalise (as in action selection) 60 | mac_out[avail_actions == 0] = 0 61 | mac_out = mac_out/mac_out.sum(dim=-1, keepdim=True) 62 | mac_out[avail_actions == 0] = 0 63 | 64 | # Calculated baseline 65 | q_vals = q_vals.reshape(-1, self.n_actions) 66 | pi = mac_out.view(-1, self.n_actions) 67 | baseline = (pi * q_vals).sum(-1).detach() 68 | 69 | # Calculate policy grad with mask 70 | q_taken = th.gather(q_vals, dim=1, index=actions.reshape(-1, 1)).squeeze(1) 71 | pi_taken = th.gather(pi, dim=1, index=actions.reshape(-1, 1)).squeeze(1) 72 | pi_taken[mask == 0] = 1.0 73 | log_pi_taken = th.log(pi_taken) 74 | 75 | advantages = (q_taken - baseline).detach() 76 | 77 | coma_loss = - ((advantages * log_pi_taken) * mask).sum() / mask.sum() 78 | 79 | # Optimise agents 80 | self.agent_optimiser.zero_grad() 81 | coma_loss.backward() 82 | grad_norm = th.nn.utils.clip_grad_norm_(self.agent_params, self.args.grad_norm_clip) 83 | self.agent_optimiser.step() 84 | 85 | if (self.critic_training_steps - self.last_target_update_step) / self.args.target_update_interval >= 1.0: 86 | self._update_targets() 87 | self.last_target_update_step = self.critic_training_steps 88 | 89 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 90 | ts_logged = len(critic_train_stats["critic_loss"]) 91 | for key in ["critic_loss", "critic_grad_norm", "td_error_abs", "q_taken_mean", "target_mean"]: 92 | self.logger.log_stat(key, sum(critic_train_stats[key])/ts_logged, t_env) 93 | 94 | self.logger.log_stat("advantage_mean", (advantages * mask).sum().item() / mask.sum().item(), t_env) 95 | self.logger.log_stat("coma_loss", coma_loss.item(), t_env) 96 | self.logger.log_stat("agent_grad_norm", grad_norm, t_env) 97 | self.logger.log_stat("pi_max", (pi.max(dim=1)[0] * mask).sum().item() / mask.sum().item(), t_env) 98 | self.log_stats_t = t_env 99 | 100 | def _train_critic(self, batch, rewards, terminated, actions, avail_actions, mask, bs, max_t): 101 | # Optimise critic 102 | target_q_vals = self.target_critic(batch)[:, :] 103 | targets_taken = th.gather(target_q_vals, dim=3, index=actions).squeeze(3) 104 | 105 | # Calculate td-lambda targets 106 | targets = build_td_lambda_targets(rewards, terminated, mask, targets_taken, self.n_agents, self.args.gamma, self.args.td_lambda) 107 | 108 | q_vals = th.zeros_like(target_q_vals)[:, :-1] 109 | 110 | running_log = { 111 | "critic_loss": [], 112 | "critic_grad_norm": [], 113 | "td_error_abs": [], 114 | "target_mean": [], 115 | "q_taken_mean": [], 116 | } 117 | 118 | for t in reversed(range(rewards.size(1))): 119 | mask_t = mask[:, t].expand(-1, self.n_agents) 120 | if mask_t.sum() == 0: 121 | continue 122 | 123 | q_t = self.critic(batch, t) 124 | q_vals[:, t] = q_t.view(bs, self.n_agents, self.n_actions) 125 | q_taken = th.gather(q_t, dim=3, index=actions[:, t:t+1]).squeeze(3).squeeze(1) 126 | targets_t = targets[:, t] 127 | 128 | td_error = (q_taken - targets_t.detach()) 129 | 130 | # 0-out the targets that came from padded data 131 | masked_td_error = td_error * mask_t 132 | 133 | # Normal L2 loss, take mean over actual data 134 | loss = (masked_td_error ** 2).sum() / mask_t.sum() 135 | self.critic_optimiser.zero_grad() 136 | loss.backward() 137 | grad_norm = th.nn.utils.clip_grad_norm_(self.critic_params, self.args.grad_norm_clip) 138 | self.critic_optimiser.step() 139 | self.critic_training_steps += 1 140 | 141 | running_log["critic_loss"].append(loss.item()) 142 | running_log["critic_grad_norm"].append(grad_norm) 143 | mask_elems = mask_t.sum().item() 144 | running_log["td_error_abs"].append((masked_td_error.abs().sum().item() / mask_elems)) 145 | running_log["q_taken_mean"].append((q_taken * mask_t).sum().item() / mask_elems) 146 | running_log["target_mean"].append((targets_t * mask_t).sum().item() / mask_elems) 147 | 148 | return q_vals, running_log 149 | 150 | def _update_targets(self): 151 | self.target_critic.load_state_dict(self.critic.state_dict()) 152 | self.logger.console_logger.info("Updated target network") 153 | 154 | def cuda(self): 155 | self.mac.cuda() 156 | self.critic.cuda() 157 | self.target_critic.cuda() 158 | 159 | def save_models(self, path): 160 | self.mac.save_models(path) 161 | th.save(self.critic.state_dict(), "{}/critic.th".format(path)) 162 | th.save(self.agent_optimiser.state_dict(), "{}/agent_opt.th".format(path)) 163 | th.save(self.critic_optimiser.state_dict(), "{}/critic_opt.th".format(path)) 164 | 165 | def load_models(self, path): 166 | self.mac.load_models(path) 167 | self.critic.load_state_dict(th.load("{}/critic.th".format(path), map_location=lambda storage, loc: storage)) 168 | # Not quite right but I don't want to save target networks 169 | self.target_critic.load_state_dict(self.critic.state_dict()) 170 | self.agent_optimiser.load_state_dict(th.load("{}/agent_opt.th".format(path), map_location=lambda storage, loc: storage)) 171 | self.critic_optimiser.load_state_dict(th.load("{}/critic_opt.th".format(path), map_location=lambda storage, loc: storage)) 172 | -------------------------------------------------------------------------------- /src/learners/q_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.vdn import VDNMixer 4 | from modules.mixers.qmix import QMixer 5 | import torch as th 6 | from torch.optim import RMSprop 7 | 8 | 9 | class QLearner: 10 | def __init__(self, mac, scheme, logger, args): 11 | self.args = args 12 | self.mac = mac 13 | self.logger = logger 14 | 15 | self.params = list(mac.parameters()) 16 | 17 | self.last_target_update_episode = 0 18 | 19 | self.mixer = None 20 | if args.mixer is not None: 21 | if args.mixer == "vdn": 22 | self.mixer = VDNMixer() 23 | elif args.mixer == "qmix": 24 | self.mixer = QMixer(args) 25 | else: 26 | raise ValueError("Mixer {} not recognised.".format(args.mixer)) 27 | self.params += list(self.mixer.parameters()) 28 | self.target_mixer = copy.deepcopy(self.mixer) 29 | 30 | self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 31 | 32 | # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC 33 | self.target_mac = copy.deepcopy(mac) 34 | 35 | self.log_stats_t = -self.args.learner_log_interval - 1 36 | 37 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 38 | # Get the relevant quantities 39 | rewards = batch["reward"][:, :-1] 40 | actions = batch["actions"][:, :-1] 41 | terminated = batch["terminated"][:, :-1].float() 42 | mask = batch["filled"][:, :-1].float() 43 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 44 | avail_actions = batch["avail_actions"] 45 | 46 | # Calculate estimated Q-Values 47 | mac_out = [] 48 | self.mac.init_hidden(batch.batch_size) 49 | for t in range(batch.max_seq_length): 50 | agent_outs = self.mac.forward(batch, t=t) 51 | mac_out.append(agent_outs) 52 | mac_out = th.stack(mac_out, dim=1) # Concat over time 53 | 54 | # Pick the Q-Values for the actions taken by each agent 55 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 56 | 57 | # Calculate the Q-Values necessary for the target 58 | target_mac_out = [] 59 | self.target_mac.init_hidden(batch.batch_size) 60 | for t in range(batch.max_seq_length): 61 | target_agent_outs = self.target_mac.forward(batch, t=t) 62 | target_mac_out.append(target_agent_outs) 63 | 64 | # We don't need the first timesteps Q-Value estimate for calculating targets 65 | target_mac_out = th.stack(target_mac_out[1:], dim=1) # Concat across time 66 | 67 | # Mask out unavailable actions 68 | target_mac_out[avail_actions[:, 1:] == 0] = -9999999 69 | 70 | # Max over target Q-Values 71 | if self.args.double_q: 72 | # Get actions that maximise live Q (for double q-learning) 73 | mac_out_detach = mac_out.clone().detach() 74 | mac_out_detach[avail_actions == 0] = -9999999 75 | cur_max_actions = mac_out_detach[:, 1:].max(dim=3, keepdim=True)[1] 76 | target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3) 77 | else: 78 | target_max_qvals = target_mac_out.max(dim=3)[0] 79 | 80 | # Mix 81 | if self.mixer is not None: 82 | chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1]) 83 | target_max_qvals = self.target_mixer(target_max_qvals, batch["state"][:, 1:]) 84 | 85 | # Calculate 1-step Q-Learning targets 86 | targets = rewards + self.args.gamma * (1 - terminated) * target_max_qvals 87 | 88 | # Td-error 89 | td_error = (chosen_action_qvals - targets.detach()) 90 | 91 | mask = mask.expand_as(td_error) 92 | 93 | # 0-out the targets that came from padded data 94 | masked_td_error = td_error * mask 95 | 96 | # Normal L2 loss, take mean over actual data 97 | loss = (masked_td_error ** 2).sum() / mask.sum() 98 | 99 | # Optimise 100 | self.optimiser.zero_grad() 101 | loss.backward() 102 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 103 | self.optimiser.step() 104 | 105 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 106 | self._update_targets() 107 | self.last_target_update_episode = episode_num 108 | 109 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 110 | self.logger.log_stat("loss", loss.item(), t_env) 111 | self.logger.log_stat("grad_norm", grad_norm, t_env) 112 | mask_elems = mask.sum().item() 113 | self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) 114 | self.logger.log_stat("q_taken_mean", (chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 115 | self.logger.log_stat("target_mean", (targets * mask).sum().item()/(mask_elems * self.args.n_agents), t_env) 116 | self.log_stats_t = t_env 117 | 118 | def _update_targets(self): 119 | self.target_mac.load_state(self.mac) 120 | if self.mixer is not None: 121 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 122 | self.logger.console_logger.info("Updated target network") 123 | 124 | def cuda(self): 125 | self.mac.cuda() 126 | self.target_mac.cuda() 127 | if self.mixer is not None: 128 | self.mixer.cuda() 129 | self.target_mixer.cuda() 130 | 131 | def save_models(self, path): 132 | self.mac.save_models(path) 133 | if self.mixer is not None: 134 | th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) 135 | th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) 136 | 137 | def load_models(self, path): 138 | self.mac.load_models(path) 139 | # Not quite right but I don't want to save target networks 140 | self.target_mac.load_models(path) 141 | if self.mixer is not None: 142 | self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) 143 | self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) 144 | -------------------------------------------------------------------------------- /src/learners/qtran_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from components.episode_buffer import EpisodeBatch 3 | from modules.mixers.qtran import QTranBase 4 | import torch as th 5 | from torch.optim import RMSprop, Adam 6 | 7 | 8 | class QLearner: 9 | def __init__(self, mac, scheme, logger, args): 10 | self.args = args 11 | self.mac = mac 12 | self.logger = logger 13 | 14 | self.params = list(mac.parameters()) 15 | 16 | self.last_target_update_episode = 0 17 | 18 | self.mixer = None 19 | if args.mixer == "qtran_base": 20 | self.mixer = QTranBase(args) 21 | elif args.mixer == "qtran_alt": 22 | raise Exception("Not implemented here!") 23 | 24 | self.params += list(self.mixer.parameters()) 25 | self.target_mixer = copy.deepcopy(self.mixer) 26 | 27 | self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps) 28 | 29 | # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC 30 | self.target_mac = copy.deepcopy(mac) 31 | 32 | self.log_stats_t = -self.args.learner_log_interval - 1 33 | 34 | def train(self, batch: EpisodeBatch, t_env: int, episode_num: int): 35 | # Get the relevant quantities 36 | rewards = batch["reward"][:, :-1] 37 | actions = batch["actions"][:, :-1] 38 | terminated = batch["terminated"][:, :-1].float() 39 | mask = batch["filled"][:, :-1].float() 40 | mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1]) 41 | avail_actions = batch["avail_actions"] 42 | 43 | # Calculate estimated Q-Values 44 | mac_out = [] 45 | mac_hidden_states = [] 46 | self.mac.init_hidden(batch.batch_size) 47 | for t in range(batch.max_seq_length): 48 | agent_outs = self.mac.forward(batch, t=t) 49 | mac_out.append(agent_outs) 50 | mac_hidden_states.append(self.mac.hidden_states) 51 | mac_out = th.stack(mac_out, dim=1) # Concat over time 52 | mac_hidden_states = th.stack(mac_hidden_states, dim=1) 53 | mac_hidden_states = mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1,2) #btav 54 | 55 | # Pick the Q-Values for the actions taken by each agent 56 | chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(3) # Remove the last dim 57 | 58 | # Calculate the Q-Values necessary for the target 59 | target_mac_out = [] 60 | target_mac_hidden_states = [] 61 | self.target_mac.init_hidden(batch.batch_size) 62 | for t in range(batch.max_seq_length): 63 | target_agent_outs = self.target_mac.forward(batch, t=t) 64 | target_mac_out.append(target_agent_outs) 65 | target_mac_hidden_states.append(self.target_mac.hidden_states) 66 | 67 | # We don't need the first timesteps Q-Value estimate for calculating targets 68 | target_mac_out = th.stack(target_mac_out[:], dim=1) # Concat across time 69 | target_mac_hidden_states = th.stack(target_mac_hidden_states, dim=1) 70 | target_mac_hidden_states = target_mac_hidden_states.reshape(batch.batch_size, self.args.n_agents, batch.max_seq_length, -1).transpose(1,2) #btav 71 | 72 | # Mask out unavailable actions 73 | target_mac_out[avail_actions[:, :] == 0] = -9999999 # From OG deepmarl 74 | mac_out_maxs = mac_out.clone() 75 | mac_out_maxs[avail_actions == 0] = -9999999 76 | 77 | # Best joint action computed by target agents 78 | target_max_actions = target_mac_out.max(dim=3, keepdim=True)[1] 79 | # Best joint-action computed by regular agents 80 | max_actions_qvals, max_actions_current = mac_out_maxs[:, :].max(dim=3, keepdim=True) 81 | 82 | if self.args.mixer == "qtran_base": 83 | # -- TD Loss -- 84 | # Joint-action Q-Value estimates 85 | joint_qs, vs = self.mixer(batch[:, :-1], mac_hidden_states[:,:-1]) 86 | 87 | # Need to argmax across the target agents' actions to compute target joint-action Q-Values 88 | if self.args.double_q: 89 | max_actions_current_ = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) 90 | max_actions_current_onehot = max_actions_current_.scatter(3, max_actions_current[:, :], 1) 91 | max_actions_onehot = max_actions_current_onehot 92 | else: 93 | max_actions = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device) 94 | max_actions_onehot = max_actions.scatter(3, target_max_actions[:, :], 1) 95 | target_joint_qs, target_vs = self.target_mixer(batch[:, 1:], hidden_states=target_mac_hidden_states[:,1:], actions=max_actions_onehot[:,1:]) 96 | 97 | # Td loss targets 98 | td_targets = rewards.reshape(-1,1) + self.args.gamma * (1 - terminated.reshape(-1, 1)) * target_joint_qs 99 | td_error = (joint_qs - td_targets.detach()) 100 | masked_td_error = td_error * mask.reshape(-1, 1) 101 | td_loss = (masked_td_error ** 2).sum() / mask.sum() 102 | # -- TD Loss -- 103 | 104 | # -- Opt Loss -- 105 | # Argmax across the current agents' actions 106 | if not self.args.double_q: # Already computed if we're doing double Q-Learning 107 | max_actions_current_ = th.zeros(size=(batch.batch_size, batch.max_seq_length, self.args.n_agents, self.args.n_actions), device=batch.device ) 108 | max_actions_current_onehot = max_actions_current_.scatter(3, max_actions_current[:, :], 1) 109 | max_joint_qs, _ = self.mixer(batch[:, :-1], mac_hidden_states[:,:-1], actions=max_actions_current_onehot[:,:-1]) # Don't use the target network and target agent max actions as per author's email 110 | 111 | # max_actions_qvals = th.gather(mac_out[:, :-1], dim=3, index=max_actions_current[:,:-1]) 112 | opt_error = max_actions_qvals[:,:-1].sum(dim=2).reshape(-1, 1) - max_joint_qs.detach() + vs 113 | masked_opt_error = opt_error * mask.reshape(-1, 1) 114 | opt_loss = (masked_opt_error ** 2).sum() / mask.sum() 115 | # -- Opt Loss -- 116 | 117 | # -- Nopt Loss -- 118 | # target_joint_qs, _ = self.target_mixer(batch[:, :-1]) 119 | nopt_values = chosen_action_qvals.sum(dim=2).reshape(-1, 1) - joint_qs.detach() + vs # Don't use target networks here either 120 | nopt_error = nopt_values.clamp(max=0) 121 | masked_nopt_error = nopt_error * mask.reshape(-1, 1) 122 | nopt_loss = (masked_nopt_error ** 2).sum() / mask.sum() 123 | # -- Nopt loss -- 124 | 125 | elif self.args.mixer == "qtran_alt": 126 | raise Exception("Not supported yet.") 127 | 128 | loss = td_loss + self.args.opt_loss * opt_loss + self.args.nopt_min_loss * nopt_loss 129 | 130 | # Optimise 131 | self.optimiser.zero_grad() 132 | loss.backward() 133 | grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip) 134 | self.optimiser.step() 135 | 136 | if (episode_num - self.last_target_update_episode) / self.args.target_update_interval >= 1.0: 137 | self._update_targets() 138 | self.last_target_update_episode = episode_num 139 | 140 | if t_env - self.log_stats_t >= self.args.learner_log_interval: 141 | self.logger.log_stat("loss", loss.item(), t_env) 142 | self.logger.log_stat("td_loss", td_loss.item(), t_env) 143 | self.logger.log_stat("opt_loss", opt_loss.item(), t_env) 144 | self.logger.log_stat("nopt_loss", nopt_loss.item(), t_env) 145 | self.logger.log_stat("grad_norm", grad_norm, t_env) 146 | if self.args.mixer == "qtran_base": 147 | mask_elems = mask.sum().item() 148 | self.logger.log_stat("td_error_abs", (masked_td_error.abs().sum().item()/mask_elems), t_env) 149 | self.logger.log_stat("td_targets", ((masked_td_error).sum().item()/mask_elems), t_env) 150 | self.logger.log_stat("td_chosen_qs", (joint_qs.sum().item()/mask_elems), t_env) 151 | self.logger.log_stat("v_mean", (vs.sum().item()/mask_elems), t_env) 152 | self.logger.log_stat("agent_indiv_qs", ((chosen_action_qvals * mask).sum().item()/(mask_elems * self.args.n_agents)), t_env) 153 | self.log_stats_t = t_env 154 | 155 | def _update_targets(self): 156 | self.target_mac.load_state(self.mac) 157 | if self.mixer is not None: 158 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 159 | self.logger.console_logger.info("Updated target network") 160 | 161 | def cuda(self): 162 | self.mac.cuda() 163 | self.target_mac.cuda() 164 | if self.mixer is not None: 165 | self.mixer.cuda() 166 | self.target_mixer.cuda() 167 | 168 | def save_models(self, path): 169 | self.mac.save_models(path) 170 | if self.mixer is not None: 171 | th.save(self.mixer.state_dict(), "{}/mixer.th".format(path)) 172 | th.save(self.optimiser.state_dict(), "{}/opt.th".format(path)) 173 | 174 | def load_models(self, path): 175 | self.mac.load_models(path) 176 | # Not quite right but I don't want to save target networks 177 | self.target_mac.load_models(path) 178 | if self.mixer is not None: 179 | self.mixer.load_state_dict(th.load("{}/mixer.th".format(path), map_location=lambda storage, loc: storage)) 180 | self.optimiser.load_state_dict(th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)) 181 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import collections 4 | from os.path import dirname, abspath 5 | from copy import deepcopy 6 | from sacred import Experiment, SETTINGS 7 | from sacred.observers import FileStorageObserver 8 | from sacred.utils import apply_backspaces_and_linefeeds 9 | import sys 10 | import torch as th 11 | from utils.logging import get_logger 12 | import yaml 13 | 14 | from run import run 15 | 16 | SETTINGS['CAPTURE_MODE'] = "fd" # set to "no" if you want to see stdout/stderr in console 17 | logger = get_logger() 18 | 19 | ex = Experiment("pymarl") 20 | ex.logger = logger 21 | ex.captured_out_filter = apply_backspaces_and_linefeeds 22 | 23 | results_path = os.path.join(dirname(dirname(abspath(__file__))), "results") 24 | 25 | 26 | @ex.main 27 | def my_main(_run, _config, _log): 28 | # Setting the random seed throughout the modules 29 | config = config_copy(_config) 30 | np.random.seed(config["seed"]) 31 | th.manual_seed(config["seed"]) 32 | config['env_args']['seed'] = config["seed"] 33 | 34 | # run the framework 35 | run(_run, config, _log) 36 | 37 | 38 | def _get_config(params, arg_name, subfolder): 39 | config_name = None 40 | for _i, _v in enumerate(params): 41 | if _v.split("=")[0] == arg_name: 42 | config_name = _v.split("=")[1] 43 | del params[_i] 44 | break 45 | 46 | if config_name is not None: 47 | with open(os.path.join(os.path.dirname(__file__), "config", subfolder, "{}.yaml".format(config_name)), "r") as f: 48 | try: 49 | config_dict = yaml.load(f) 50 | except yaml.YAMLError as exc: 51 | assert False, "{}.yaml error: {}".format(config_name, exc) 52 | return config_dict 53 | 54 | 55 | def recursive_dict_update(d, u): 56 | for k, v in u.items(): 57 | if isinstance(v, collections.Mapping): 58 | d[k] = recursive_dict_update(d.get(k, {}), v) 59 | else: 60 | d[k] = v 61 | return d 62 | 63 | 64 | def config_copy(config): 65 | if isinstance(config, dict): 66 | return {k: config_copy(v) for k, v in config.items()} 67 | elif isinstance(config, list): 68 | return [config_copy(v) for v in config] 69 | else: 70 | return deepcopy(config) 71 | 72 | 73 | if __name__ == '__main__': 74 | params = deepcopy(sys.argv) 75 | 76 | # Get the defaults from default.yaml 77 | with open(os.path.join(os.path.dirname(__file__), "config", "default.yaml"), "r") as f: 78 | try: 79 | config_dict = yaml.load(f) 80 | except yaml.YAMLError as exc: 81 | assert False, "default.yaml error: {}".format(exc) 82 | 83 | # Load algorithm and env base configs 84 | env_config = _get_config(params, "--env-config", "envs") 85 | alg_config = _get_config(params, "--config", "algs") 86 | # config_dict = {**config_dict, **env_config, **alg_config} 87 | config_dict = recursive_dict_update(config_dict, env_config) 88 | config_dict = recursive_dict_update(config_dict, alg_config) 89 | 90 | # now add all the config to sacred 91 | ex.add_config(config_dict) 92 | 93 | # Save to disk by default for sacred 94 | logger.info("Saving to FileStorageObserver in results/sacred.") 95 | file_obs_path = os.path.join(results_path, "sacred") 96 | ex.observers.append(FileStorageObserver.create(file_obs_path)) 97 | 98 | ex.run_commandline(params) 99 | 100 | -------------------------------------------------------------------------------- /src/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Theohhhu/UPDeT/94b3db7a05e6c366596b76dd9c48981473cf6823/src/modules/__init__.py -------------------------------------------------------------------------------- /src/modules/agents/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .rnn_agent import RNNAgent 4 | REGISTRY["rnn"] = RNNAgent 5 | 6 | from .updet_agent import UPDeT 7 | REGISTRY['updet'] = UPDeT 8 | 9 | from .transformer_agg_agent import TransformerAggregationAgent 10 | REGISTRY['transformer_aggregation'] = TransformerAggregationAgent -------------------------------------------------------------------------------- /src/modules/agents/rnn_agent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class RNNAgent(nn.Module): 6 | def __init__(self, input_shape, args): 7 | super(RNNAgent, self).__init__() 8 | self.args = args 9 | 10 | self.fc1 = nn.Linear(input_shape, args.rnn_hidden_dim) 11 | self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) 12 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) 13 | 14 | def init_hidden(self): 15 | # make hidden states on same device as model 16 | return self.fc1.weight.new(1, self.args.rnn_hidden_dim).zero_() 17 | 18 | def forward(self, inputs, hidden_state): 19 | x = F.relu(self.fc1(inputs)) 20 | h_in = hidden_state.reshape(-1, self.args.rnn_hidden_dim) 21 | h = self.rnn(x, h_in) 22 | q = self.fc2(h) 23 | return q, h 24 | -------------------------------------------------------------------------------- /src/modules/agents/transformer_agg_agent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import argparse 5 | 6 | 7 | class TransformerAggregationAgent(nn.Module): 8 | def __init__(self, input_shape, args): 9 | super(TransformerAggregationAgent, self).__init__() 10 | self.args = args 11 | self.transformer = Transformer(args.token_dim, args.emb, args.heads, args.depth, args.emb) 12 | self.q_linear = nn.Linear(args.emb, 6 + args.enemy_num) 13 | 14 | def init_hidden(self): 15 | # make hidden states on same device as model 16 | return torch.zeros(1, self.args.emb).cuda() 17 | 18 | def forward(self, inputs, hidden_state, task_enemy_num, task_ally_num): 19 | outputs, _ = self.transformer.forward(inputs, hidden_state, None) 20 | 21 | # last output for hidden state 22 | h = outputs[:,-1:,:] 23 | q_agg = torch.mean(outputs, 1) 24 | q = self.q_linear(q_agg) 25 | 26 | return q, h 27 | 28 | 29 | class SelfAttention(nn.Module): 30 | def __init__(self, emb, heads=8, mask=False): 31 | 32 | super().__init__() 33 | 34 | self.emb = emb 35 | self.heads = heads 36 | self.mask = mask 37 | 38 | self.tokeys = nn.Linear(emb, emb * heads, bias=False) 39 | self.toqueries = nn.Linear(emb, emb * heads, bias=False) 40 | self.tovalues = nn.Linear(emb, emb * heads, bias=False) 41 | 42 | self.unifyheads = nn.Linear(heads * emb, emb) 43 | 44 | def forward(self, x, mask): 45 | 46 | b, t, e = x.size() 47 | h = self.heads 48 | keys = self.tokeys(x).view(b, t, h, e) 49 | queries = self.toqueries(x).view(b, t, h, e) 50 | values = self.tovalues(x).view(b, t, h, e) 51 | 52 | # compute scaled dot-product self-attention 53 | 54 | # - fold heads into the batch dimension 55 | keys = keys.transpose(1, 2).contiguous().view(b * h, t, e) 56 | queries = queries.transpose(1, 2).contiguous().view(b * h, t, e) 57 | values = values.transpose(1, 2).contiguous().view(b * h, t, e) 58 | 59 | queries = queries / (e ** (1 / 4)) 60 | keys = keys / (e ** (1 / 4)) 61 | # - Instead of dividing the dot products by sqrt(e), we scale the keys and values. 62 | # This should be more memory efficient 63 | 64 | # - get dot product of queries and keys, and scale 65 | dot = torch.bmm(queries, keys.transpose(1, 2)) 66 | 67 | assert dot.size() == (b * h, t, t) 68 | 69 | if self.mask: # mask out the upper half of the dot matrix, excluding the diagonal 70 | mask_(dot, maskval=float('-inf'), mask_diagonal=False) 71 | 72 | if mask is not None: 73 | dot = dot.masked_fill(mask == 0, -1e9) 74 | 75 | dot = F.softmax(dot, dim=2) 76 | # - dot now has row-wise self-attention probabilities 77 | 78 | # apply the self attention to the values 79 | out = torch.bmm(dot, values).view(b, h, t, e) 80 | 81 | # swap h, t back, unify heads 82 | out = out.transpose(1, 2).contiguous().view(b, t, h * e) 83 | 84 | return self.unifyheads(out) 85 | 86 | class TransformerBlock(nn.Module): 87 | 88 | def __init__(self, emb, heads, mask, ff_hidden_mult=4, dropout=0.0): 89 | super().__init__() 90 | 91 | self.attention = SelfAttention(emb, heads=heads, mask=mask) 92 | self.mask = mask 93 | 94 | self.norm1 = nn.LayerNorm(emb) 95 | self.norm2 = nn.LayerNorm(emb) 96 | 97 | self.ff = nn.Sequential( 98 | nn.Linear(emb, ff_hidden_mult * emb), 99 | nn.ReLU(), 100 | nn.Linear(ff_hidden_mult * emb, emb) 101 | ) 102 | 103 | self.do = nn.Dropout(dropout) 104 | 105 | def forward(self, x_mask): 106 | x, mask = x_mask 107 | 108 | attended = self.attention(x, mask) 109 | 110 | x = self.norm1(attended + x) 111 | 112 | x = self.do(x) 113 | 114 | fedforward = self.ff(x) 115 | 116 | x = self.norm2(fedforward + x) 117 | 118 | x = self.do(x) 119 | 120 | return x, mask 121 | 122 | 123 | class Transformer(nn.Module): 124 | 125 | def __init__(self, input_dim, emb, heads, depth, output_dim): 126 | super().__init__() 127 | 128 | self.num_tokens = output_dim 129 | 130 | self.token_embedding = nn.Linear(input_dim, emb) 131 | 132 | tblocks = [] 133 | for i in range(depth): 134 | tblocks.append( 135 | TransformerBlock(emb=emb, heads=heads, mask=False)) 136 | 137 | self.tblocks = nn.Sequential(*tblocks) 138 | 139 | self.toprobs = nn.Linear(emb, output_dim) 140 | 141 | def forward(self, x, h, mask): 142 | 143 | tokens = self.token_embedding(x) 144 | tokens = torch.cat((tokens, h), 1) 145 | 146 | b, t, e = tokens.size() 147 | 148 | x, mask = self.tblocks((tokens, mask)) 149 | 150 | x = self.toprobs(x.view(b * t, e)).view(b, t, self.num_tokens) 151 | 152 | return x, tokens 153 | 154 | def mask_(matrices, maskval=0.0, mask_diagonal=True): 155 | 156 | b, h, w = matrices.size() 157 | indices = torch.triu_indices(h, w, offset=0 if mask_diagonal else 1) 158 | matrices[:, indices[0], indices[1]] = maskval 159 | 160 | 161 | if __name__ == '__main__': 162 | parser = argparse.ArgumentParser(description='Unit Testing') 163 | parser.add_argument('--token_dim', default='5', type=int) 164 | parser.add_argument('--emb', default='32', type=int) 165 | parser.add_argument('--heads', default='3', type=int) 166 | parser.add_argument('--depth', default='2', type=int) 167 | parser.add_argument('--ally_num', default='5', type=int) 168 | parser.add_argument('--enemy_num', default='5', type=int) 169 | parser.add_argument('--episode', default='20', type=int) 170 | args = parser.parse_args() 171 | 172 | 173 | # testing the agent 174 | agent = TransformerAggregationAgent(None, args).cuda() 175 | hidden_state = agent.init_hidden().cuda().expand(args.ally_num, 1, -1) 176 | tensor = torch.rand(args.ally_num, args.ally_num+args.enemy_num, args.token_dim).cuda() 177 | q_list = [] 178 | for _ in range(args.episode): 179 | q, hidden_state = agent.forward(tensor, hidden_state, args.ally_num, args.enemy_num) 180 | q_list.append(q) -------------------------------------------------------------------------------- /src/modules/agents/updet_agent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | import argparse 5 | 6 | 7 | class UPDeT(nn.Module): 8 | def __init__(self, input_shape, args): 9 | super(UPDeT, self).__init__() 10 | self.args = args 11 | self.transformer = Transformer(args.token_dim, args.emb, args.heads, args.depth, args.emb) 12 | self.q_basic = nn.Linear(args.emb, 6) 13 | 14 | def init_hidden(self): 15 | # make hidden states on same device as model 16 | return torch.zeros(1, self.args.emb).cuda() 17 | 18 | def forward(self, inputs, hidden_state, task_enemy_num, task_ally_num): 19 | outputs, _ = self.transformer.forward(inputs, hidden_state, None) 20 | # first output for 6 action (no_op stop up down left right) 21 | q_basic_actions = self.q_basic(outputs[:, 0, :]) 22 | 23 | # last dim for hidden state 24 | h = outputs[:, -1:, :] 25 | 26 | q_enemies_list = [] 27 | 28 | # each enemy has an output Q 29 | for i in range(task_enemy_num): 30 | q_enemy = self.q_basic(outputs[:, 1 + i, :]) 31 | q_enemy_mean = torch.mean(q_enemy, 1, True) 32 | q_enemies_list.append(q_enemy_mean) 33 | 34 | # concat enemy Q over all enemies 35 | q_enemies = torch.stack(q_enemies_list, dim=1).squeeze() 36 | 37 | # concat basic action Q with enemy attack Q 38 | q = torch.cat((q_basic_actions, q_enemies), 1) 39 | 40 | return q, h 41 | 42 | class SelfAttention(nn.Module): 43 | def __init__(self, emb, heads=8, mask=False): 44 | 45 | super().__init__() 46 | 47 | self.emb = emb 48 | self.heads = heads 49 | self.mask = mask 50 | 51 | self.tokeys = nn.Linear(emb, emb * heads, bias=False) 52 | self.toqueries = nn.Linear(emb, emb * heads, bias=False) 53 | self.tovalues = nn.Linear(emb, emb * heads, bias=False) 54 | 55 | self.unifyheads = nn.Linear(heads * emb, emb) 56 | 57 | def forward(self, x, mask): 58 | 59 | b, t, e = x.size() 60 | h = self.heads 61 | keys = self.tokeys(x).view(b, t, h, e) 62 | queries = self.toqueries(x).view(b, t, h, e) 63 | values = self.tovalues(x).view(b, t, h, e) 64 | 65 | # compute scaled dot-product self-attention 66 | 67 | # - fold heads into the batch dimension 68 | keys = keys.transpose(1, 2).contiguous().view(b * h, t, e) 69 | queries = queries.transpose(1, 2).contiguous().view(b * h, t, e) 70 | values = values.transpose(1, 2).contiguous().view(b * h, t, e) 71 | 72 | queries = queries / (e ** (1 / 4)) 73 | keys = keys / (e ** (1 / 4)) 74 | # - Instead of dividing the dot products by sqrt(e), we scale the keys and values. 75 | # This should be more memory efficient 76 | 77 | # - get dot product of queries and keys, and scale 78 | dot = torch.bmm(queries, keys.transpose(1, 2)) 79 | 80 | assert dot.size() == (b * h, t, t) 81 | 82 | if self.mask: # mask out the upper half of the dot matrix, excluding the diagonal 83 | mask_(dot, maskval=float('-inf'), mask_diagonal=False) 84 | 85 | if mask is not None: 86 | dot = dot.masked_fill(mask == 0, -1e9) 87 | 88 | dot = F.softmax(dot, dim=2) 89 | # - dot now has row-wise self-attention probabilities 90 | 91 | # apply the self attention to the values 92 | out = torch.bmm(dot, values).view(b, h, t, e) 93 | 94 | # swap h, t back, unify heads 95 | out = out.transpose(1, 2).contiguous().view(b, t, h * e) 96 | 97 | return self.unifyheads(out) 98 | 99 | class TransformerBlock(nn.Module): 100 | 101 | def __init__(self, emb, heads, mask, ff_hidden_mult=4, dropout=0.0): 102 | super().__init__() 103 | 104 | self.attention = SelfAttention(emb, heads=heads, mask=mask) 105 | self.mask = mask 106 | 107 | self.norm1 = nn.LayerNorm(emb) 108 | self.norm2 = nn.LayerNorm(emb) 109 | 110 | self.ff = nn.Sequential( 111 | nn.Linear(emb, ff_hidden_mult * emb), 112 | nn.ReLU(), 113 | nn.Linear(ff_hidden_mult * emb, emb) 114 | ) 115 | 116 | self.do = nn.Dropout(dropout) 117 | 118 | def forward(self, x_mask): 119 | x, mask = x_mask 120 | 121 | attended = self.attention(x, mask) 122 | 123 | x = self.norm1(attended + x) 124 | 125 | x = self.do(x) 126 | 127 | fedforward = self.ff(x) 128 | 129 | x = self.norm2(fedforward + x) 130 | 131 | x = self.do(x) 132 | 133 | return x, mask 134 | 135 | 136 | class Transformer(nn.Module): 137 | 138 | def __init__(self, input_dim, emb, heads, depth, output_dim): 139 | super().__init__() 140 | 141 | self.num_tokens = output_dim 142 | 143 | self.token_embedding = nn.Linear(input_dim, emb) 144 | 145 | tblocks = [] 146 | for i in range(depth): 147 | tblocks.append( 148 | TransformerBlock(emb=emb, heads=heads, mask=False)) 149 | 150 | self.tblocks = nn.Sequential(*tblocks) 151 | 152 | self.toprobs = nn.Linear(emb, output_dim) 153 | 154 | def forward(self, x, h, mask): 155 | 156 | tokens = self.token_embedding(x) 157 | tokens = torch.cat((tokens, h), 1) 158 | 159 | b, t, e = tokens.size() 160 | 161 | x, mask = self.tblocks((tokens, mask)) 162 | 163 | x = self.toprobs(x.view(b * t, e)).view(b, t, self.num_tokens) 164 | 165 | return x, tokens 166 | 167 | def mask_(matrices, maskval=0.0, mask_diagonal=True): 168 | 169 | b, h, w = matrices.size() 170 | indices = torch.triu_indices(h, w, offset=0 if mask_diagonal else 1) 171 | matrices[:, indices[0], indices[1]] = maskval 172 | 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser(description='Unit Testing') 176 | parser.add_argument('--token_dim', default='5', type=int) 177 | parser.add_argument('--emb', default='32', type=int) 178 | parser.add_argument('--heads', default='3', type=int) 179 | parser.add_argument('--depth', default='2', type=int) 180 | parser.add_argument('--ally_num', default='5', type=int) 181 | parser.add_argument('--enemy_num', default='5', type=int) 182 | parser.add_argument('--episode', default='20', type=int) 183 | args = parser.parse_args() 184 | 185 | 186 | # testing the agent 187 | agent = UPDeT(None, args).cuda() 188 | hidden_state = agent.init_hidden().cuda().expand(args.ally_num, 1, -1) 189 | tensor = torch.rand(args.ally_num, args.ally_num+args.enemy_num, args.token_dim).cuda() 190 | q_list = [] 191 | for _ in range(args.episode): 192 | q, hidden_state = agent.forward(tensor, hidden_state, args.ally_num, args.enemy_num) 193 | q_list.append(q) 194 | -------------------------------------------------------------------------------- /src/modules/critics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Theohhhu/UPDeT/94b3db7a05e6c366596b76dd9c48981473cf6823/src/modules/critics/__init__.py -------------------------------------------------------------------------------- /src/modules/critics/coma.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class COMACritic(nn.Module): 7 | def __init__(self, scheme, args): 8 | super(COMACritic, self).__init__() 9 | 10 | self.args = args 11 | self.n_actions = args.n_actions 12 | self.n_agents = args.n_agents 13 | 14 | input_shape = self._get_input_shape(scheme) 15 | self.output_type = "q" 16 | 17 | # Set up network layers 18 | self.fc1 = nn.Linear(input_shape, 128) 19 | self.fc2 = nn.Linear(128, 128) 20 | self.fc3 = nn.Linear(128, self.n_actions) 21 | 22 | def forward(self, batch, t=None): 23 | inputs = self._build_inputs(batch, t=t) 24 | x = F.relu(self.fc1(inputs)) 25 | x = F.relu(self.fc2(x)) 26 | q = self.fc3(x) 27 | return q 28 | 29 | def _build_inputs(self, batch, t=None): 30 | bs = batch.batch_size 31 | max_t = batch.max_seq_length if t is None else 1 32 | ts = slice(None) if t is None else slice(t, t+1) 33 | inputs = [] 34 | # state 35 | inputs.append(batch["state"][:, ts].unsqueeze(2).repeat(1, 1, self.n_agents, 1)) 36 | 37 | # observation 38 | inputs.append(batch["obs"][:, ts]) 39 | 40 | # actions (masked out by agent) 41 | actions = batch["actions_onehot"][:, ts].view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1) 42 | agent_mask = (1 - th.eye(self.n_agents, device=batch.device)) 43 | agent_mask = agent_mask.view(-1, 1).repeat(1, self.n_actions).view(self.n_agents, -1) 44 | inputs.append(actions * agent_mask.unsqueeze(0).unsqueeze(0)) 45 | 46 | # last actions 47 | if t == 0: 48 | inputs.append(th.zeros_like(batch["actions_onehot"][:, 0:1]).view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1)) 49 | elif isinstance(t, int): 50 | inputs.append(batch["actions_onehot"][:, slice(t-1, t)].view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1)) 51 | else: 52 | last_actions = th.cat([th.zeros_like(batch["actions_onehot"][:, 0:1]), batch["actions_onehot"][:, :-1]], dim=1) 53 | last_actions = last_actions.view(bs, max_t, 1, -1).repeat(1, 1, self.n_agents, 1) 54 | inputs.append(last_actions) 55 | 56 | inputs.append(th.eye(self.n_agents, device=batch.device).unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1)) 57 | 58 | inputs = th.cat([x.reshape(bs, max_t, self.n_agents, -1) for x in inputs], dim=-1) 59 | return inputs 60 | 61 | def _get_input_shape(self, scheme): 62 | # state 63 | input_shape = scheme["state"]["vshape"] 64 | # observation 65 | input_shape += scheme["obs"]["vshape"] 66 | # actions and last actions 67 | input_shape += scheme["actions_onehot"]["vshape"][0] * self.n_agents * 2 68 | # agent id 69 | input_shape += self.n_agents 70 | return input_shape -------------------------------------------------------------------------------- /src/modules/mixers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Theohhhu/UPDeT/94b3db7a05e6c366596b76dd9c48981473cf6823/src/modules/mixers/__init__.py -------------------------------------------------------------------------------- /src/modules/mixers/qmix.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class QMixer(nn.Module): 8 | def __init__(self, args): 9 | super(QMixer, self).__init__() 10 | 11 | self.args = args 12 | self.n_agents = args.n_agents 13 | self.state_dim = int(np.prod(args.state_shape)) 14 | 15 | self.embed_dim = args.mixing_embed_dim 16 | 17 | if getattr(args, "hypernet_layers", 1) == 1: 18 | self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) 19 | self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) 20 | elif getattr(args, "hypernet_layers", 1) == 2: 21 | hypernet_embed = self.args.hypernet_embed 22 | self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 23 | nn.ReLU(), 24 | nn.Linear(hypernet_embed, self.embed_dim * self.n_agents)) 25 | self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 26 | nn.ReLU(), 27 | nn.Linear(hypernet_embed, self.embed_dim)) 28 | elif getattr(args, "hypernet_layers", 1) > 2: 29 | raise Exception("Sorry >2 hypernet layers is not implemented!") 30 | else: 31 | raise Exception("Error setting number of hypernet layers.") 32 | 33 | # State dependent bias for hidden layer 34 | self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 35 | 36 | # V(s) instead of a bias for the last layers 37 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 38 | nn.ReLU(), 39 | nn.Linear(self.embed_dim, 1)) 40 | 41 | def forward(self, agent_qs, states): 42 | bs = agent_qs.size(0) 43 | states = states.reshape(-1, self.state_dim) 44 | agent_qs = agent_qs.view(-1, 1, self.n_agents) 45 | # First layer 46 | w1 = th.abs(self.hyper_w_1(states)) 47 | b1 = self.hyper_b_1(states) 48 | w1 = w1.view(-1, self.n_agents, self.embed_dim) 49 | b1 = b1.view(-1, 1, self.embed_dim) 50 | hidden = F.elu(th.bmm(agent_qs, w1) + b1) 51 | # Second layer 52 | w_final = th.abs(self.hyper_w_final(states)) 53 | w_final = w_final.view(-1, self.embed_dim, 1) 54 | # State-dependent bias 55 | v = self.V(states).view(-1, 1, 1) 56 | # Compute final output 57 | y = th.bmm(hidden, w_final) + v 58 | # Reshape and return 59 | q_tot = y.view(bs, -1, 1) 60 | return q_tot 61 | -------------------------------------------------------------------------------- /src/modules/mixers/qtran.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class QTranBase(nn.Module): 8 | def __init__(self, args): 9 | super(QTranBase, self).__init__() 10 | 11 | self.args = args 12 | 13 | self.n_agents = args.n_agents 14 | self.n_actions = args.n_actions 15 | self.state_dim = int(np.prod(args.state_shape)) 16 | self.arch = self.args.qtran_arch # QTran architecture 17 | 18 | self.embed_dim = args.mixing_embed_dim 19 | 20 | # Q(s,u) 21 | if self.arch == "coma_critic": 22 | # Q takes [state, u] as input 23 | q_input_size = self.state_dim + (self.n_agents * self.n_actions) 24 | elif self.arch == "qtran_paper": 25 | # Q takes [state, agent_action_observation_encodings] 26 | if self.args.agent in ['updet', 'transformer_aggregation']: 27 | q_input_size = self.state_dim + self.args.emb + self.n_actions 28 | else: 29 | q_input_size = self.state_dim + self.args.rnn_hidden_dim + self.n_actions 30 | else: 31 | raise Exception("{} is not a valid QTran architecture".format(self.arch)) 32 | 33 | if self.args.network_size == "small": 34 | self.Q = nn.Sequential(nn.Linear(q_input_size, self.embed_dim), 35 | nn.ReLU(), 36 | nn.Linear(self.embed_dim, self.embed_dim), 37 | nn.ReLU(), 38 | nn.Linear(self.embed_dim, 1)) 39 | 40 | # V(s) 41 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 42 | nn.ReLU(), 43 | nn.Linear(self.embed_dim, self.embed_dim), 44 | nn.ReLU(), 45 | nn.Linear(self.embed_dim, 1)) 46 | 47 | if self.args.agent not in ['updet', 'transformer_aggregation']: 48 | ae_input = self.args.rnn_hidden_dim + self.n_actions 49 | else: 50 | ae_input = self.args.emb + self.n_actions 51 | 52 | self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), 53 | nn.ReLU(), 54 | nn.Linear(ae_input, ae_input)) 55 | elif self.args.network_size == "big": 56 | self.Q = nn.Sequential(nn.Linear(q_input_size, self.embed_dim), 57 | nn.ReLU(), 58 | nn.Linear(self.embed_dim, self.embed_dim), 59 | nn.ReLU(), 60 | nn.Linear(self.embed_dim, self.embed_dim), 61 | nn.ReLU(), 62 | nn.Linear(self.embed_dim, 1)) 63 | # V(s) 64 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 65 | nn.ReLU(), 66 | nn.Linear(self.embed_dim, self.embed_dim), 67 | nn.ReLU(), 68 | nn.Linear(self.embed_dim, self.embed_dim), 69 | nn.ReLU(), 70 | nn.Linear(self.embed_dim, 1)) 71 | if self.args.agent not in ['updet', 'transformer_aggregation']: 72 | ae_input = self.args.rnn_hidden_dim + self.n_actions 73 | else: 74 | ae_input = self.args.emb + self.n_actions 75 | self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), 76 | nn.ReLU(), 77 | nn.Linear(ae_input, ae_input)) 78 | else: 79 | assert False 80 | 81 | def forward(self, batch, hidden_states, actions=None): 82 | bs = batch.batch_size 83 | ts = batch.max_seq_length 84 | 85 | states = batch["state"].reshape(bs * ts, self.state_dim) 86 | 87 | if self.arch == "coma_critic": 88 | if actions is None: 89 | # Use the actions taken by the agents 90 | actions = batch["actions_onehot"].reshape(bs * ts, self.n_agents * self.n_actions) 91 | else: 92 | # It will arrive as (bs, ts, agents, actions), we need to reshape it 93 | actions = actions.reshape(bs * ts, self.n_agents * self.n_actions) 94 | inputs = th.cat([states, actions], dim=1) 95 | elif self.arch == "qtran_paper": 96 | if actions is None: 97 | # Use the actions taken by the agents 98 | actions = batch["actions_onehot"].reshape(bs * ts, self.n_agents, self.n_actions) 99 | else: 100 | # It will arrive as (bs, ts, agents, actions), we need to reshape it 101 | actions = actions.reshape(bs * ts, self.n_agents, self.n_actions) 102 | 103 | hidden_states = hidden_states.reshape(bs * ts, self.n_agents, -1) 104 | agent_state_action_input = th.cat([hidden_states, actions], dim=2) 105 | agent_state_action_encoding = self.action_encoding(agent_state_action_input.reshape(bs * ts * self.n_agents, -1)).reshape(bs * ts, self.n_agents, -1) 106 | agent_state_action_encoding = agent_state_action_encoding.sum(dim=1) # Sum across agents 107 | 108 | inputs = th.cat([states, agent_state_action_encoding], dim=1) 109 | 110 | q_outputs = self.Q(inputs) 111 | 112 | states = batch["state"].reshape(bs * ts, self.state_dim) 113 | v_outputs = self.V(states) 114 | 115 | return q_outputs, v_outputs 116 | 117 | -------------------------------------------------------------------------------- /src/modules/mixers/vdn.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | 5 | class VDNMixer(nn.Module): 6 | def __init__(self): 7 | super(VDNMixer, self).__init__() 8 | 9 | def forward(self, agent_qs, batch): 10 | return th.sum(agent_qs, dim=2, keepdim=True) -------------------------------------------------------------------------------- /src/run.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import pprint 4 | import time 5 | import threading 6 | import torch as th 7 | from types import SimpleNamespace as SN 8 | from utils.logging import Logger 9 | from utils.timehelper import time_left, time_str 10 | from os.path import dirname, abspath 11 | 12 | from learners import REGISTRY as le_REGISTRY 13 | from runners import REGISTRY as r_REGISTRY 14 | from controllers import REGISTRY as mac_REGISTRY 15 | from components.episode_buffer import ReplayBuffer 16 | from components.transforms import OneHot 17 | 18 | 19 | def run(_run, _config, _log): 20 | 21 | # check args sanity 22 | _config = args_sanity_check(_config, _log) 23 | 24 | args = SN(**_config) 25 | args.device = "cuda" if args.use_cuda else "cpu" 26 | 27 | # setup loggers 28 | logger = Logger(_log) 29 | 30 | _log.info("Experiment Parameters:") 31 | experiment_params = pprint.pformat(_config, 32 | indent=4, 33 | width=1) 34 | _log.info("\n\n" + experiment_params + "\n") 35 | 36 | # configure tensorboard logger 37 | unique_token = "{}-{}-{}-{}-dim-{}-heads-{}-depth".format(args.name, args.agent, 38 | args.env_args['map_name'], 39 | args.emb, args.heads, 40 | args.depth) 41 | 42 | unique_token += "-seed-{}".format(_config["seed"]) 43 | 44 | args.unique_token = unique_token 45 | if args.use_tensorboard: 46 | tb_logs_direc = os.path.join(dirname(dirname(abspath(__file__))), "results", "tb_logs") 47 | tb_exp_direc = os.path.join(tb_logs_direc, "{}").format(unique_token) 48 | logger.setup_tb(tb_exp_direc) 49 | 50 | # sacred is on by default 51 | logger.setup_sacred(_run) 52 | 53 | # Run and train 54 | run_sequential(args=args, logger=logger) 55 | 56 | # Clean up after finishing 57 | print("Exiting Main") 58 | 59 | print("Stopping all threads") 60 | for t in threading.enumerate(): 61 | if t.name != "MainThread": 62 | print("Thread {} is alive! Is daemon: {}".format(t.name, t.daemon)) 63 | t.join(timeout=1) 64 | print("Thread joined") 65 | 66 | print("Exiting script") 67 | 68 | # Making sure framework really exits 69 | os._exit(os.EX_OK) 70 | 71 | 72 | def evaluate_sequential(args, runner): 73 | 74 | for _ in range(args.test_nepisode): 75 | runner.run(test_mode=True) 76 | 77 | if args.save_replay: 78 | runner.save_replay() 79 | 80 | runner.close_env() 81 | 82 | def run_sequential(args, logger): 83 | 84 | # Init runner so we can get env info 85 | runner = r_REGISTRY[args.runner](args=args, logger=logger) 86 | 87 | # Set up schemes and groups here 88 | env_info = runner.get_env_info() 89 | args.n_agents = env_info["n_agents"] 90 | args.n_actions = env_info["n_actions"] 91 | args.state_shape = env_info["state_shape"] 92 | 93 | # Default/Base scheme 94 | scheme = { 95 | "state": {"vshape": env_info["state_shape"]}, 96 | "obs": {"vshape": env_info["obs_shape"], "group": "agents"}, 97 | "actions": {"vshape": (1,), "group": "agents", "dtype": th.long}, 98 | "avail_actions": {"vshape": (env_info["n_actions"],), "group": "agents", "dtype": th.int}, 99 | "reward": {"vshape": (1,)}, 100 | "terminated": {"vshape": (1,), "dtype": th.uint8}, 101 | } 102 | groups = { 103 | "agents": args.n_agents 104 | } 105 | preprocess = { 106 | "actions": ("actions_onehot", [OneHot(out_dim=args.n_actions)]) 107 | } 108 | 109 | buffer = ReplayBuffer(scheme, groups, args.buffer_size, env_info["episode_limit"] + 1, 110 | preprocess=preprocess, 111 | device="cpu" if args.buffer_cpu_only else args.device) 112 | 113 | # Setup multiagent controller here 114 | mac = mac_REGISTRY[args.mac](buffer.scheme, groups, args) 115 | 116 | # Give runner the scheme 117 | runner.setup(scheme=scheme, groups=groups, preprocess=preprocess, mac=mac) 118 | 119 | # Learner 120 | learner = le_REGISTRY[args.learner](mac, buffer.scheme, logger, args) 121 | 122 | if args.use_cuda: 123 | learner.cuda() 124 | 125 | if args.checkpoint_path != "": 126 | 127 | timesteps = [] 128 | timestep_to_load = 0 129 | 130 | if not os.path.isdir(args.checkpoint_path): 131 | logger.console_logger.info("Checkpoint directiory {} doesn't exist".format(args.checkpoint_path)) 132 | return 133 | 134 | # Go through all files in args.checkpoint_path 135 | for name in os.listdir(args.checkpoint_path): 136 | full_name = os.path.join(args.checkpoint_path, name) 137 | # Check if they are dirs the names of which are numbers 138 | if os.path.isdir(full_name) and name.isdigit(): 139 | timesteps.append(int(name)) 140 | 141 | if args.load_step == 0: 142 | # choose the max timestep 143 | timestep_to_load = max(timesteps) 144 | else: 145 | # choose the timestep closest to load_step 146 | timestep_to_load = min(timesteps, key=lambda x: abs(x - args.load_step)) 147 | 148 | model_path = os.path.join(args.checkpoint_path, str(timestep_to_load)) 149 | 150 | logger.console_logger.info("Loading model from {}".format(model_path)) 151 | learner.load_models(model_path) 152 | runner.t_env = timestep_to_load 153 | 154 | if args.evaluate or args.save_replay: 155 | evaluate_sequential(args, runner) 156 | return 157 | 158 | # start training 159 | episode = 0 160 | last_test_T = -args.test_interval - 1 161 | last_log_T = 0 162 | model_save_time = 0 163 | 164 | start_time = time.time() 165 | last_time = start_time 166 | 167 | logger.console_logger.info("Beginning training for {} timesteps".format(args.t_max)) 168 | 169 | while runner.t_env <= args.t_max: 170 | 171 | # Run for a whole episode at a time 172 | episode_batch = runner.run(test_mode=False) 173 | buffer.insert_episode_batch(episode_batch) 174 | 175 | if buffer.can_sample(args.batch_size): 176 | episode_sample = buffer.sample(args.batch_size) 177 | 178 | # Truncate batch to only filled timesteps 179 | max_ep_t = episode_sample.max_t_filled() 180 | episode_sample = episode_sample[:, :max_ep_t] 181 | 182 | if episode_sample.device != args.device: 183 | episode_sample.to(args.device) 184 | 185 | learner.train(episode_sample, runner.t_env, episode) 186 | 187 | # Execute test runs once in a while 188 | n_test_runs = max(1, args.test_nepisode // runner.batch_size) 189 | if (runner.t_env - last_test_T) / args.test_interval >= 1.0: 190 | 191 | logger.console_logger.info("t_env: {} / {}".format(runner.t_env, args.t_max)) 192 | logger.console_logger.info("Estimated time left: {}. Time passed: {}".format( 193 | time_left(last_time, last_test_T, runner.t_env, args.t_max), time_str(time.time() - start_time))) 194 | last_time = time.time() 195 | 196 | last_test_T = runner.t_env 197 | for _ in range(n_test_runs): 198 | runner.run(test_mode=True) 199 | 200 | if args.save_model and (runner.t_env - model_save_time >= args.save_model_interval or model_save_time == 0): 201 | model_save_time = runner.t_env 202 | save_path = os.path.join(args.local_results_path, "models", args.unique_token, str(runner.t_env)) 203 | #"results/models/{}".format(unique_token) 204 | os.makedirs(save_path, exist_ok=True) 205 | logger.console_logger.info("Saving models to {}".format(save_path)) 206 | 207 | # learner should handle saving/loading -- delegate actor save/load to mac, 208 | # use appropriate filenames to do critics, optimizer states 209 | learner.save_models(save_path) 210 | 211 | episode += args.batch_size_run 212 | 213 | if (runner.t_env - last_log_T) >= args.log_interval: 214 | logger.log_stat("episode", episode, runner.t_env) 215 | logger.print_recent_stats() 216 | last_log_T = runner.t_env 217 | 218 | runner.close_env() 219 | logger.console_logger.info("Finished Training") 220 | 221 | 222 | def args_sanity_check(config, _log): 223 | 224 | # set CUDA flags 225 | # config["use_cuda"] = True # Use cuda whenever possible! 226 | if config["use_cuda"] and not th.cuda.is_available(): 227 | config["use_cuda"] = False 228 | _log.warning("CUDA flag use_cuda was switched OFF automatically because no CUDA devices are available!") 229 | 230 | if config["test_nepisode"] < config["batch_size_run"]: 231 | config["test_nepisode"] = config["batch_size_run"] 232 | else: 233 | config["test_nepisode"] = (config["test_nepisode"]//config["batch_size_run"]) * config["batch_size_run"] 234 | 235 | return config 236 | -------------------------------------------------------------------------------- /src/runners/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .episode_runner import EpisodeRunner 4 | REGISTRY["episode"] = EpisodeRunner 5 | 6 | from .parallel_runner import ParallelRunner 7 | REGISTRY["parallel"] = ParallelRunner 8 | -------------------------------------------------------------------------------- /src/runners/episode_runner.py: -------------------------------------------------------------------------------- 1 | from envs import REGISTRY as env_REGISTRY 2 | from functools import partial 3 | from components.episode_buffer import EpisodeBatch 4 | import numpy as np 5 | 6 | 7 | class EpisodeRunner: 8 | 9 | def __init__(self, args, logger): 10 | self.args = args 11 | self.logger = logger 12 | self.batch_size = self.args.batch_size_run 13 | assert self.batch_size == 1 14 | 15 | self.env = env_REGISTRY[self.args.env](**self.args.env_args) 16 | self.episode_limit = self.env.episode_limit 17 | self.t = 0 18 | 19 | self.t_env = 0 20 | 21 | self.train_returns = [] 22 | self.test_returns = [] 23 | self.train_stats = {} 24 | self.test_stats = {} 25 | 26 | # Log the first run 27 | self.log_train_stats_t = -1000000 28 | 29 | def setup(self, scheme, groups, preprocess, mac): 30 | self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, 31 | preprocess=preprocess, device=self.args.device) 32 | self.mac = mac 33 | 34 | def get_env_info(self): 35 | return self.env.get_env_info() 36 | 37 | def save_replay(self): 38 | self.env.save_replay() 39 | 40 | def close_env(self): 41 | self.env.close() 42 | 43 | def reset(self): 44 | self.batch = self.new_batch() 45 | self.env.reset() 46 | self.t = 0 47 | 48 | def run(self, test_mode=False): 49 | self.reset() 50 | 51 | terminated = False 52 | episode_return = 0 53 | self.mac.init_hidden(batch_size=self.batch_size) 54 | 55 | while not terminated: 56 | 57 | pre_transition_data = { 58 | "state": [self.env.get_state()], 59 | "avail_actions": [self.env.get_avail_actions()], 60 | "obs": [self.env.get_obs()] 61 | } 62 | 63 | self.batch.update(pre_transition_data, ts=self.t) 64 | 65 | # Pass the entire batch of experiences up till now to the agents 66 | # Receive the actions for each agent at this timestep in a batch of size 1 67 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode) 68 | 69 | reward, terminated, env_info = self.env.step(actions[0]) 70 | episode_return += reward 71 | 72 | post_transition_data = { 73 | "actions": actions, 74 | "reward": [(reward,)], 75 | "terminated": [(terminated != env_info.get("episode_limit", False),)], 76 | } 77 | 78 | self.batch.update(post_transition_data, ts=self.t) 79 | 80 | self.t += 1 81 | 82 | last_data = { 83 | "state": [self.env.get_state()], 84 | "avail_actions": [self.env.get_avail_actions()], 85 | "obs": [self.env.get_obs()] 86 | } 87 | self.batch.update(last_data, ts=self.t) 88 | 89 | # Select actions in the last stored state 90 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode) 91 | self.batch.update({"actions": actions}, ts=self.t) 92 | 93 | cur_stats = self.test_stats if test_mode else self.train_stats 94 | cur_returns = self.test_returns if test_mode else self.train_returns 95 | log_prefix = "test_" if test_mode else "" 96 | cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) | set(env_info)}) 97 | cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0) 98 | cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0) 99 | 100 | if not test_mode: 101 | self.t_env += self.t 102 | 103 | cur_returns.append(episode_return) 104 | 105 | if test_mode and (len(self.test_returns) == self.args.test_nepisode): 106 | self._log(cur_returns, cur_stats, log_prefix) 107 | elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval: 108 | self._log(cur_returns, cur_stats, log_prefix) 109 | if hasattr(self.mac.action_selector, "epsilon"): 110 | self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env) 111 | self.log_train_stats_t = self.t_env 112 | 113 | return self.batch 114 | 115 | def _log(self, returns, stats, prefix): 116 | self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env) 117 | self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env) 118 | returns.clear() 119 | 120 | for k, v in stats.items(): 121 | if k != "n_episodes": 122 | self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env) 123 | stats.clear() 124 | -------------------------------------------------------------------------------- /src/runners/parallel_runner.py: -------------------------------------------------------------------------------- 1 | from envs import REGISTRY as env_REGISTRY 2 | from functools import partial 3 | from components.episode_buffer import EpisodeBatch 4 | from multiprocessing import Pipe, Process 5 | import numpy as np 6 | import torch as th 7 | 8 | 9 | # Based (very) heavily on SubprocVecEnv from OpenAI Baselines 10 | # https://github.com/openai/baselines/blob/master/baselines/common/vec_env/subproc_vec_env.py 11 | class ParallelRunner: 12 | 13 | def __init__(self, args, logger): 14 | self.args = args 15 | self.logger = logger 16 | self.batch_size = self.args.batch_size_run 17 | 18 | # Make subprocesses for the envs 19 | self.parent_conns, self.worker_conns = zip(*[Pipe() for _ in range(self.batch_size)]) 20 | env_fn = env_REGISTRY[self.args.env] 21 | self.ps = [Process(target=env_worker, args=(worker_conn, CloudpickleWrapper(partial(env_fn, **self.args.env_args)))) 22 | for worker_conn in self.worker_conns] 23 | 24 | for p in self.ps: 25 | p.daemon = True 26 | p.start() 27 | 28 | self.parent_conns[0].send(("get_env_info", None)) 29 | self.env_info = self.parent_conns[0].recv() 30 | self.episode_limit = self.env_info["episode_limit"] 31 | 32 | self.t = 0 33 | 34 | self.t_env = 0 35 | 36 | self.train_returns = [] 37 | self.test_returns = [] 38 | self.train_stats = {} 39 | self.test_stats = {} 40 | 41 | self.log_train_stats_t = -100000 42 | 43 | def setup(self, scheme, groups, preprocess, mac): 44 | self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1, 45 | preprocess=preprocess, device=self.args.device) 46 | self.mac = mac 47 | self.scheme = scheme 48 | self.groups = groups 49 | self.preprocess = preprocess 50 | 51 | def get_env_info(self): 52 | return self.env_info 53 | 54 | def save_replay(self): 55 | pass 56 | 57 | def close_env(self): 58 | for parent_conn in self.parent_conns: 59 | parent_conn.send(("close", None)) 60 | 61 | def reset(self): 62 | self.batch = self.new_batch() 63 | 64 | # Reset the envs 65 | for parent_conn in self.parent_conns: 66 | parent_conn.send(("reset", None)) 67 | 68 | pre_transition_data = { 69 | "state": [], 70 | "avail_actions": [], 71 | "obs": [] 72 | } 73 | # Get the obs, state and avail_actions back 74 | for parent_conn in self.parent_conns: 75 | data = parent_conn.recv() 76 | pre_transition_data["state"].append(data["state"]) 77 | pre_transition_data["avail_actions"].append(data["avail_actions"]) 78 | pre_transition_data["obs"].append(data["obs"]) 79 | 80 | self.batch.update(pre_transition_data, ts=0) 81 | 82 | self.t = 0 83 | self.env_steps_this_run = 0 84 | 85 | def run(self, test_mode=False): 86 | self.reset() 87 | 88 | all_terminated = False 89 | episode_returns = [0 for _ in range(self.batch_size)] 90 | episode_lengths = [0 for _ in range(self.batch_size)] 91 | self.mac.init_hidden(batch_size=self.batch_size) 92 | terminated = [False for _ in range(self.batch_size)] 93 | envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed] 94 | final_env_infos = [] # may store extra stats like battle won. this is filled in ORDER OF TERMINATION 95 | 96 | while True: 97 | 98 | # Pass the entire batch of experiences up till now to the agents 99 | # Receive the actions for each agent at this timestep in a batch for each un-terminated env 100 | actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, bs=envs_not_terminated, test_mode=test_mode) 101 | cpu_actions = actions.to("cpu").numpy() 102 | 103 | # Update the actions taken 104 | actions_chosen = { 105 | "actions": actions.unsqueeze(1) 106 | } 107 | self.batch.update(actions_chosen, bs=envs_not_terminated, ts=self.t, mark_filled=False) 108 | 109 | # Send actions to each env 110 | action_idx = 0 111 | for idx, parent_conn in enumerate(self.parent_conns): 112 | if idx in envs_not_terminated: # We produced actions for this env 113 | if not terminated[idx]: # Only send the actions to the env if it hasn't terminated 114 | parent_conn.send(("step", cpu_actions[action_idx])) 115 | action_idx += 1 # actions is not a list over every env 116 | 117 | # Update envs_not_terminated 118 | envs_not_terminated = [b_idx for b_idx, termed in enumerate(terminated) if not termed] 119 | all_terminated = all(terminated) 120 | if all_terminated: 121 | break 122 | 123 | # Post step data we will insert for the current timestep 124 | post_transition_data = { 125 | "reward": [], 126 | "terminated": [] 127 | } 128 | # Data for the next step we will insert in order to select an action 129 | pre_transition_data = { 130 | "state": [], 131 | "avail_actions": [], 132 | "obs": [] 133 | } 134 | 135 | # Receive data back for each unterminated env 136 | for idx, parent_conn in enumerate(self.parent_conns): 137 | if not terminated[idx]: 138 | data = parent_conn.recv() 139 | # Remaining data for this current timestep 140 | post_transition_data["reward"].append((data["reward"],)) 141 | 142 | episode_returns[idx] += data["reward"] 143 | episode_lengths[idx] += 1 144 | if not test_mode: 145 | self.env_steps_this_run += 1 146 | 147 | env_terminated = False 148 | if data["terminated"]: 149 | final_env_infos.append(data["info"]) 150 | if data["terminated"] and not data["info"].get("episode_limit", False): 151 | env_terminated = True 152 | terminated[idx] = data["terminated"] 153 | post_transition_data["terminated"].append((env_terminated,)) 154 | 155 | # Data for the next timestep needed to select an action 156 | pre_transition_data["state"].append(data["state"]) 157 | pre_transition_data["avail_actions"].append(data["avail_actions"]) 158 | pre_transition_data["obs"].append(data["obs"]) 159 | 160 | # Add post_transiton data into the batch 161 | self.batch.update(post_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=False) 162 | 163 | # Move onto the next timestep 164 | self.t += 1 165 | 166 | # Add the pre-transition data 167 | self.batch.update(pre_transition_data, bs=envs_not_terminated, ts=self.t, mark_filled=True) 168 | 169 | if not test_mode: 170 | self.t_env += self.env_steps_this_run 171 | 172 | # Get stats back for each env 173 | for parent_conn in self.parent_conns: 174 | parent_conn.send(("get_stats",None)) 175 | 176 | env_stats = [] 177 | for parent_conn in self.parent_conns: 178 | env_stat = parent_conn.recv() 179 | env_stats.append(env_stat) 180 | 181 | cur_stats = self.test_stats if test_mode else self.train_stats 182 | cur_returns = self.test_returns if test_mode else self.train_returns 183 | log_prefix = "test_" if test_mode else "" 184 | infos = [cur_stats] + final_env_infos 185 | cur_stats.update({k: sum(d.get(k, 0) for d in infos) for k in set.union(*[set(d) for d in infos])}) 186 | cur_stats["n_episodes"] = self.batch_size + cur_stats.get("n_episodes", 0) 187 | cur_stats["ep_length"] = sum(episode_lengths) + cur_stats.get("ep_length", 0) 188 | 189 | cur_returns.extend(episode_returns) 190 | 191 | n_test_runs = max(1, self.args.test_nepisode // self.batch_size) * self.batch_size 192 | if test_mode and (len(self.test_returns) == n_test_runs): 193 | self._log(cur_returns, cur_stats, log_prefix) 194 | elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval: 195 | self._log(cur_returns, cur_stats, log_prefix) 196 | if hasattr(self.mac.action_selector, "epsilon"): 197 | self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env) 198 | self.log_train_stats_t = self.t_env 199 | 200 | return self.batch 201 | 202 | def _log(self, returns, stats, prefix): 203 | self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env) 204 | self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env) 205 | returns.clear() 206 | 207 | for k, v in stats.items(): 208 | if k != "n_episodes": 209 | self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env) 210 | stats.clear() 211 | 212 | 213 | def env_worker(remote, env_fn): 214 | # Make environment 215 | env = env_fn.x() 216 | while True: 217 | cmd, data = remote.recv() 218 | if cmd == "step": 219 | actions = data 220 | # Take a step in the environment 221 | reward, terminated, env_info = env.step(actions) 222 | # Return the observations, avail_actions and state to make the next action 223 | state = env.get_state() 224 | avail_actions = env.get_avail_actions() 225 | obs = env.get_obs() 226 | remote.send({ 227 | # Data for the next timestep needed to pick an action 228 | "state": state, 229 | "avail_actions": avail_actions, 230 | "obs": obs, 231 | # Rest of the data for the current timestep 232 | "reward": reward, 233 | "terminated": terminated, 234 | "info": env_info 235 | }) 236 | elif cmd == "reset": 237 | env.reset() 238 | remote.send({ 239 | "state": env.get_state(), 240 | "avail_actions": env.get_avail_actions(), 241 | "obs": env.get_obs() 242 | }) 243 | elif cmd == "close": 244 | env.close() 245 | remote.close() 246 | break 247 | elif cmd == "get_env_info": 248 | remote.send(env.get_env_info()) 249 | elif cmd == "get_stats": 250 | remote.send(env.get_stats()) 251 | else: 252 | raise NotImplementedError 253 | 254 | 255 | class CloudpickleWrapper(): 256 | """ 257 | Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) 258 | """ 259 | def __init__(self, x): 260 | self.x = x 261 | def __getstate__(self): 262 | import cloudpickle 263 | return cloudpickle.dumps(self.x) 264 | def __setstate__(self, ob): 265 | import pickle 266 | self.x = pickle.loads(ob) 267 | 268 | -------------------------------------------------------------------------------- /src/utils/dict2namedtuple.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | def convert(dictionary): 5 | return namedtuple('GenericDict', dictionary.keys())(**dictionary) 6 | -------------------------------------------------------------------------------- /src/utils/logging.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import logging 3 | import numpy as np 4 | 5 | class Logger: 6 | def __init__(self, console_logger): 7 | self.console_logger = console_logger 8 | 9 | self.use_tb = False 10 | self.use_sacred = False 11 | self.use_hdf = False 12 | 13 | self.stats = defaultdict(lambda: []) 14 | 15 | def setup_tb(self, directory_name): 16 | # Import here so it doesn't have to be installed if you don't use it 17 | from tensorboard_logger import configure, log_value 18 | configure(directory_name) 19 | self.tb_logger = log_value 20 | self.use_tb = True 21 | 22 | def setup_sacred(self, sacred_run_dict): 23 | self.sacred_info = sacred_run_dict.info 24 | self.use_sacred = True 25 | 26 | def log_stat(self, key, value, t, to_sacred=True): 27 | self.stats[key].append((t, value)) 28 | 29 | if self.use_tb: 30 | self.tb_logger(key, value, t) 31 | 32 | if self.use_sacred and to_sacred: 33 | if key in self.sacred_info: 34 | self.sacred_info["{}_T".format(key)].append(t) 35 | self.sacred_info[key].append(value) 36 | else: 37 | self.sacred_info["{}_T".format(key)] = [t] 38 | self.sacred_info[key] = [value] 39 | 40 | def print_recent_stats(self): 41 | log_str = "Recent Stats | t_env: {:>10} | Episode: {:>8}\n".format(*self.stats["episode"][-1]) 42 | i = 0 43 | for (k, v) in sorted(self.stats.items()): 44 | if k == "episode": 45 | continue 46 | i += 1 47 | window = 5 if k != "epsilon" else 1 48 | item = "{:.4f}".format(np.mean([x[1] for x in self.stats[k][-window:]])) 49 | log_str += "{:<25}{:>8}".format(k + ":", item) 50 | log_str += "\n" if i % 4 == 0 else "\t" 51 | self.console_logger.info(log_str) 52 | 53 | 54 | # set up a custom logger 55 | def get_logger(): 56 | logger = logging.getLogger() 57 | logger.handlers = [] 58 | ch = logging.StreamHandler() 59 | formatter = logging.Formatter('[%(levelname)s %(asctime)s] %(name)s %(message)s', '%H:%M:%S') 60 | ch.setFormatter(formatter) 61 | logger.addHandler(ch) 62 | logger.setLevel('DEBUG') 63 | 64 | return logger 65 | 66 | -------------------------------------------------------------------------------- /src/utils/rl_utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | def build_td_lambda_targets(rewards, terminated, mask, target_qs, n_agents, gamma, td_lambda): 5 | # Assumes in B*T*A and , , in (at least) B*T-1*1 6 | # Initialise last lambda -return for not terminated episodes 7 | ret = target_qs.new_zeros(*target_qs.shape) 8 | ret[:, -1] = target_qs[:, -1] * (1 - th.sum(terminated, dim=1)) 9 | # Backwards recursive update of the "forward view" 10 | for t in range(ret.shape[1] - 2, -1, -1): 11 | ret[:, t] = td_lambda * gamma * ret[:, t + 1] + mask[:, t] \ 12 | * (rewards[:, t] + (1 - td_lambda) * gamma * target_qs[:, t + 1] * (1 - terminated[:, t])) 13 | # Returns lambda-return from t=0 to t=T-1, i.e. in B*T-1*A 14 | return ret[:, 0:-1] 15 | 16 | -------------------------------------------------------------------------------- /src/utils/timehelper.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | def print_time(start_time, T, t_max, episode, episode_rewards): 6 | time_elapsed = time.time() - start_time 7 | T = max(1, T) 8 | time_left = time_elapsed * (t_max - T) / T 9 | # Just in case its over 100 days 10 | time_left = min(time_left, 60 * 60 * 24 * 100) 11 | last_reward = "N\A" 12 | if len(episode_rewards) > 5: 13 | last_reward = "{:.2f}".format(np.mean(episode_rewards[-50:])) 14 | print("\033[F\033[F\x1b[KEp: {:,}, T: {:,}/{:,}, Reward: {}, \n\x1b[KElapsed: {}, Left: {}\n".format(episode, T, t_max, last_reward, time_str(time_elapsed), time_str(time_left)), " " * 10, end="\r") 15 | 16 | 17 | def time_left(start_time, t_start, t_current, t_max): 18 | if t_current >= t_max: 19 | return "-" 20 | time_elapsed = time.time() - start_time 21 | t_current = max(1, t_current) 22 | time_left = time_elapsed * (t_max - t_current) / (t_current - t_start) 23 | # Just in case its over 100 days 24 | time_left = min(time_left, 60 * 60 * 24 * 100) 25 | return time_str(time_left) 26 | 27 | 28 | def time_str(s): 29 | """ 30 | Convert seconds to a nicer string showing days, hours, minutes and seconds 31 | """ 32 | days, remainder = divmod(s, 60 * 60 * 24) 33 | hours, remainder = divmod(remainder, 60 * 60) 34 | minutes, seconds = divmod(remainder, 60) 35 | string = "" 36 | if days > 0: 37 | string += "{:d} days, ".format(int(days)) 38 | if hours > 0: 39 | string += "{:d} hours, ".format(int(hours)) 40 | if minutes > 0: 41 | string += "{:d} minutes, ".format(int(minutes)) 42 | string += "{:d} seconds".format(int(seconds)) 43 | return string 44 | --------------------------------------------------------------------------------