├── marlbase ├── __init__.py ├── ac │ ├── __init__.py │ ├── eval.py │ ├── train.py │ └── model.py ├── dqn │ ├── __init__.py │ ├── eval.py │ ├── train.py │ └── model.py ├── utils │ ├── __init__.py │ ├── video.py │ ├── smaclite_wrapper.py │ ├── standardise_stream.py │ ├── postprocessing │ │ ├── find_best_hyperparams.py │ │ ├── plot_runs.py │ │ ├── hiplot_fetcher.py │ │ ├── export_multirun.py │ │ └── load_data.py │ ├── utils.py │ ├── envs.py │ ├── wrappers.py │ ├── loggers.py │ ├── stats.py │ └── models.py ├── configs │ ├── logger │ │ ├── basic.yaml │ │ ├── wandb.yaml │ │ └── filesystemlogger.yaml │ ├── algorithm │ │ ├── vdn.yaml │ │ ├── qmix.yaml │ │ ├── ia2c.yaml │ │ ├── maa2c.yaml │ │ ├── ippo.yaml │ │ ├── mappo.yaml │ │ └── idqn.yaml │ ├── eval.yaml │ ├── hydra │ │ └── job_logging │ │ │ ├── file.yaml │ │ │ └── console.yaml │ ├── default.yaml │ └── sweeps │ │ └── sample.yaml ├── run.py ├── eval.py └── search.py ├── .flake8 ├── requirements.txt ├── setup.py ├── .editorconfig ├── .gitignore └── README.md /marlbase/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marlbase/ac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marlbase/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marlbase/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marlbase/configs/logger/basic.yaml: -------------------------------------------------------------------------------- 1 | _target_: utils.loggers.Logger 2 | project_name: fastmarl -------------------------------------------------------------------------------- /marlbase/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | _target_: utils.loggers.WandbLogger 2 | project_name: fastmarl -------------------------------------------------------------------------------- /marlbase/configs/logger/filesystemlogger.yaml: -------------------------------------------------------------------------------- 1 | _target_: utils.loggers.FileSystemLogger 2 | project_name: fastmarl 3 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E226,E302,E41, F841 3 | max-line-length = 89 4 | exclude = tests/* 5 | max-complexity = 10 -------------------------------------------------------------------------------- /marlbase/configs/algorithm/vdn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - idqn 5 | 6 | env: 7 | wrappers : 8 | - CooperativeReward 9 | 10 | algorithm: 11 | name: "vdn" 12 | model: 13 | _target_: dqn.model.VDNetwork 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click==8.1.7 2 | einops==0.8.0 3 | gymnasium<1.0.0 4 | hydra-core==1.3.2 5 | hydra-submitit-launcher==1.2.0 6 | hydra-ax-sweeper==1.1.5 7 | imageio==2.9.0 8 | imageio-ffmpeg==0.5.1 9 | matplotlib==3.9.2 10 | munch==4.0.0 11 | pandas==2.2.2 12 | seaborn==0.13.2 13 | torch==2.4.0 14 | pyyaml==6.0.2 15 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/qmix.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - idqn 5 | 6 | env: 7 | wrappers : 8 | - CooperativeReward 9 | 10 | algorithm: 11 | name: "qmix" 12 | model: 13 | _target_: dqn.model.QMixNetwork 14 | mixing: 15 | embed_dim: 64 16 | hypernet_layers: 2 17 | hypernet_embed: 32 18 | -------------------------------------------------------------------------------- /marlbase/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - logger: filesystemlogger 4 | - override hydra/job_logging: file 5 | - override hydra/launcher: submitit_local 6 | 7 | hydra: 8 | run: 9 | dir: evals/${random:4} 10 | launcher: 11 | timeout_min: 2880 12 | cpus_per_task: 1 13 | mem_gb: 4 14 | job: 15 | chdir: True 16 | 17 | path: null 18 | load_step: null 19 | video_frames: 10000 20 | seed: null 21 | 22 | -------------------------------------------------------------------------------- /marlbase/utils/video.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | 3 | 4 | class VideoRecorder: 5 | def __init__(self, fps=30): 6 | self.fps = fps 7 | self.frames = [] 8 | 9 | def reset(self): 10 | self.frames = [] 11 | 12 | def record_frame(self, env): 13 | frame = env.unwrapped.render() 14 | self.frames.append(frame) 15 | 16 | def save(self, filename): 17 | imageio.mimsave(f"{filename}", self.frames, fps=self.fps) 18 | -------------------------------------------------------------------------------- /marlbase/configs/hydra/job_logging/file.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | simple: 4 | format: '(%(process)d) [%(levelname)s] - (%(asctime)s) - %(name)s >> %(message)s' 5 | datefmt: '%m/%d %H:%M:%S' 6 | handlers: 7 | console: 8 | class: logging.StreamHandler 9 | formatter: simple 10 | stream: ext://sys.stdout 11 | file: 12 | class: logging.FileHandler 13 | formatter: simple 14 | filename: ${hydra.job.name}.log 15 | root: 16 | level: INFO 17 | handlers: [console, file] 18 | 19 | disable_existing_loggers: false -------------------------------------------------------------------------------- /marlbase/configs/hydra/job_logging/console.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | version: 1 3 | formatters: 4 | simple: 5 | format: '(%(process)d) [%(levelname)s] - (%(asctime)s) - %(name)s >> %(message)s' 6 | datefmt: '%m/%d %H:%M:%S' 7 | handlers: 8 | console: 9 | class: logging.StreamHandler 10 | formatter: simple 11 | stream: ext://sys.stdout 12 | file: 13 | class: logging.FileHandler 14 | formatter: simple 15 | filename: ${hydra.job.name}.log 16 | root: 17 | level: INFO 18 | handlers: [console] 19 | 20 | disable_existing_loggers: false -------------------------------------------------------------------------------- /marlbase/ac/eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | import torch 4 | 5 | from marlbase.ac.train import record_episodes 6 | 7 | 8 | def main(env, ckpt_path, **cfg): 9 | cfg = DictConfig(cfg) 10 | 11 | model = hydra.utils.instantiate( 12 | cfg.model, env.observation_space, env.action_space, cfg 13 | ) 14 | print(f"Loading model from {ckpt_path}") 15 | state_dict = torch.load(ckpt_path, weights_only=True) 16 | model.load_state_dict(state_dict) 17 | 18 | record_episodes( 19 | env, 20 | model, 21 | cfg.video_frames, 22 | "./eval.mp4", 23 | cfg.model.device, 24 | ) 25 | 26 | env.close() 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="marlbase", 5 | version="0.1.0", 6 | description="Fast Multi-Agent RL: a collection of algorithms applied in multi-agent environments", 7 | author="Filippos Christianos, Lukas Schäfer", 8 | url="https://github.com/marl-book/codebase", 9 | packages=find_packages(exclude=["contrib", "docs", "tests"]), 10 | classifiers=[ 11 | "Intended Audience :: Developers", 12 | "Programming Language :: Python :: 3.10", 13 | ], 14 | install_requires=["hydra-core>=1.1", "torch", "cpprb", "einops"], 15 | extras_require={"test": ["pytest"]}, 16 | include_package_data=True, 17 | ) 18 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig is awesome: https://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | end_of_line = lf 9 | insert_final_newline = true 10 | 11 | # Matches multiple files with brace expansion notation 12 | # Set default charset 13 | [*.{js,py}] 14 | charset = utf-8 15 | 16 | # 4 space indentation 17 | [*.py] 18 | indent_style = space 19 | indent_size = 4 20 | 21 | # Tab indentation (no size specified) 22 | [Makefile] 23 | indent_style = tab 24 | 25 | # Matches the exact files either package.json or .travis.yml 26 | [{package.json,.travis.yml}] 27 | indent_style = space 28 | indent_size = 2 29 | -------------------------------------------------------------------------------- /marlbase/dqn/eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | import torch 4 | 5 | from marlbase.dqn.train import record_episodes 6 | 7 | 8 | def main(env, ckpt_path, **cfg): 9 | cfg = DictConfig(cfg) 10 | 11 | model = hydra.utils.instantiate( 12 | cfg.model, env.observation_space, env.action_space, cfg 13 | ) 14 | print(f"Loading model from {ckpt_path}") 15 | state_dict = torch.load(ckpt_path, weights_only=True) 16 | model.load_state_dict(state_dict) 17 | 18 | record_episodes( 19 | env, 20 | model, 21 | cfg.video_frames, 22 | "./eval.mp4", 23 | cfg.model.device, 24 | cfg.eps_evaluation, 25 | ) 26 | 27 | env.close() 28 | -------------------------------------------------------------------------------- /marlbase/configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - logger: filesystemlogger 4 | - override hydra/job_logging: file 5 | - override hydra/launcher: submitit_local 6 | 7 | hydra: 8 | run: 9 | dir: outputs/${env.name}/${algorithm.name}/${random:4} 10 | launcher: 11 | timeout_min: 2880 12 | cpus_per_task: 1 13 | mem_gb: 4 14 | job: 15 | chdir: True 16 | 17 | seed: null 18 | 19 | algorithm: 20 | total_steps: 100_000 21 | log_interval: 10_000 22 | save_interval: False 23 | eval_interval: 10_000 24 | eval_episodes: 100 25 | video_interval: False 26 | video_frames: 500 27 | 28 | env: 29 | _target_: utils.envs.make_env 30 | name : ??? 31 | time_limit : ??? 32 | clear_info: False 33 | observe_id: False 34 | standardise_rewards: False 35 | wrappers : null 36 | -------------------------------------------------------------------------------- /marlbase/configs/sweeps/sample.yaml: -------------------------------------------------------------------------------- 1 | # constants 2 | algorithm.save_interval: False 3 | algorithm.eval_episodes: 10 4 | algorithm.video_interval: False 5 | 6 | # top-level choice (grid) search 7 | algorithm.standardise_returns: 8 | - True 9 | - False 10 | 11 | hparam-tuples-1: 12 | - !!python/tuple [env.name: lbforaging:Foraging-10x10-3p-3f-v3, env.time_limit: 25] 13 | - !!python/tuple [env.name: lbforaging:Foraging-8x8-2p-2f-coop-v3, env.time_limit: 25] 14 | 15 | hparam-tuples-2: 16 | - !!python/tuple 17 | - "+algorithm": "idqn" 18 | - algorithm.total_steps: 2_000_000 19 | - algorithm.batch_size: 20 | - 128 21 | - 256 22 | 23 | - !!python/tuple 24 | - "+algorithm": "ia2c" 25 | - algorithm.total_steps: 2_000_000 26 | - algorithm.entropy_coef: 27 | - 0.01 28 | - 0.001 29 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/ia2c.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | env: 3 | parallel_envs: 10 4 | 5 | algorithm: 6 | _target_: ac.train.main 7 | name: "ia2c" 8 | model: 9 | _target_: ac.model.A2CNetwork 10 | actor: 11 | layers: 12 | - 128 13 | - 128 14 | parameter_sharing: False # True/False/List[int] (seps_indices) 15 | use_orthogonal_init: True 16 | use_rnn: False 17 | critic: 18 | centralised: False 19 | layers: 20 | - 128 21 | - 128 22 | parameter_sharing: False # True/False/List[int] (seps_indices) 23 | use_orthogonal_init: True 24 | use_rnn: False 25 | 26 | device : "cpu" # a pytorch device ("cpu" or "cuda") 27 | 28 | optimizer : "Adam" 29 | lr: 3.e-4 30 | 31 | grad_clip: False 32 | 33 | n_steps: 5 34 | gamma: 0.99 35 | entropy_coef: 0.001 36 | value_loss_coef: 0.5 37 | use_proper_termination: False 38 | standardise_returns: False 39 | 40 | target_update_interval_or_tau: 200 41 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/maa2c.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | env: 3 | parallel_envs: 10 4 | 5 | algorithm: 6 | _target_: ac.train.main 7 | name: "maa2c" 8 | model: 9 | _target_: ac.model.A2CNetwork 10 | actor: 11 | layers: 12 | - 128 13 | - 128 14 | parameter_sharing: False # True/False/List[int] (seps_indices) 15 | use_orthogonal_init: True 16 | use_rnn: False 17 | critic: 18 | centralised: True 19 | layers: 20 | - 128 21 | - 128 22 | parameter_sharing: False # True/False/List[int] (seps_indices) 23 | use_orthogonal_init: True 24 | use_rnn: False 25 | 26 | device : "cpu" # a pytorch device ("cpu" or "cuda") 27 | 28 | optimizer : "Adam" 29 | lr: 3.e-4 30 | 31 | grad_clip: False 32 | 33 | n_steps: 5 34 | gamma: 0.99 35 | entropy_coef: 0.001 36 | value_loss_coef: 0.5 37 | use_proper_termination: False 38 | standardise_returns: False 39 | 40 | target_update_interval_or_tau: 200 41 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/ippo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | env: 3 | parallel_envs: 10 4 | 5 | algorithm: 6 | _target_: ac.train.main 7 | name: "ippo" 8 | model: 9 | _target_: ac.model.PPONetwork 10 | actor: 11 | layers: 12 | - 128 13 | - 128 14 | parameter_sharing: False # True/False/List[int] (seps_indices) 15 | use_orthogonal_init: True 16 | use_rnn: False 17 | critic: 18 | centralised: False 19 | layers: 20 | - 128 21 | - 128 22 | parameter_sharing: False # True/False/List[int] (seps_indices) 23 | use_orthogonal_init: True 24 | use_rnn: False 25 | 26 | device : "cpu" # a pytorch device ("cpu" or "cuda") 27 | 28 | optimizer : "Adam" 29 | lr: 3.e-4 30 | 31 | grad_clip: False 32 | 33 | n_steps: 5 34 | gamma: 0.99 35 | entropy_coef: 0.001 36 | value_loss_coef: 0.5 37 | use_proper_termination: False 38 | standardise_returns: False 39 | 40 | num_epochs: 4 41 | ppo_clip: 0.2 42 | 43 | target_update_interval_or_tau: 200 44 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/mappo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | env: 3 | parallel_envs: 10 4 | 5 | algorithm: 6 | _target_: ac.train.main 7 | name: "mappo" 8 | model: 9 | _target_: ac.model.PPONetwork 10 | actor: 11 | layers: 12 | - 128 13 | - 128 14 | parameter_sharing: False # True/False/List[int] (seps_indices) 15 | use_orthogonal_init: True 16 | use_rnn: False 17 | critic: 18 | centralised: True 19 | layers: 20 | - 128 21 | - 128 22 | parameter_sharing: False # True/False/List[int] (seps_indices) 23 | use_orthogonal_init: True 24 | use_rnn: False 25 | 26 | device : "cpu" # a pytorch device ("cpu" or "cuda") 27 | 28 | optimizer : "Adam" 29 | lr: 3.e-4 30 | 31 | grad_clip: False 32 | 33 | n_steps: 5 34 | gamma: 0.99 35 | entropy_coef: 0.001 36 | value_loss_coef: 0.5 37 | use_proper_termination: False 38 | standardise_returns: False 39 | 40 | num_epochs: 4 41 | ppo_clip: 0.2 42 | 43 | target_update_interval_or_tau: 200 44 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/idqn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | algorithm: 4 | _target_: dqn.train.main 5 | name: "idqn" 6 | model: 7 | _target_: dqn.model.QNetwork 8 | layers: 9 | - 128 10 | - 128 11 | parameter_sharing: False # True/False/List[int] (seps_indices) 12 | use_orthogonal_init: True 13 | use_rnn: False 14 | device: "cpu" # a pytorch device ("cpu" or "cuda") 15 | 16 | training_start: 2000 17 | buffer_size: 10000 # number of *episodes* to store in the replay buffer 18 | 19 | optimizer: "Adam" 20 | lr: 3e-4 21 | gamma: 0.99 22 | batch_size: 32 23 | double_q: True 24 | 25 | grad_clip: 1.0 26 | 27 | use_proper_termination: False # True/ False 28 | standardise_returns: False 29 | 30 | eps_decay_style: "linear" # "linear" or "exponential" 31 | eps_decay_over: 0.5 # fraction of total steps over which to decay epsilon 32 | eps_start: 1.0 33 | eps_end: 0.05 34 | eps_exp_decay_rate: 6.5 # exponential decay rate (ignored for linear decay) 35 | eps_evaluation: 0.05 36 | 37 | target_update_interval_or_tau: 200 38 | -------------------------------------------------------------------------------- /marlbase/utils/smaclite_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gymnasium as gym 3 | 4 | 5 | class SMACliteWrapper(gym.Wrapper): 6 | def __init__(self, env): 7 | super().__init__(env) 8 | self.n_agents = self.env.unwrapped.n_agents 9 | 10 | def step(self, actions): 11 | """Returns obss, reward, terminated, truncated, info""" 12 | actions = [int(act) for act in actions] 13 | obs, reward, terminated, truncated, info = self.env.step(actions) 14 | info["action_mask"] = np.array( 15 | self.env.unwrapped.get_avail_actions(), dtype=np.float32 16 | ) 17 | return obs, [reward] * self.n_agents, terminated, truncated, info 18 | 19 | def reset(self, seed=None, options=None): 20 | """Returns initial observations and info""" 21 | obs, info = self.env.reset(seed=seed, options=options) 22 | info["action_mask"] = np.array( 23 | self.env.unwrapped.get_avail_actions(), dtype=np.float32 24 | ) 25 | return obs, info 26 | 27 | def render(self): 28 | self.env.render() 29 | 30 | def close(self): 31 | self.env.close() 32 | -------------------------------------------------------------------------------- /marlbase/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import numpy as np 5 | from omegaconf import OmegaConf, DictConfig 6 | import torch 7 | 8 | OmegaConf.register_new_resolver( 9 | "random", 10 | lambda x: os.urandom(x).hex(), 11 | ) 12 | 13 | 14 | @hydra.main(config_path="configs", config_name="default", version_base="1.3") 15 | def main(cfg: DictConfig): 16 | logger = hydra.utils.instantiate(cfg.logger, cfg=cfg, _recursive_=False) 17 | 18 | env = hydra.utils.call(cfg.env, seed=cfg.seed) 19 | 20 | # Use singular env for evaluation/ recording 21 | if "parallel_envs" in cfg.env: 22 | del cfg.env.parallel_envs 23 | eval_env = hydra.utils.call( 24 | cfg.env, 25 | enable_video=True if cfg.algorithm.video_interval else False, 26 | seed=cfg.seed, 27 | ) 28 | 29 | torch.set_num_threads(1) 30 | 31 | if cfg.seed is not None: 32 | torch.manual_seed(cfg.seed) 33 | np.random.seed(cfg.seed) 34 | else: 35 | logger.warning("No seed has been set.") 36 | 37 | assert cfg.env.time_limit is not None, "Time limit must be set." 38 | hydra.utils.call( 39 | cfg.algorithm, 40 | env, 41 | eval_env, 42 | logger, 43 | time_limit=cfg.env.time_limit, 44 | _recursive_=False, 45 | ) 46 | 47 | return logger.get_state() 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /marlbase/utils/standardise_stream.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | class RunningMeanStd(object): 7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = (), device: str = "cpu"): 8 | """ 9 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 10 | """ 11 | self.mean = torch.zeros(shape, dtype=torch.float32, device=device) 12 | self.var = torch.ones(shape, dtype=torch.float32, device=device) 13 | self.count = epsilon 14 | 15 | def update(self, arr): 16 | arr = arr.reshape(-1, arr.size(-1)) 17 | batch_mean = torch.mean(arr, dim=0) 18 | batch_var = torch.var(arr, dim=0) 19 | batch_count = arr.shape[0] 20 | self.update_from_moments(batch_mean, batch_var, batch_count) 21 | 22 | def update_from_moments(self, batch_mean, batch_var, batch_count: int): 23 | delta = batch_mean - self.mean 24 | tot_count = self.count + batch_count 25 | 26 | new_mean = self.mean + delta * batch_count / tot_count 27 | m_a = self.var * self.count 28 | m_b = batch_var * batch_count 29 | m_2 = ( 30 | m_a 31 | + m_b 32 | + torch.square(delta) 33 | * self.count 34 | * batch_count 35 | / (self.count + batch_count) 36 | ) 37 | new_var = m_2 / (self.count + batch_count) 38 | 39 | new_count = batch_count + self.count 40 | 41 | self.mean = new_mean 42 | self.var = new_var 43 | self.count = new_count 44 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/find_best_hyperparams.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | from omegaconf import OmegaConf 5 | 6 | from marlbase.utils.postprocessing.load_data import load_and_group_runs 7 | 8 | 9 | DEFAULT_METRIC = "mean_episode_returns" 10 | 11 | 12 | @click.command() 13 | @click.option("--source", type=click.Path(dir_okay=True, writable=False), required=True) 14 | @click.option("--metric", type=str, default=DEFAULT_METRIC) 15 | def run(source, metric): 16 | groups = load_and_group_runs(Path(source)) 17 | assert len(groups) > 0, "No groups found" 18 | 19 | assert all( 20 | group.has_metric(metric) for group in groups 21 | ), f"Metric {metric} not found in all groups" 22 | 23 | envs = set([group.config.env.name for group in groups]) 24 | 25 | for env in envs: 26 | env_groups = [group for group in groups if group.config.env.name == env] 27 | 28 | best_group = None 29 | best_value = -float("inf") 30 | 31 | for group in env_groups: 32 | values = group.get_metric(metric) 33 | mean = values.mean() 34 | if mean > best_value: 35 | best_group = group 36 | best_value = mean 37 | 38 | click.echo( 39 | "Best group for " 40 | + click.style(env, fg="red", bold=True) 41 | + " according to " 42 | + click.style(metric, fg="red", bold=True) 43 | + ": " 44 | + click.style(best_group.name, fg="red", bold=True) 45 | ) 46 | 47 | click.echo(OmegaConf.to_yaml(best_group.config)) 48 | 49 | click.echo(85 * "-" + "\n") 50 | 51 | 52 | if __name__ == "__main__": 53 | run() 54 | -------------------------------------------------------------------------------- /marlbase/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | import numpy as np 6 | from omegaconf import OmegaConf, DictConfig 7 | import torch 8 | 9 | 10 | OmegaConf.register_new_resolver( 11 | "random", 12 | lambda x: os.urandom(x).hex(), 13 | ) 14 | 15 | 16 | @hydra.main(config_path="configs", config_name="eval", version_base="1.3") 17 | def main(cfg: DictConfig): 18 | path = Path(__file__).parent / cfg.path 19 | assert path.exists(), f"Path {path} does not exist." 20 | assert path.is_dir(), f"Path {path} is not a directory." 21 | 22 | config_path = path / "config.yaml" 23 | assert config_path.exists(), f"Config file {config_path} does not exist." 24 | run_config = OmegaConf.load(config_path) 25 | 26 | if "parallel_envs" in run_config.env: 27 | del run_config.env.parallel_envs 28 | env = hydra.utils.call( 29 | run_config.env, 30 | enable_video=True, 31 | seed=cfg.seed, 32 | ) 33 | 34 | torch.set_num_threads(1) 35 | 36 | if cfg.seed is not None: 37 | torch.manual_seed(cfg.seed) 38 | np.random.seed(cfg.seed) 39 | 40 | run_config.algorithm._target_ = run_config.algorithm._target_.replace("train", "eval") 41 | 42 | if cfg.load_step is not None: 43 | load_step = cfg.load_step 44 | else: 45 | # Find the latest checkpoint 46 | load_step = max( 47 | [ 48 | int(f.stem.split("_")[-1][1:]) 49 | for f in (path / "checkpoints").glob("model_s*.pt") 50 | ] 51 | ) 52 | ckpt_path = path / "checkpoints" / f"model_s{load_step}.pt" 53 | assert ckpt_path.exists(), f"Checkpoint {ckpt_path} does not exist." 54 | 55 | hydra.utils.call( 56 | run_config.algorithm, 57 | env, 58 | ckpt_path, 59 | _recursive_=False, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/plot_runs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | 7 | from marlbase.utils.postprocessing.load_data import load_and_group_runs 8 | 9 | 10 | DEFAULT_METRIC = "mean_episode_returns" 11 | 12 | 13 | @click.command() 14 | @click.option("--source", type=click.Path(dir_okay=True, writable=False), required=True) 15 | @click.option("--minimal-name", type=bool, default=True) 16 | @click.option("--metric", type=str, default=DEFAULT_METRIC) 17 | @click.option("--save_path", type=click.Path(dir_okay=True, writable=True)) 18 | def run(source, minimal_name, metric, save_path): 19 | groups = load_and_group_runs(Path(source), minimal_name) 20 | assert len(groups) > 0, "No groups found" 21 | 22 | click.echo(f"Loaded {len(groups)} groups:") 23 | for group in groups: 24 | click.echo(f"\t{group.name} with {len(group.runs)} runs") 25 | 26 | assert all( 27 | group.has_metric(metric) for group in groups 28 | ), f"Metric {metric} not found in all groups" 29 | 30 | envs = set([group.config.env.name for group in groups]) 31 | 32 | for env in envs: 33 | env_groups = [group for group in groups if group.config.env.name == env] 34 | 35 | sns.set_style("whitegrid") 36 | plt.figure() 37 | for group in env_groups: 38 | steps = group.get_metric("environment_steps").mean(axis=0) 39 | values = group.get_metric(metric) 40 | means = values.mean(axis=0) 41 | stds = values.std(axis=0) 42 | plt.plot(steps, means, label=group.name) 43 | plt.fill_between( 44 | steps, 45 | means - stds, 46 | means + stds, 47 | alpha=0.3, 48 | ) 49 | plt.legend() 50 | plt.xlabel("Environment steps") 51 | plt.ylabel(metric) 52 | plt.title(env) 53 | if save_path: 54 | path = Path(save_path) / f"{env.replace('/', ':')}_{metric}.pdf" 55 | path.parent.mkdir(parents=True, exist_ok=True) 56 | plt.savefig(path) 57 | plt.show() 58 | 59 | 60 | if __name__ == "__main__": 61 | run() 62 | -------------------------------------------------------------------------------- /marlbase/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MultiCategorical: 5 | def __init__(self, categoricals): 6 | self.categoricals = categoricals 7 | 8 | def __getitem__(self, key): 9 | return self.categoricals[key] 10 | 11 | def sample(self): 12 | return [c.sample().unsqueeze(-1) for c in self.categoricals] 13 | 14 | def log_probs(self, actions): 15 | return [ 16 | c.log_prob(a.squeeze(-1)).unsqueeze(-1) 17 | for c, a in zip(self.categoricals, actions) 18 | ] 19 | 20 | def mode(self): 21 | return [c.mode for c in self.categoricals] 22 | 23 | def entropy(self): 24 | return [c.entropy() for c in self.categoricals] 25 | 26 | 27 | def to_onehot(tensor, n_dims): 28 | """ 29 | Convert tensor of indices to one-hot representation 30 | :param tensor: tensor of indices (batch_size, ..., 1) 31 | :param n_dims: number of dimensions 32 | :return: one-hot representation (batch_size, ..., n_dims) 33 | """ 34 | onehot = torch.zeros(tensor.shape + (n_dims,), device=tensor.device) 35 | return onehot.scatter(-1, tensor.unsqueeze(-1), 1) 36 | 37 | 38 | def compute_nstep_returns(rewards, done, next_values, nsteps, gamma): 39 | """ 40 | Computed n-step returns 41 | :param rewards: tensor of shape (ep_length, batch_size, n_agents) 42 | :param done: tensor of shape (ep_length, batch_size, n_agents) 43 | :param next_values: tensor of shape (ep_length, batch_size, n_agents) 44 | :param nsteps: number of steps to bootstrap 45 | :param gamma: discount factor 46 | :return: tensor of shape with returns (ep_length, batch_size, n_agents) 47 | """ 48 | ep_length = rewards.size(0) 49 | nstep_values = torch.zeros_like(rewards) 50 | for t_start in range(ep_length): 51 | nstep_return_t = torch.zeros_like(rewards[0]) 52 | for step in range(nsteps + 1): 53 | t = t_start + step 54 | if t >= ep_length: 55 | # episode has ended 56 | break 57 | elif step == nsteps: 58 | # last n-step value --> bootstrap from the next value 59 | nstep_return_t += gamma**step * next_values[t] * (1 - done[t]) 60 | else: 61 | nstep_return_t += gamma**step * rewards[t] * (1 - done[t]) 62 | nstep_values[t_start] = nstep_return_t 63 | return nstep_values 64 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/hiplot_fetcher.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import click 4 | import hiplot as hip 5 | import json 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | class NumpyEncoder(json.JSONEncoder): 11 | """ Custom encoder for numpy data types """ 12 | def default(self, obj): 13 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, 14 | np.int16, np.int32, np.int64, np.uint8, 15 | np.uint16, np.uint32, np.uint64)): 16 | 17 | return int(obj) 18 | 19 | elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): 20 | return float(obj) 21 | 22 | elif isinstance(obj, (np.complex_, np.complex64, np.complex128)): 23 | return {'real': obj.real, 'imag': obj.imag} 24 | 25 | elif isinstance(obj, (np.ndarray,)): 26 | return obj.tolist() 27 | 28 | elif isinstance(obj, (np.bool_)): 29 | return bool(obj) 30 | 31 | elif isinstance(obj, (np.void)): 32 | return None 33 | 34 | return json.JSONEncoder.default(self, obj) 35 | 36 | def experiment_fetcher(uri): 37 | 38 | PREFIX = "exp://" 39 | 40 | if not uri.startswith(PREFIX): 41 | # Let other fetchers handle this one 42 | raise hip.ExperimentFetcherDoesntApply() 43 | uri = uri[len(PREFIX):] # Remove the prefix 44 | 45 | exported_file = uri.split("/")[0] 46 | 47 | df = pd.read_hdf(exported_file, "df") 48 | configs = pd.read_hdf(exported_file, "configs") 49 | 50 | df = ( 51 | df.groupby(axis=1, level=[0, 1, 2]).mean().max() 52 | ) 53 | 54 | data = defaultdict(lambda: defaultdict(list)) 55 | 56 | for env, df in df.groupby(level=0): 57 | df = df.xs(env) 58 | for alg, df in df.groupby(level=0): 59 | df = df.xs(alg) 60 | 61 | for index, perf in df.iteritems(): 62 | data[env][alg].append({**configs.loc[index].to_dict(), "performance": perf, "uid": index}) 63 | 64 | env, alg = uri.split("/")[1], uri.split("/")[2] 65 | 66 | data = json.loads(json.dumps(data[env][alg], cls=NumpyEncoder)) 67 | exp = hip.Experiment.from_iterable(data) 68 | 69 | return exp 70 | 71 | 72 | if __name__ == "__main__": 73 | click.echo("Run with \"hiplot fastmarl.utils.postprocessing.hiplot_fetcher.experiment_fetcher\"") 74 | click.echo("And enter \"exp://filename.hd5/envname/alg\" in the textbox") 75 | 76 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/export_multirun.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from hashlib import sha256 3 | from pathlib import Path 4 | 5 | import click 6 | import json 7 | from munch import munchify 8 | import pandas as pd 9 | import yaml 10 | 11 | 12 | def _load_data(folder): 13 | path = Path(folder) 14 | config_files = path.glob("**/**/.hydra/config.yaml") 15 | 16 | data = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 17 | seed_data = defaultdict(lambda: defaultdict(lambda: defaultdict(set))) 18 | 19 | algos = set() 20 | envs = set() 21 | 22 | hash_to_config = defaultdict() 23 | 24 | for config_path in config_files: 25 | # hash = config_path.parent.parent.name 26 | 27 | with config_path.open() as fp: 28 | config = munchify(yaml.safe_load(fp)) 29 | 30 | env = config.env.name 31 | try: 32 | env = env.split(":")[1] # remove the library 33 | except IndexError: 34 | pass 35 | 36 | algo = config.algorithm.name 37 | 38 | algos.add(algo) 39 | envs.add(env) 40 | 41 | seed = config.seed 42 | del config.seed 43 | 44 | raw_data = json.dumps(config, sort_keys=True).encode("utf8") 45 | hash = sha256(raw_data).hexdigest()[:12] 46 | 47 | hash_to_config[hash] = pd.json_normalize(config) 48 | 49 | df = pd.read_csv(config_path.parent.parent / "results.csv", index_col=0)[ 50 | "mean_episode_returns" 51 | ] 52 | data[env][algo][hash].append(df.rename(f"seed={seed}")) 53 | assert seed not in seed_data[env][algo][hash], "Duplicate seed" 54 | seed_data[env][algo][hash].add(seed) 55 | 56 | env_df_list = [] 57 | for env in data.keys(): 58 | algo_df_list = [] 59 | for algo in data[env].keys(): 60 | lst = [] 61 | for hash in data[env][algo].keys(): 62 | lst.append(pd.concat(data[env][algo][hash], axis=1)) 63 | df = pd.concat(lst, axis=1, keys=[h for h in data[env][algo].keys()]) 64 | algo_df_list.append(df) 65 | df = pd.concat(algo_df_list, axis=1, keys=data[env].keys()) 66 | env_df_list.append(df) 67 | df = pd.concat(env_df_list, axis=1, keys=data.keys()) 68 | 69 | return pd.concat(hash_to_config).droplevel(1), df 70 | 71 | 72 | @click.command() 73 | @click.option("--folder", type=click.Path(exists=True), default="outputs/") 74 | @click.option("--export-file", type=click.Path(dir_okay=False, writable=True)) 75 | @click.pass_context 76 | def run(ctx, folder, export_file): 77 | 78 | hash_to_config, df = _load_data(folder) 79 | 80 | df.to_hdf(export_file, key="df", mode="w", complevel=9) 81 | hash_to_config.to_hdf(export_file, key="configs") 82 | 83 | 84 | if __name__ == "__main__": 85 | run() 86 | -------------------------------------------------------------------------------- /marlbase/utils/envs.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import random 3 | 4 | import gymnasium as gym 5 | from omegaconf import DictConfig 6 | 7 | from marlbase.utils import wrappers as mwrappers 8 | from marlbase.utils.smaclite_wrapper import SMACliteWrapper 9 | 10 | 11 | def _make_parallel_envs( 12 | name, 13 | parallel_envs, 14 | wrappers, 15 | time_limit, 16 | clear_info, 17 | observe_id, 18 | standardise_rewards, 19 | seed, 20 | enable_video, 21 | **kwargs, 22 | ): 23 | def _env_thunk(seed): 24 | if "smaclite" in name: 25 | import smaclite # noqa 26 | 27 | env = gym.make( 28 | name, 29 | seed=seed, 30 | render_mode="rgb_array" if enable_video else None, 31 | **kwargs, 32 | ) 33 | env = SMACliteWrapper(env) 34 | else: 35 | env = gym.make( 36 | name, **kwargs, render_mode="rgb_array" if enable_video else None 37 | ) 38 | if clear_info: 39 | env = mwrappers.ClearInfo(env) 40 | if time_limit: 41 | env = gym.wrappers.TimeLimit(env, time_limit) 42 | env = mwrappers.RecordEpisodeStatistics(env) 43 | if observe_id: 44 | env = mwrappers.ObserveID(env) 45 | if standardise_rewards: 46 | env = mwrappers.StandardiseReward(env) 47 | if wrappers is not None: 48 | for wrapper in wrappers: 49 | wrapper = ( 50 | getattr(mwrappers, wrapper) 51 | if hasattr(mwrappers, wrapper) 52 | else getattr(gym.wrappers, wrapper) 53 | ) 54 | env = wrapper(env) 55 | env.reset(seed=seed) 56 | return env 57 | 58 | if seed is None: 59 | seed = random.randint(0, 99999) 60 | 61 | envs = gym.vector.AsyncVectorEnv( 62 | [partial(_env_thunk, seed + i) for i in range(parallel_envs)] 63 | ) 64 | 65 | return envs 66 | 67 | 68 | def _make_env( 69 | name, 70 | time_limit, 71 | clear_info, 72 | observe_id, 73 | standardise_rewards, 74 | wrappers, 75 | seed, 76 | enable_video, 77 | **kwargs, 78 | ): 79 | if "smaclite" in name: 80 | import smaclite # noqa 81 | 82 | env = gym.make( 83 | name, 84 | seed=seed, 85 | render_mode="rgb_array" if enable_video else None, 86 | **kwargs, 87 | ) 88 | env = SMACliteWrapper(env) 89 | else: 90 | env = gym.make( 91 | name, render_mode="rgb_array" if enable_video else None, **kwargs 92 | ) 93 | if clear_info: 94 | env = mwrappers.ClearInfo(env) 95 | if time_limit: 96 | env = gym.wrappers.TimeLimit(env, time_limit) 97 | env = mwrappers.RecordEpisodeStatistics(env) 98 | if observe_id: 99 | env = mwrappers.ObserveID(env) 100 | if standardise_rewards: 101 | env = mwrappers.StandardiseReward(env) 102 | if wrappers is not None: 103 | for wrapper in wrappers: 104 | wrapper = ( 105 | getattr(mwrappers, wrapper) 106 | if hasattr(mwrappers, wrapper) 107 | else getattr(gym.wrappers, wrapper) 108 | ) 109 | env = wrapper(env) 110 | 111 | env.reset(seed=seed) 112 | return env 113 | 114 | 115 | def make_env(seed, enable_video=False, **env_config): 116 | env_config = DictConfig(env_config) 117 | if "parallel_envs" in env_config and env_config.parallel_envs: 118 | return _make_parallel_envs(**env_config, enable_video=enable_video, seed=seed) 119 | return _make_env(**env_config, enable_video=enable_video, seed=seed) 120 | -------------------------------------------------------------------------------- /marlbase/search.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from itertools import product 3 | import multiprocessing 4 | import random 5 | import subprocess 6 | import time 7 | 8 | import click 9 | import yaml 10 | 11 | _CPU_COUNT = multiprocessing.cpu_count() - 1 12 | 13 | 14 | def _flatten_lists(object): 15 | for item in object: 16 | if isinstance(item, (list, tuple, set)): 17 | yield from _flatten_lists(item) 18 | else: 19 | yield item 20 | 21 | 22 | def _seed_and_shuffle(configs, shuffle, seeds): 23 | configs = [[f"+hypergroup=hp_grp_{i}"] + c for i, c in enumerate(configs)] 24 | configs = list(product(configs, [f"seed={i}" for i in range(seeds)])) 25 | configs = [list(_flatten_lists(c)) for c in configs] 26 | 27 | if shuffle: 28 | random.Random(1337).shuffle(configs) 29 | 30 | return configs 31 | 32 | 33 | def _load_config(filename): 34 | config = yaml.load(filename, Loader=yaml.CLoader) 35 | return config 36 | 37 | 38 | def _gen_combos(config, built_config): 39 | built_config = deepcopy(built_config) 40 | if not config: 41 | return [[f"{k}={v}" for k, v in built_config.items()]] 42 | 43 | k, v = list(config.items())[0] 44 | 45 | configs = [] 46 | if type(v) is list: 47 | for item in v: 48 | new_config = deepcopy(config) 49 | del new_config[k] 50 | new_config[k] = item 51 | configs += _gen_combos(new_config, built_config) 52 | elif type(v) is tuple: 53 | new_config = deepcopy(config) 54 | del new_config[k] 55 | for item in v: 56 | new_config.update(item) 57 | 58 | configs += _gen_combos(new_config, built_config) 59 | else: 60 | new_config = deepcopy(config) 61 | del new_config[k] 62 | built_config[k] = v 63 | configs += _gen_combos(new_config, built_config) 64 | return configs 65 | 66 | 67 | def work(cmd, sleep): 68 | cmd = cmd.split(" ") 69 | time.sleep(sleep) 70 | return subprocess.call(cmd, shell=False) 71 | 72 | 73 | @click.group() 74 | def cli(): 75 | pass 76 | 77 | 78 | @cli.command() 79 | @click.argument("output", type=click.Path(exists=False, dir_okay=False, writable=True)) 80 | def write(output): 81 | raise NotImplementedError 82 | 83 | 84 | @cli.group() 85 | @click.option("--config", type=click.File(), default="config.yaml") 86 | @click.option("--shuffle/--no-shuffle", default=True) 87 | @click.option("--seeds", default=3, show_default=True, help="How many seeds to run") 88 | @click.pass_context 89 | def run(ctx, config, shuffle, seeds): 90 | config = _load_config(config) 91 | configs = _gen_combos(config, {}) 92 | 93 | configs = _seed_and_shuffle(configs, shuffle, seeds) 94 | if len(configs) == 0: 95 | click.echo("No valid combinations. Aborted!") 96 | exit(1) 97 | ctx.obj = configs 98 | 99 | 100 | @run.command() 101 | @click.option( 102 | "--cpus", 103 | default=_CPU_COUNT, 104 | show_default=True, 105 | help="How many processes to run in parallel", 106 | ) 107 | @click.pass_obj 108 | def locally(combos, cpus): 109 | configs = [ 110 | "python run.py " + "-m " + " ".join([c for c in combo]) for combo in combos 111 | ] 112 | args = [(conf, i * 2) for i, conf in enumerate(configs)] 113 | 114 | click.confirm( 115 | f"There are {click.style(str(len(combos)), fg='red')} combinations of configurations. Up to {cpus} will run in parallel. Continue?", 116 | abort=True, 117 | ) 118 | 119 | pool = multiprocessing.Pool(processes=cpus) 120 | print(pool.starmap(work, args)) 121 | 122 | 123 | @run.command() 124 | @click.pass_obj 125 | def dry_run(combos): 126 | configs = [" ".join([c for c in combo]) for combo in combos] 127 | click.echo( 128 | f"There are {click.style(str(len(combos)), fg='red')} configurations as shown below:" 129 | ) 130 | for c in configs: 131 | click.echo(c) 132 | 133 | 134 | @run.command() 135 | @click.argument( 136 | "index", 137 | type=int, 138 | ) 139 | @click.pass_obj 140 | def single(combos, index): 141 | """Runs a single hyperparameter combination 142 | INDEX is the index of the combination to run in the generated combination list 143 | """ 144 | 145 | config = combos[index] 146 | cmd = "python run.py " + " ".join([c for c in config]) 147 | print(cmd) 148 | work(cmd) 149 | 150 | 151 | if __name__ == "__main__": 152 | cli() 153 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/load_data.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | from typing import Dict, List, Union 4 | 5 | import numpy as np 6 | from omegaconf import OmegaConf 7 | import pandas as pd 8 | 9 | 10 | def _load_csv(path: Path) -> Dict[str, List[float]]: 11 | assert ( 12 | path.exists() and path.is_file() and path.suffix == ".csv" 13 | ), f"{path} is not a valid csv file" 14 | return pd.read_csv(path).to_dict(orient="list") 15 | 16 | 17 | def _load_config(path: Path) -> OmegaConf: 18 | assert ( 19 | path.exists() and path.is_file() and path.suffix == ".yaml" 20 | ), f"{path} is not a valid config file" 21 | return OmegaConf.load(path) 22 | 23 | 24 | class Run: 25 | def __init__(self, config: OmegaConf, data: Dict[str, List[float]], path: Path): 26 | self.data = data 27 | self.config = config 28 | self.path = path 29 | 30 | def __from_path__(path: Path) -> "Run": 31 | assert path.exists() and path.is_dir(), f"{path} is not a valid run directory" 32 | data = _load_csv(path / "results.csv") 33 | config = _load_config(path / "config.yaml") 34 | return Run(config, data, path) 35 | 36 | def __str__(self) -> str: 37 | return f"Run {self.path}" 38 | 39 | def get_config_name(self) -> str: 40 | return " ".join( 41 | [f"{key}={value}" for key, value in self.config.items() if key != "seed"] 42 | ) 43 | 44 | 45 | def _load_run(path: Path): 46 | return Run.__from_path__(path) 47 | 48 | 49 | class Group: 50 | def __init__(self, name, runs: List[Run]): 51 | self.name = name 52 | self.config = runs[0].config 53 | self.config.pop("seed") 54 | self.runs = runs 55 | 56 | def __str__(self) -> str: 57 | return f"Group {self.name} ({len(self.runs)} runs)" 58 | 59 | def has_metric(self, key) -> bool: 60 | has_metrics = [key in run.data for run in self.runs] 61 | assert all(has_metrics) or not any( 62 | has_metrics 63 | ), f"Key {key} is present in some but not all runs" 64 | return has_metrics[0] 65 | 66 | def get_metric(self, key) -> np.ndarray: 67 | assert self.has_metric(key), f"Key {key} is not present in all runs" 68 | values = [run.data[key] for run in self.runs] 69 | assert all( 70 | len(value) == len(values[0]) for value in values 71 | ), f"Values for key {key} have different lengths" 72 | return np.array(values) 73 | 74 | 75 | def _load_runs(path: Path) -> List[Run]: 76 | assert path.exists() and path.is_dir(), f"{path} is not a valid directory" 77 | runs = [] 78 | for run in path.glob("**/results.csv"): 79 | run = _load_run(run.parent) 80 | runs.append(run) 81 | return runs 82 | 83 | 84 | def _flatten_omegaconf( 85 | config: OmegaConf, base_name=None 86 | ) -> Dict[str, Union[str, float, int]]: 87 | flat_config = {} 88 | for key, value in config.items(): 89 | key = f"{base_name}.{key}" if base_name else key 90 | if OmegaConf.is_config(value) and not OmegaConf.is_list(value): 91 | flat_config.update(_flatten_omegaconf(value, key)) 92 | else: 93 | flat_config[key] = value 94 | return flat_config 95 | 96 | 97 | def load_and_group_runs(path: Path, minimal_name: bool = True) -> List[Group]: 98 | """ 99 | Load all runs in a directory and group them by unique configurations 100 | :param path: Path to directory containing runs 101 | :param minimal_name: Use minimal name for each group 102 | :return: List of Group objects 103 | """ 104 | # group runs by unique confrigurations 105 | runs_by_config_name = defaultdict(list) 106 | for run in _load_runs(path): 107 | runs_by_config_name[run.get_config_name()].append(run) 108 | 109 | if minimal_name: 110 | # identify minimal hyperparameters that differentiate runs 111 | values_by_key = defaultdict(set) 112 | for config_name, runs in runs_by_config_name.items(): 113 | group_config = _flatten_omegaconf(runs[0].config) 114 | for key, value in group_config.items(): 115 | if ( 116 | key == "seed" 117 | or key == "algorithm.name" 118 | or "_target_" in key 119 | or key == "hypergroup" 120 | or "wrappers" in key 121 | ): 122 | continue 123 | values_by_key[key].add(value) 124 | 125 | # distinguishing hyperparameters 126 | distinguishing_keys = [ 127 | key for key, values in values_by_key.items() if len(values) > 1 128 | ] 129 | 130 | runs_by_minimal_config_name = {} 131 | for runs in runs_by_config_name.values(): 132 | group_config = _flatten_omegaconf(runs[0].config) 133 | minimal_config_name = group_config["algorithm.name"].upper() 134 | config_name = " ".join( 135 | [ 136 | f"{key}={group_config[key]}" 137 | for key in distinguishing_keys 138 | if key in group_config 139 | ] 140 | ) 141 | if config_name: 142 | minimal_config_name += f" ({config_name})" 143 | runs_by_minimal_config_name[minimal_config_name] = runs 144 | 145 | runs_by_config_name = runs_by_minimal_config_name 146 | 147 | return [Group(name, runs) for name, runs in runs_by_config_name.items()] 148 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | fastmarl/outputs/ 3 | fastmarl/multirun/ 4 | # Created by https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode 5 | # Edit at https://www.gitignore.io/?templates=linux,python,windows,pycharm+all,visualstudiocode 6 | .vscode 7 | ### Linux ### 8 | *~ 9 | 10 | # temporary files which can be created if a process still has a handle open of a deleted file 11 | .fuse_hidden* 12 | 13 | # KDE directory preferences 14 | .directory 15 | 16 | # Linux trash folder which might appear on any partition or disk 17 | .Trash-* 18 | 19 | # .nfs files are created when an open file is removed but is still being accessed 20 | .nfs* 21 | 22 | ### PyCharm+all ### 23 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 24 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 25 | 26 | # User-specific stuff 27 | .idea/**/workspace.xml 28 | .idea/**/tasks.xml 29 | .idea/**/usage.statistics.xml 30 | .idea/**/dictionaries 31 | .idea/**/shelf 32 | 33 | # Generated files 34 | .idea/**/contentModel.xml 35 | 36 | # Sensitive or high-churn files 37 | .idea/**/dataSources/ 38 | .idea/**/dataSources.ids 39 | .idea/**/dataSources.local.xml 40 | .idea/**/sqlDataSources.xml 41 | .idea/**/dynamic.xml 42 | .idea/**/uiDesigner.xml 43 | .idea/**/dbnavigator.xml 44 | 45 | # Gradle 46 | .idea/**/gradle.xml 47 | .idea/**/libraries 48 | 49 | # Gradle and Maven with auto-import 50 | # When using Gradle or Maven with auto-import, you should exclude module files, 51 | # since they will be recreated, and may cause churn. Uncomment if using 52 | # auto-import. 53 | # .idea/modules.xml 54 | # .idea/*.iml 55 | # .idea/modules 56 | # *.iml 57 | # *.ipr 58 | 59 | # CMake 60 | cmake-build-*/ 61 | 62 | # Mongo Explorer plugin 63 | .idea/**/mongoSettings.xml 64 | 65 | # File-based project format 66 | *.iws 67 | 68 | # IntelliJ 69 | out/ 70 | 71 | # mpeltonen/sbt-idea plugin 72 | .idea_modules/ 73 | 74 | # JIRA plugin 75 | atlassian-ide-plugin.xml 76 | 77 | # Cursive Clojure plugin 78 | .idea/replstate.xml 79 | 80 | # Crashlytics plugin (for Android Studio and IntelliJ) 81 | com_crashlytics_export_strings.xml 82 | crashlytics.properties 83 | crashlytics-build.properties 84 | fabric.properties 85 | 86 | # Editor-based Rest Client 87 | .idea/httpRequests 88 | 89 | # Android studio 3.1+ serialized cache file 90 | .idea/caches/build_file_checksums.ser 91 | 92 | ### PyCharm+all Patch ### 93 | # Ignores the whole .idea folder and all .iml files 94 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 95 | 96 | .idea/ 97 | 98 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 99 | 100 | *.iml 101 | modules.xml 102 | .idea/misc.xml 103 | *.ipr 104 | 105 | # Sonarlint plugin 106 | .idea/sonarlint 107 | 108 | ### Python ### 109 | # Byte-compiled / optimized / DLL files 110 | __pycache__/ 111 | *.py[cod] 112 | *$py.class 113 | 114 | # C extensions 115 | *.so 116 | 117 | # Distribution / packaging 118 | .Python 119 | build/ 120 | develop-eggs/ 121 | dist/ 122 | downloads/ 123 | eggs/ 124 | .eggs/ 125 | lib/ 126 | lib64/ 127 | parts/ 128 | sdist/ 129 | var/ 130 | wheels/ 131 | pip-wheel-metadata/ 132 | share/python-wheels/ 133 | *.egg-info/ 134 | .installed.cfg 135 | *.egg 136 | MANIFEST 137 | 138 | # PyInstaller 139 | # Usually these files are written by a python script from a template 140 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 141 | *.manifest 142 | *.spec 143 | 144 | # Installer logs 145 | pip-log.txt 146 | pip-delete-this-directory.txt 147 | 148 | # Unit test / coverage reports 149 | htmlcov/ 150 | .tox/ 151 | .nox/ 152 | .coverage 153 | .coverage.* 154 | .cache 155 | nosetests.xml 156 | coverage.xml 157 | *.cover 158 | .hypothesis/ 159 | .pytest_cache/ 160 | 161 | # Translations 162 | *.mo 163 | *.pot 164 | 165 | # Scrapy stuff: 166 | .scrapy 167 | 168 | # Sphinx documentation 169 | docs/_build/ 170 | 171 | # PyBuilder 172 | target/ 173 | 174 | # pyenv 175 | .python-version 176 | 177 | # pipenv 178 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 179 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 180 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 181 | # install all needed dependencies. 182 | #Pipfile.lock 183 | 184 | # celery beat schedule file 185 | celerybeat-schedule 186 | 187 | # SageMath parsed files 188 | *.sage.py 189 | 190 | # Spyder project settings 191 | .spyderproject 192 | .spyproject 193 | 194 | # Rope project settings 195 | .ropeproject 196 | 197 | # Mr Developer 198 | .mr.developer.cfg 199 | .project 200 | .pydevproject 201 | 202 | # mkdocs documentation 203 | /site 204 | 205 | # mypy 206 | .mypy_cache/ 207 | .dmypy.json 208 | dmypy.json 209 | 210 | # Pyre type checker 211 | .pyre/ 212 | 213 | ### VisualStudioCode ### 214 | .vscode/* 215 | !.vscode/settings.json 216 | !.vscode/tasks.json 217 | !.vscode/launch.json 218 | !.vscode/extensions.json 219 | 220 | ### VisualStudioCode Patch ### 221 | # Ignore all local history of files 222 | .history 223 | 224 | ### Windows ### 225 | # Windows thumbnail cache files 226 | Thumbs.db 227 | Thumbs.db:encryptable 228 | ehthumbs.db 229 | ehthumbs_vista.db 230 | 231 | # Dump file 232 | *.stackdump 233 | 234 | # Folder config file 235 | [Dd]esktop.ini 236 | 237 | # Recycle Bin used on file shares 238 | $RECYCLE.BIN/ 239 | 240 | # Windows Installer files 241 | *.cab 242 | *.msi 243 | *.msix 244 | *.msm 245 | *.msp 246 | 247 | # Windows shortcuts 248 | *.lnk 249 | 250 | # End of https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode -------------------------------------------------------------------------------- /marlbase/utils/wrappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of environment wrappers for multi-agent environments 3 | """ 4 | 5 | from collections import deque 6 | from time import perf_counter 7 | 8 | import gymnasium as gym 9 | from gymnasium import ObservationWrapper, spaces 10 | import numpy as np 11 | 12 | 13 | class RecordEpisodeStatistics(gym.Wrapper): 14 | """Multi-agent version of RecordEpisodeStatistics gym wrapper""" 15 | 16 | def __init__(self, env, deque_size=100): 17 | super().__init__(env) 18 | self.t0 = perf_counter() 19 | self.episode_reward = np.zeros(self.unwrapped.n_agents) 20 | self.episode_length = 0 21 | self.reward_queue = deque(maxlen=deque_size) 22 | self.length_queue = deque(maxlen=deque_size) 23 | 24 | def reset(self, **kwargs): 25 | obs, info = super().reset(**kwargs) 26 | self.episode_reward = 0 27 | self.episode_length = 0 28 | self.t0 = perf_counter() 29 | return obs, info 30 | 31 | def step(self, action): 32 | observation, reward, done, truncated, info = super().step(action) 33 | self.episode_reward += np.array(reward, dtype=np.float32) 34 | self.episode_length += 1 35 | if done or truncated: 36 | info["episode_returns"] = self.episode_reward 37 | if len(self.episode_reward) == self.unwrapped.n_agents: 38 | for i, agent_reward in enumerate(self.episode_reward): 39 | info[f"agent{i}/episode_returns"] = agent_reward 40 | info["episode_length"] = self.episode_length 41 | info["episode_time"] = perf_counter() - self.t0 42 | 43 | self.reward_queue.append(self.episode_reward) 44 | self.length_queue.append(self.episode_length) 45 | return observation, reward, done, truncated, info 46 | 47 | 48 | class FlattenObservation(ObservationWrapper): 49 | r"""Observation wrapper that flattens the observation of individual agents.""" 50 | 51 | def __init__(self, env): 52 | super(FlattenObservation, self).__init__(env) 53 | ma_spaces = [] 54 | for sa_obs in env.observation_space: 55 | flatdim = spaces.flatdim(sa_obs) 56 | ma_spaces += [ 57 | spaces.Box( 58 | low=-float("inf"), 59 | high=float("inf"), 60 | shape=(flatdim,), 61 | dtype=np.float32, 62 | ) 63 | ] 64 | self.observation_space = spaces.Tuple(tuple(ma_spaces)) 65 | 66 | def observation(self, observation): 67 | return tuple( 68 | [ 69 | spaces.flatten(obs_space, obs) 70 | for obs_space, obs in zip(self.env.observation_space, observation) 71 | ] 72 | ) 73 | 74 | 75 | class ObserveID(gym.ObservationWrapper): 76 | def __init__(self, env): 77 | super().__init__(env) 78 | agent_count = env.unwrapped.n_agents 79 | for obs_space in self.observation_space: 80 | assert ( 81 | isinstance(obs_space, gym.spaces.Box) and len(obs_space.shape) == 1 82 | ), "ObserveID wrapper assumes flattened observation space." 83 | self.observation_space = gym.spaces.Tuple( 84 | tuple( 85 | [ 86 | gym.spaces.Box( 87 | low=-np.inf, 88 | high=np.inf, 89 | shape=((x.shape[0] + agent_count),), 90 | dtype=x.dtype, 91 | ) 92 | for x in self.observation_space 93 | ] 94 | ) 95 | ) 96 | 97 | def observation(self, observation): 98 | observation = np.stack(observation) 99 | observation = np.concatenate( 100 | (np.eye(self.unwrapped.n_agents, dtype=observation.dtype), observation), 101 | axis=1, 102 | ) 103 | return [o.squeeze() for o in np.split(observation, self.unwrapped.n_agents)] 104 | 105 | 106 | class CooperativeReward(gym.RewardWrapper): 107 | def reward(self, reward): 108 | return self.unwrapped.n_agents * [sum(reward)] 109 | 110 | 111 | class StandardiseReward(gym.RewardWrapper): 112 | def __init__(self, *args, **kwargs): 113 | super().__init__(*args, **kwargs) 114 | self.stdr_wrp_sumw = np.zeros(self.unwrapped.n_agents, dtype=np.float32) 115 | self.stdr_wrp_wmean = np.zeros(self.unwrapped.n_agents, dtype=np.float32) 116 | self.stdr_wrp_t = np.zeros(self.unwrapped.n_agents, dtype=np.float32) 117 | self.stdr_wrp_n = 0 118 | 119 | def reward(self, reward): 120 | # based on http://www.nowozin.net/sebastian/blog/streaming-mean-and-variance-computation.html 121 | # update running avg and std 122 | weight = 1.0 123 | 124 | q = reward - self.stdr_wrp_wmean 125 | temp_sumw = self.stdr_wrp_sumw + weight 126 | r = q * weight / temp_sumw 127 | 128 | self.stdr_wrp_wmean += r 129 | self.stdr_wrp_t += q * r * self.stdr_wrp_sumw 130 | self.stdr_wrp_sumw = temp_sumw 131 | self.stdr_wrp_n += 1 132 | 133 | if self.stdr_wrp_n == 1: 134 | return reward 135 | 136 | # calculate standardised reward 137 | var = (self.stdr_wrp_t * self.stdr_wrp_n) / ( 138 | self.stdr_wrp_sumw * (self.stdr_wrp_n - 1) 139 | ) 140 | stdr_rew = (reward - self.stdr_wrp_wmean) / (np.sqrt(var) + 1e-6) 141 | return stdr_rew 142 | 143 | 144 | class ClearInfo(gym.Wrapper): 145 | def step(self, action): 146 | observation, reward, done, truncated, _ = self.env.step(action) 147 | return observation, reward, done, truncated, {} 148 | -------------------------------------------------------------------------------- /marlbase/utils/loggers.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from hashlib import sha256 3 | import json 4 | import logging 5 | import math 6 | import time 7 | from typing import Dict, List 8 | 9 | import numpy as np 10 | from omegaconf import DictConfig, OmegaConf 11 | import pandas as pd 12 | 13 | 14 | def squash_info(info): 15 | new_info = {} 16 | keys = set([k for i in info for k in i.keys()]) 17 | keys.discard("TimeLimit.truncated") 18 | keys.discard("terminal_observation") 19 | for key in keys: 20 | values = [d[key] for d in info if key in d] 21 | if len(values) == 1: 22 | new_info[key] = values[0] 23 | continue 24 | 25 | mean = np.mean([np.array(v).sum() for v in values]) 26 | std = np.std([np.array(v).sum() for v in values]) 27 | 28 | split_key = key.rsplit("/", 1) 29 | mean_key = split_key[:] 30 | std_key = split_key[:] 31 | mean_key[-1] = "mean_" + mean_key[-1] 32 | std_key[-1] = "std_" + std_key[-1] 33 | 34 | new_info["/".join(mean_key)] = mean 35 | new_info["/".join(std_key)] = std 36 | return new_info 37 | 38 | 39 | class Logger: 40 | def __init__(self, project_name, cfg: DictConfig) -> None: 41 | self.config_hash = sha256( 42 | json.dumps( 43 | {k: v for k, v in OmegaConf.to_container(cfg).items() if k != "seed"}, 44 | sort_keys=True, 45 | ).encode("utf8") 46 | ).hexdigest()[-10:] 47 | 48 | self._total_steps = cfg.algorithm.total_steps 49 | self._start_time = time.time() 50 | self._prev_time = None 51 | self._prev_steps = (0, 0) # steps (updates) and env_samples 52 | 53 | def log_metrics(self, metrics: List[Dict]): ... 54 | 55 | def print_progress(self, updates, steps, mean_returns, episodes): 56 | self.info(f"Updates {updates}, Environment timesteps {steps}") 57 | 58 | time_now = time.time() 59 | 60 | elapsed_wallclock = time_now - self._prev_time[0] if self._prev_time else None 61 | elapsed_cpu = ( 62 | time.process_time() - self._prev_time[1] if self._prev_time else None 63 | ) 64 | elapsed_from_start = timedelta(seconds=math.ceil((time_now - self._start_time))) 65 | 66 | completed = steps / self._total_steps 67 | 68 | if elapsed_wallclock: 69 | ups = (updates - self._prev_steps[0]) / elapsed_wallclock 70 | fps = (steps - self._prev_steps[1]) / elapsed_wallclock 71 | self.info(f"UPS: {ups:.2f}, FPS: {fps:.2f} (wall time)") 72 | 73 | # ups = (updates - self._prev_steps[0]) / elapsed_cpu 74 | # fps = (steps - self._prev_steps[1]) / elapsed_cpu 75 | # self.info(f"UPS: {ups:.2f}, FPS: {fps:.2f} (cpu time)") 76 | 77 | eta = elapsed_from_start * (1 - completed) / completed 78 | eta = timedelta(seconds=math.ceil(eta.total_seconds())) 79 | self.info(f"Elapsed Time: {elapsed_from_start}") 80 | self.info(f"Estim. Time Left: {eta}") 81 | 82 | self.info(f"Completed: {100*completed:.2f}%") 83 | 84 | self._prev_steps = (updates, steps) 85 | self._prev_time = time.time(), time.process_time() 86 | 87 | self.info(f"Last {episodes} episodes with mean returns: {mean_returns:.3f}") 88 | self.info("-------------------------------------------") 89 | 90 | def watch(self, model): 91 | self.debug(model) 92 | 93 | def debug(self, *args, **kwargs): 94 | return logging.debug(*args, **kwargs) 95 | 96 | def info(self, *args, **kwargs): 97 | return logging.info(*args, **kwargs) 98 | 99 | def warning(self, *args, **kwargs): 100 | return logging.warning(*args, **kwargs) 101 | 102 | def error(self, *args, **kwargs): 103 | return logging.error(*args, **kwargs) 104 | 105 | def critical(self, *args, **kwargs): 106 | return logging.critical(*args, **kwargs) 107 | 108 | def get_state(self): 109 | return None 110 | 111 | 112 | class WandbLogger(Logger): 113 | def __init__(self, project_name, cfg: DictConfig) -> None: 114 | import wandb 115 | 116 | super().__init__(project_name, cfg) 117 | self._run = wandb.init( 118 | project=project_name, 119 | config=OmegaConf.to_container(cfg), 120 | monitor_gym=True, 121 | group=self.config_hash, 122 | ) 123 | 124 | def log_metrics(self, metrics: List[Dict]): 125 | d = squash_info(metrics) 126 | self._run.log(d) 127 | 128 | self.print_progress( 129 | d["updates"], 130 | d["environment_steps"], 131 | d["mean_episode_returns"], 132 | len(metrics) - 1, 133 | ) 134 | 135 | def watch(self, model): 136 | self.debug(model) 137 | self._run.watch(model) 138 | 139 | 140 | class FileSystemLogger(Logger): 141 | def __init__(self, project_name, cfg): 142 | super().__init__(project_name, cfg) 143 | 144 | self.results_path = "results.csv" 145 | self.config_path = "config.yaml" 146 | with open(self.config_path, "w") as f: 147 | OmegaConf.save(cfg, f) 148 | 149 | def log_metrics(self, metrics): 150 | d = squash_info(metrics) 151 | df = pd.DataFrame.from_dict([d])[ 152 | ["environment_steps"] 153 | + sorted([k for k in d.keys() if k != "environment_steps"]) 154 | ] 155 | # Since we are appending, we only want to write the csv headers if the file does not already exist 156 | # the following codeblock handles this automatically 157 | with open(self.results_path, "a") as f: 158 | df.to_csv(f, header=f.tell() == 0, index=False) 159 | 160 | self.print_progress( 161 | d["updates"], 162 | d["environment_steps"], 163 | d["mean_episode_returns"], 164 | len(metrics) - 1, 165 | ) 166 | 167 | def get_state(self): 168 | df = pd.read_csv(self.results_path, index_col=0) 169 | return df 170 | -------------------------------------------------------------------------------- /marlbase/utils/stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def _load_data_from_subfolder(folder, metric, step=None, step_metric=None): 8 | """Helper function for pulling results data from logs 9 | 10 | Args: 11 | folder (str): 12 | metric (str): 13 | step (int): 14 | step_metric (str): 15 | 16 | Returns: 17 | list of performance values 18 | """ 19 | # The given folder will contain several sub-folders with random hashes like "1a8fdsk3" 20 | # Within each sub-folder is the data we need 21 | results = [] 22 | 23 | for subfolder in os.listdir(folder): 24 | data = pd.read_csv(f'{os.path.join(folder, subfolder, "results.csv")}') 25 | 26 | if step is not None and step_metric is not None: 27 | data = [data[data[step_metric] == step][metric].tolist()[0]] 28 | 29 | else: 30 | data = data[metric].tolist() 31 | 32 | results.append(data) 33 | 34 | return results 35 | 36 | 37 | def make_agg_metrics_intervals(folders, algos, metric, step=None, step_metric=None): 38 | """Pulls results for the 'Aggregate metrics with 95% Stratified Bootstrap CIs' plot 39 | Can also be used for "Performance Profiles" plot 40 | 41 | Below is an example usage for this function: 42 | make_agg_metrics_intervals( 43 | folders=[folder, folder, folder, folder], 44 | algos=['ac', 'ac', 'dqn', 'dqn'], 45 | metric=['mean_reward', 'mean_reward', 'mean_reward', 'mean_reward'], 46 | step=[240, 240, 500, 500], 47 | step_metric=['environment_steps', 'environment_steps', 'updates', 'updates'] 48 | ) 49 | 50 | Shape of the output data is {'algo_1': (n_runs x n_envs), ..., 'algo_j': (n_runs x n_envs} 51 | 52 | Args: 53 | folders (List[str]): 54 | algos (List[str]): 55 | metric (List[str]): 56 | step (List[int]): 57 | step_metric (List[str]): 58 | 59 | Returns: 60 | Dict of performance matrices 61 | """ 62 | # For the interval estimates plot, we need performance at a specific point during training/evaluation 63 | if step is None: 64 | raise ValueError('For interval plots, a specific step must be specified') 65 | if step_metric is None: 66 | raise ValueError('For interval plots, a specific step_metric must be specified') 67 | 68 | # Process for reading in the data 69 | results = {} 70 | 71 | for i in range(len(folders)): 72 | data = _load_data_from_subfolder(os.path.join(folders[i], algos[i]), metric[i], step[i], step_metric[i]) 73 | 74 | if algos[i] not in results.keys(): 75 | results[algos[i]] = [] 76 | 77 | results[algos[i]].append(data) 78 | 79 | # Now we need to transpose the pulled results into results matrices. For specific shape, see function docstring 80 | results_T = {} 81 | 82 | for algo in results.keys(): 83 | pulled_results = results[algo] 84 | results_T[algo] = np.array(pulled_results).T[0] 85 | 86 | return results_T 87 | 88 | 89 | def make_agg_metrics_pxy(folders, algos, metric, step=None, step_metric=None): 90 | """Pulls results for the 'Probability of Improvement' plot 91 | 92 | Below is an example usage for this function: 93 | make_agg_metrics_pxy( 94 | folders=[folder, folder, folder, folder], 95 | algos=['ac', 'ac', 'dqn', 'dqn'], 96 | metric=['mean_reward', 'mean_reward', 'mean_reward', 'mean_reward'], 97 | step=[240, 240, 500, 500], 98 | step_metric=['environment_steps', 'environment_steps', 'updates', 'updates'] 99 | ) 100 | 101 | Shape of the output data is {'algo_1,algo_2': ((n_runs x n_envs), (n_runs x n_envs)), ...} 102 | 103 | Args: 104 | folders (List[str]): 105 | algos (List[str]): 106 | metric (List[str]): 107 | step (List[int]): 108 | step_metric (List[str]): 109 | 110 | Returns: 111 | Dicts of comparative performance matrices 112 | """ 113 | # First pulling the metrics as we would for other single-value plots 114 | agg_metrics = make_agg_metrics_intervals(folders=folders, algos=algos, metric=metric, 115 | step=step, step_metric=step_metric) 116 | 117 | # Now building out the combinatorics dict 118 | results = {} 119 | 120 | for i in range(len(algos)): 121 | for j in range(len(algos)): 122 | if i == j: 123 | continue 124 | results[f'{algos[i]},{algos[j]}'] = (agg_metrics[algos[i]], agg_metrics[algos[j]]) 125 | 126 | return results 127 | 128 | 129 | def make_agg_metrics_efficiency(folders, algos, metric): 130 | """Pulls results for the 'Aggregate metrics with 95% Stratified Bootstrap CIs' plot 131 | Can also be used for "Performance Profiles" plot 132 | 133 | Below is an example usage for this function: 134 | make_agg_metrics_efficiency( 135 | folders=[folder, folder, folder, folder], 136 | algos=['ac', 'ac', 'dqn', 'dqn'], 137 | metric=['mean_reward', 'mean_reward', 'mean_reward', 'mean_reward'], 138 | ) 139 | 140 | Shape of the output data is {'algo_1': (n_runs x n_envs x n_steps), ...,} 141 | 142 | Args: 143 | folders (List[str]): 144 | algos (List[str]): 145 | metric (List[str]): 146 | step (List[int]): 147 | step_metric (List[str]): 148 | 149 | Returns: 150 | Dict of performance matrices 151 | """ 152 | step = [None for _ in range(len(algos))] 153 | step_metric = [None for _ in range(len(algos))] 154 | 155 | # Process for reading in the data 156 | results = {} 157 | 158 | for i in range(len(folders)): 159 | data = _load_data_from_subfolder(os.path.join(folders[i], algos[i]), metric[i], step[i], step_metric[i]) 160 | 161 | if algos[i] not in results.keys(): 162 | results[algos[i]] = [] 163 | 164 | results[algos[i]].append(data) 165 | 166 | results_T = {} 167 | 168 | for algo in results.keys(): 169 | pulled_results = results[algo] 170 | 171 | n_envs = len(pulled_results) 172 | n_runs = len(pulled_results[0]) 173 | n_steps = len(pulled_results[0][0]) 174 | 175 | 176 | results_T[algo] = np.array(pulled_results).reshape((n_runs, n_envs, n_steps)) 177 | 178 | return results_T 179 | -------------------------------------------------------------------------------- /marlbase/ac/train.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from pathlib import Path 3 | 4 | from einops import rearrange 5 | import numpy as np 6 | from gymnasium.spaces import flatdim 7 | import hydra 8 | from omegaconf import DictConfig 9 | import torch 10 | 11 | from marlbase.utils.video import VideoRecorder 12 | 13 | 14 | Batch = namedtuple( 15 | "Batch", ["obss", "actions", "rewards", "dones", "filled", "action_masks"] 16 | ) 17 | 18 | 19 | def _log_progress(infos, step, updates, logger): 20 | infos.append({"updates": updates, "environment_steps": step}) 21 | logger.log_metrics(infos) 22 | 23 | 24 | def _collect_trajectories( 25 | envs, model, max_ep_length, parallel_envs, n_agents, device, use_proper_termination 26 | ): 27 | running = torch.ones(parallel_envs, device=device, dtype=torch.bool) 28 | 29 | # creates and initialises storage 30 | obss, info = envs.reset() 31 | obss = [torch.tensor(o, device=device) for o in obss] 32 | parallel_envs = envs.observation_space[0].shape[0] 33 | obs_dim = flatdim(envs.single_observation_space) 34 | num_actions = max(action_space.n for action_space in envs.single_action_space) 35 | 36 | batch_obs = torch.zeros( 37 | max_ep_length + 1, 38 | parallel_envs, 39 | obs_dim, 40 | device=device, 41 | ) 42 | batch_done = torch.zeros( 43 | max_ep_length + 1, parallel_envs, device=device, dtype=torch.bool 44 | ) 45 | batch_act = torch.zeros( 46 | max_ep_length, parallel_envs, n_agents, device=device, dtype=torch.long 47 | ) 48 | batch_rew = torch.zeros(max_ep_length, parallel_envs, n_agents, device=device) 49 | batch_filled = torch.zeros(max_ep_length, parallel_envs, device=device) 50 | t = 0 51 | infos = [] 52 | 53 | # if the environment provides action masks, use them 54 | if "action_mask" in info: 55 | batch_action_masks = torch.ones( 56 | max_ep_length + 1, parallel_envs, n_agents, num_actions, device=device 57 | ) 58 | mask = np.stack(info["action_mask"], dtype=np.float32) 59 | action_mask = torch.tensor(mask, dtype=torch.float32, device=device) 60 | batch_action_masks[0] = action_mask 61 | action_mask = action_mask.swapaxes(0, 1) 62 | else: 63 | batch_action_masks = None 64 | action_mask = None 65 | 66 | # set initial obs 67 | batch_obs[0] = torch.cat(obss, dim=-1) 68 | 69 | actor_hiddens = model.init_actor_hiddens(parallel_envs) 70 | 71 | while running.any(): 72 | with torch.no_grad(): 73 | actions, actor_hiddens = model.act( 74 | obss, 75 | actor_hiddens, 76 | action_mask=action_mask, 77 | ) 78 | 79 | next_obss, rewards, done, truncated, info = envs.step( 80 | actions.squeeze().tolist() 81 | ) 82 | next_obss = [torch.tensor(o, device=device) for o in next_obss] 83 | 84 | done = torch.tensor(done, dtype=torch.bool, device=device) 85 | truncated = torch.tensor(truncated, dtype=torch.bool, device=device) 86 | if not use_proper_termination: 87 | # TODO: does this make sense? 88 | done = torch.logical_or(done, truncated) 89 | 90 | batch_obs[t + 1, running, :] = torch.cat(next_obss, dim=1)[running] 91 | batch_act[t, running] = rearrange(actions, "N B 1 -> B N")[running] 92 | batch_done[t + 1, running] = done[running] 93 | batch_rew[t, running] = torch.tensor(rewards, dtype=torch.float32, device=device)[running] 94 | batch_filled[t, running] = 1 95 | if "action_mask" in info: 96 | mask = np.stack(info["action_mask"], dtype=np.float32) 97 | action_mask = torch.tensor(mask, dtype=torch.float32, device=device) 98 | batch_action_masks[t + 1, running] = action_mask[running] 99 | action_mask = action_mask.swapaxes(0, 1) 100 | 101 | if done.any(): 102 | for i, d in enumerate(done): 103 | if d: 104 | assert ( 105 | "final_info" in info 106 | and info["final_info"][i] is not None 107 | and "episode_returns" in info["final_info"][i] 108 | ), "Finished episode info does not contain expected statistics." 109 | infos.append(info["final_info"][i]) 110 | running[i] = False 111 | 112 | t += 1 113 | obss = next_obss 114 | 115 | batch = Batch( 116 | batch_obs, batch_act, batch_rew, batch_done, batch_filled, batch_action_masks 117 | ) 118 | 119 | return t, batch, infos 120 | 121 | 122 | def record_episodes(env, model, n_timesteps, path, device): 123 | recorder = VideoRecorder() 124 | done = True 125 | 126 | for _ in range(n_timesteps): 127 | if done: 128 | obss, info = env.reset() 129 | hiddens = model.init_actor_hiddens(1) 130 | if "action_mask" in info: 131 | action_mask = torch.tensor( 132 | info["action_mask"], dtype=torch.float32, device=device 133 | ) 134 | else: 135 | action_mask = None 136 | done = False 137 | else: 138 | with torch.no_grad(): 139 | obss = torch.tensor(obss, dtype=torch.float32, device=device).unsqueeze( 140 | 1 141 | ) 142 | actions, hiddens = model.act(obss, hiddens, action_mask) 143 | obss, _, done, truncated, info = env.step([a.item() for a in actions]) 144 | if "action_mask" in info: 145 | action_mask = torch.tensor( 146 | info["action_mask"], dtype=torch.float32, device=device 147 | ) 148 | done = done or truncated 149 | recorder.record_frame(env) 150 | 151 | Path(path).parent.mkdir(parents=True, exist_ok=True) 152 | recorder.save(path) 153 | 154 | 155 | def main(envs, eval_env, logger, time_limit, **cfg): 156 | cfg = DictConfig(cfg) 157 | 158 | model = hydra.utils.instantiate( 159 | cfg.model, envs.single_observation_space, envs.single_action_space, cfg 160 | ) 161 | logger.watch(model) 162 | 163 | parallel_envs = envs.observation_space[0].shape[0] 164 | 165 | step = 0 166 | updates = 0 167 | last_eval = 0 168 | last_save = 0 169 | last_video = 0 170 | while step < cfg.total_steps + 1: 171 | t, batch, infos = _collect_trajectories( 172 | envs, 173 | model, 174 | time_limit, 175 | parallel_envs, 176 | model.n_agents, 177 | cfg.model.device, 178 | cfg.use_proper_termination, 179 | ) 180 | 181 | metrics = model.update(batch, step) 182 | infos.append(metrics) 183 | 184 | if (step - last_eval) >= cfg.eval_interval: 185 | _log_progress(infos, step, updates, logger) 186 | last_eval = step 187 | 188 | if cfg.save_interval and (step - last_save) >= cfg.save_interval: 189 | Path("checkpoints").mkdir(exist_ok=True) 190 | torch.save(model.state_dict(), f"checkpoints/model_s{step}.pt") 191 | last_save = step 192 | 193 | if cfg.video_interval and (step - last_video) >= cfg.video_interval: 194 | record_episodes( 195 | eval_env, 196 | model, 197 | cfg.video_frames, 198 | f"./videos/step-{step}.mp4", 199 | cfg.model.device, 200 | ) 201 | last_video = step 202 | 203 | updates += 1 204 | step += t * parallel_envs 205 | 206 | envs.close() 207 | -------------------------------------------------------------------------------- /marlbase/utils/models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | 8 | def orthogonal_init(m): 9 | nn.init.orthogonal_(m.weight.data, gain=np.sqrt(2)) 10 | nn.init.constant_(m.bias.data, 0) 11 | return m 12 | 13 | 14 | class FCNetwork(nn.Module): 15 | def __init__( 16 | self, dims, activation=nn.ReLU, final_activation=None, use_orthogonal_init=True 17 | ): 18 | """ 19 | Create fully-connected network 20 | :param dims: list of dimensions for the network 21 | :param activation: activation function to use 22 | :param final_activation: activation function to use on output (if any) 23 | :param use_orthogonal_init: whether to use orthogonal initialization 24 | :return: sequential network 25 | """ 26 | super().__init__() 27 | mods = [] 28 | 29 | input_size = dims[0] 30 | h_sizes = dims[1:] 31 | 32 | init_fn = orthogonal_init if use_orthogonal_init else lambda x: x 33 | 34 | mods = [init_fn(nn.Linear(input_size, h_sizes[0]))] 35 | for i in range(len(h_sizes) - 1): 36 | mods.append(activation()) 37 | mods.append(init_fn(nn.Linear(h_sizes[i], h_sizes[i + 1]))) 38 | 39 | if final_activation: 40 | mods.append(final_activation()) 41 | 42 | self.network = nn.Sequential(*mods) 43 | 44 | def init_hiddens(self, batch_size, device): 45 | return None 46 | 47 | def forward(self, x, h=None): 48 | return self.network(x), None 49 | 50 | 51 | class RNNNetwork(nn.Module): 52 | def __init__( 53 | self, 54 | dims, 55 | rnn=nn.GRU, 56 | activation=nn.ReLU, 57 | final_activation=None, 58 | use_orthogonal_init=True, 59 | ): 60 | """ 61 | Create recurrent network (last layer is fully connected) 62 | :param dims: list of dimensions for the network 63 | :param activation: activation function to use 64 | :param final_activation: activation function to use on output (if any) 65 | :param use_orthogonal_init: whether to use orthogonal initialization 66 | :return: sequential network 67 | """ 68 | super().__init__() 69 | assert ( 70 | len(dims) > 2 71 | ), "Need at least 3 dimensions for RNN (1 input dim, >= 1 hidden dim, 1 output dim)" 72 | 73 | assert rnn in [nn.GRU, nn.LSTM], "Only GRU and LSTM are supported" 74 | 75 | input_size = dims[0] 76 | rnn_hiddens = dims[1:-1] 77 | rnn_layers = len(rnn_hiddens) - 1 78 | rnn_hidden_size = rnn_hiddens[0] 79 | assert all( 80 | rnn_hidden_size == h for h in rnn_hiddens 81 | ), "Expect same hidden size across all RNN layers" 82 | output_size = dims[-1] 83 | 84 | self.first_layer = nn.Linear(input_size, rnn_hidden_size) 85 | self.rnn = rnn( 86 | input_size=rnn_hidden_size, 87 | hidden_size=rnn_hidden_size, 88 | num_layers=rnn_layers, 89 | batch_first=False, 90 | ) 91 | self.activation = activation() 92 | self.final_layer = nn.Linear(rnn_hidden_size, output_size) 93 | if use_orthogonal_init: 94 | self.final_layer = orthogonal_init(self.final_layer) 95 | 96 | self.final_activation = final_activation 97 | 98 | def init_hiddens(self, batch_size, device): 99 | return torch.zeros( 100 | self.rnn.num_layers, 101 | batch_size, 102 | self.rnn.hidden_size, 103 | device=device, 104 | ) 105 | 106 | def forward(self, x, h=None): 107 | assert x.dim() == 3, "Expect input to be 3D tensor (seq_len, batch, input_size)" 108 | assert ( 109 | h is None or h.dim() == 3 110 | ), "Expect hidden state to be 3D tensor (num_layers, batch, hidden_size)" 111 | x = self.activation(self.first_layer(x)) 112 | x, h = self.rnn(x, h) 113 | x = self.final_layer(x) 114 | if self.final_activation: 115 | x = self.final_activation(x) 116 | return x, h 117 | 118 | 119 | def make_network( 120 | dims, 121 | use_rnn=False, 122 | rnn=nn.GRU, 123 | activation=nn.ReLU, 124 | final_activation=None, 125 | use_orthogonal_init=True, 126 | ): 127 | if use_rnn: 128 | return RNNNetwork(dims, rnn, activation, final_activation, use_orthogonal_init) 129 | else: 130 | return FCNetwork(dims, activation, final_activation, use_orthogonal_init) 131 | 132 | 133 | class MultiAgentIndependentNetwork(nn.Module): 134 | def __init__( 135 | self, 136 | input_sizes, 137 | hidden_dims, 138 | output_sizes, 139 | use_rnn=False, 140 | use_orthogonal_init=True, 141 | ): 142 | super().__init__() 143 | assert len(input_sizes) == len( 144 | output_sizes 145 | ), "Expect same number of input and output sizes" 146 | self.independent = nn.ModuleList() 147 | 148 | for in_size, out_size in zip(input_sizes, output_sizes): 149 | dims = [in_size] + hidden_dims + [out_size] 150 | self.independent.append( 151 | make_network( 152 | dims, use_rnn=use_rnn, use_orthogonal_init=use_orthogonal_init 153 | ) 154 | ) 155 | 156 | def forward( 157 | self, 158 | inputs: Union[List[torch.Tensor], torch.Tensor], 159 | hiddens: Optional[List[torch.Tensor]] = None, 160 | ): 161 | if hiddens is None: 162 | hiddens = [None] * len(inputs) 163 | futures = [ 164 | torch.jit.fork(model, x, h) 165 | for model, x, h in zip(self.independent, inputs, hiddens) 166 | ] 167 | results = [torch.jit.wait(fut) for fut in futures] 168 | outs = [x for x, _ in results] 169 | hiddens = [h for _, h in results] 170 | return outs, hiddens 171 | 172 | def init_hiddens(self, batch_size, device): 173 | return [model.init_hiddens(batch_size, device) for model in self.independent] 174 | 175 | 176 | class MultiAgentSharedNetwork(nn.Module): 177 | def __init__( 178 | self, 179 | input_sizes, 180 | hidden_dims, 181 | output_sizes, 182 | sharing_indices, 183 | use_rnn=False, 184 | use_orthogonal_init=True, 185 | ): 186 | super().__init__() 187 | assert len(input_sizes) == len( 188 | output_sizes 189 | ), "Expect same number of input and output sizes" 190 | self.num_agents = len(input_sizes) 191 | 192 | if sharing_indices is True: 193 | self.sharing_indices = len(input_sizes) * [0] 194 | elif sharing_indices is False: 195 | self.sharing_indices = list(range(len(input_sizes))) 196 | else: 197 | self.sharing_indices = sharing_indices 198 | assert len(self.sharing_indices) == len( 199 | input_sizes 200 | ), "Expect same number of sharing indices as agents" 201 | 202 | self.num_networks = 0 203 | self.networks = nn.ModuleList() 204 | self.agents_by_network = [] 205 | self.input_sizes = [] 206 | self.output_sizes = [] 207 | created_networks = set() 208 | for i in self.sharing_indices: 209 | if i in created_networks: 210 | # network already created 211 | continue 212 | 213 | # agent indices that share this network 214 | network_agents = [ 215 | j for j, idx in enumerate(self.sharing_indices) if idx == i 216 | ] 217 | in_sizes = [input_sizes[j] for j in network_agents] 218 | in_size = in_sizes[0] 219 | assert all( 220 | idim == in_size for idim in in_sizes 221 | ), f"Expect same input sizes across all agents sharing network {i}" 222 | out_sizes = [output_sizes[j] for j in network_agents] 223 | out_size = out_sizes[0] 224 | assert all( 225 | odim == out_size for odim in out_sizes 226 | ), f"Expect same output sizes across all agents sharing network {i}" 227 | 228 | dims = [in_size] + hidden_dims + [out_size] 229 | self.networks.append( 230 | make_network( 231 | dims, use_rnn=use_rnn, use_orthogonal_init=use_orthogonal_init 232 | ) 233 | ) 234 | self.agents_by_network.append(network_agents) 235 | self.input_sizes.append(in_size) 236 | self.output_sizes.append(out_size) 237 | self.num_networks += 1 238 | created_networks.add(i) 239 | 240 | def forward( 241 | self, 242 | inputs: Union[List[torch.Tensor], torch.Tensor], 243 | hiddens: Optional[List[torch.Tensor]] = None, 244 | ): 245 | assert all( 246 | x.dim() == 3 for x in inputs 247 | ), "Expect each agent input to be 3D tensor (seq_len, batch, input_size)" 248 | assert hiddens is None or all( 249 | x is None or x.dim() == 3 for x in hiddens 250 | ), "Expect hidden state to be 3D tensor (num_layers, batch, hidden_size)" 251 | 252 | batch_size = inputs[0].size(1) 253 | assert all( 254 | x.size(1) == batch_size for x in inputs 255 | ), "Expect all agent inputs to have same batch size" 256 | 257 | # group inputs and hiddens by network 258 | network_inputs = [] 259 | network_hiddens = [] 260 | for agent_indices in self.agents_by_network: 261 | net_inputs = [inputs[i] for i in agent_indices] 262 | if hiddens is None or all(h is None for h in hiddens): 263 | net_hiddens = None 264 | else: 265 | net_hiddens = [hiddens[i] for i in agent_indices] 266 | network_inputs.append(torch.cat(net_inputs, dim=1)) 267 | network_hiddens.append( 268 | torch.cat(net_hiddens, dim=1) if net_hiddens is not None else None 269 | ) 270 | 271 | # forward through networks 272 | futures = [ 273 | torch.jit.fork(network, x, h) 274 | for network, x, h in zip(self.networks, network_inputs, network_hiddens) 275 | ] 276 | results = [torch.jit.wait(fut) for fut in futures] 277 | outs = [x.split(batch_size, dim=1) for x, _ in results] 278 | hiddens = [ 279 | h.split(batch_size, dim=1) if h is not None else None for _, h in results 280 | ] 281 | 282 | # group outputs by agents 283 | agent_outputs = [] 284 | agent_hiddens = [] 285 | self.idx_by_network = [0] * self.num_networks 286 | for network_idx in self.sharing_indices: 287 | idx_within_network = self.idx_by_network[network_idx] 288 | agent_outputs.append(outs[network_idx][idx_within_network]) 289 | if hiddens[network_idx] is not None: 290 | agent_hiddens.append(hiddens[network_idx][idx_within_network]) 291 | else: 292 | agent_hiddens.append(None) 293 | self.idx_by_network[network_idx] += 1 294 | return agent_outputs, agent_hiddens 295 | 296 | def init_hiddens(self, batch_size, device): 297 | return [ 298 | self.networks[network_idx].init_hiddens(batch_size, device) 299 | for network_idx in self.sharing_indices 300 | ] 301 | -------------------------------------------------------------------------------- /marlbase/dqn/train.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import math 3 | from pathlib import Path 4 | 5 | # from cpprb import ReplayBuffer 6 | import hydra 7 | import numpy as np 8 | from omegaconf import DictConfig 9 | import torch 10 | 11 | from marlbase.utils.video import VideoRecorder 12 | 13 | 14 | Batch = namedtuple( 15 | "Batch", ["obss", "actions", "rewards", "dones", "filled", "action_mask"] 16 | ) 17 | 18 | 19 | class ReplayBuffer: 20 | def __init__( 21 | self, 22 | buffer_size, 23 | n_agents, 24 | observation_space, 25 | action_space, 26 | max_episode_length, 27 | device, 28 | store_action_masks=False, 29 | ): 30 | self.buffer_size = buffer_size 31 | self.n_agents = n_agents 32 | self.max_episode_length = max_episode_length 33 | self.store_action_masks = store_action_masks 34 | self.device = device 35 | 36 | self.pos = 0 37 | self.cur_pos = 0 38 | self.t = 0 39 | 40 | self.observations = [ 41 | np.zeros( 42 | (max_episode_length + 1, buffer_size, *observation_space[i].shape), 43 | dtype=np.float32, 44 | ) 45 | for i in range(n_agents) 46 | ] 47 | self.actions = np.zeros( 48 | (n_agents, max_episode_length, buffer_size), dtype=np.int64 49 | ) 50 | self.rewards = np.zeros( 51 | (n_agents, max_episode_length, buffer_size), dtype=np.float32 52 | ) 53 | self.dones = np.zeros((max_episode_length + 1, buffer_size), dtype=bool) 54 | self.filled = np.zeros((max_episode_length, buffer_size), dtype=bool) 55 | if store_action_masks: 56 | action_dim = max(action_space.n for action_space in action_space) 57 | self.action_masks = np.zeros( 58 | (n_agents, max_episode_length + 1, buffer_size, action_dim), 59 | dtype=np.float32, 60 | ) 61 | 62 | def __len__(self): 63 | return min(self.pos, self.buffer_size) 64 | 65 | def init_episode(self, obss, action_masks=None): 66 | self.t = 0 67 | for i in range(self.n_agents): 68 | self.observations[i][0, self.cur_pos] = obss[i] 69 | if action_masks is not None: 70 | assert self.store_action_masks, "Action masks not stored in buffer!" 71 | self.action_masks[:, 0, self.cur_pos] = action_masks 72 | 73 | def add(self, obss, acts, rews, done, action_masks=None): 74 | assert self.t < self.max_episode_length, "Episode longer than given max length!" 75 | for i in range(self.n_agents): 76 | self.observations[i][self.t + 1, self.cur_pos] = obss[i] 77 | self.actions[:, self.t, self.cur_pos] = acts 78 | self.rewards[:, self.t, self.cur_pos] = rews 79 | self.dones[self.t + 1, self.cur_pos] = done 80 | self.filled[self.t, self.cur_pos] = True 81 | if action_masks is not None: 82 | assert self.store_action_masks, "Action masks not stored in buffer!" 83 | self.action_masks[:, self.t + 1, self.cur_pos] = action_masks 84 | self.t += 1 85 | 86 | if done: 87 | self.pos += 1 88 | self.cur_pos = self.pos % self.buffer_size 89 | self.t = 0 90 | 91 | def can_sample(self, batch_size): 92 | return self.pos >= batch_size 93 | 94 | def sample(self, batch_size): 95 | idx = np.random.randint(0, len(self), size=batch_size) 96 | obss = torch.stack( 97 | [ 98 | torch.tensor( 99 | self.observations[i][:, idx], 100 | dtype=torch.float32, 101 | device=self.device, 102 | ) 103 | for i in range(self.n_agents) 104 | ] 105 | ) 106 | actions = torch.tensor( 107 | self.actions[:, :, idx], dtype=torch.int64, device=self.device 108 | ) 109 | rewards = torch.tensor( 110 | self.rewards[:, :, idx], dtype=torch.float32, device=self.device 111 | ) 112 | dones = torch.tensor( 113 | self.dones[:, idx], dtype=torch.float32, device=self.device 114 | ) 115 | filled = torch.tensor( 116 | self.filled[:, idx], dtype=torch.float32, device=self.device 117 | ) 118 | if self.store_action_masks: 119 | action_mask = torch.tensor( 120 | self.action_masks[:, :, idx], dtype=torch.float32, device=self.device 121 | ) 122 | else: 123 | action_mask = None 124 | return Batch(obss, actions, rewards, dones, filled, action_mask) 125 | 126 | 127 | def _epsilon_schedule( 128 | decay_style, decay_over, eps_start, eps_end, exp_decay_rate, total_steps 129 | ): 130 | """ 131 | Exponential decay schedule for exploration epsilon. 132 | :param decay_style: style of epsilon schedule. One of "linear"/ "lin" or "exponential"/ "exp". 133 | :param decay_over: fraction of total steps over which to decay epsilon. 134 | :param eps_start: starting epsilon value. 135 | :param eps_end: ending epsilon value. 136 | :param exp_decay_rate: decay rate for exponential decay. 137 | :param total_steps: total number of steps to take. 138 | :return: Epsilon schedule function mapping step number to epsilon value. 139 | """ 140 | assert decay_style in [ 141 | "linear", 142 | "lin", 143 | "exponential", 144 | "exp", 145 | ], "decay_style must be one of 'linear' or 'exponential'" 146 | assert 0 <= eps_start <= 1 and 0 <= eps_end <= 1, "eps must be in [0, 1]" 147 | assert eps_start >= eps_end, "eps_start must be >= eps_end" 148 | assert 0 < decay_over <= 1, "decay_over must be in (0, 1]" 149 | assert total_steps > 0, "total_steps must be > 0" 150 | assert exp_decay_rate > 0, "eps_decay must be > 0" 151 | 152 | if decay_style in ["linear", "lin"]: 153 | 154 | def _thunk(steps_done): 155 | return max( 156 | eps_end 157 | + (eps_start - eps_end) * (1 - steps_done / (total_steps * decay_over)), 158 | eps_end, 159 | ) 160 | 161 | elif decay_style in ["exponential", "exp"]: 162 | # decaying over all steps 163 | # eps_decay = (eps_start - eps_end) / total_steps * exp_decay_rate 164 | # decaying over decay_over fraction of steps 165 | eps_decay = (eps_start - eps_end) / (total_steps * decay_over) * exp_decay_rate 166 | 167 | def _thunk(steps_done): 168 | return max( 169 | eps_end + (eps_start - eps_end) * math.exp(-eps_decay * steps_done), 170 | eps_end, 171 | ) 172 | else: 173 | raise ValueError("decay_style must be one of 'linear' or 'exponential'") 174 | return _thunk 175 | 176 | 177 | def _evaluate(env, model, eval_episodes, eval_epsilon): 178 | infos = [] 179 | while len(infos) < eval_episodes: 180 | obs, info = env.reset() 181 | hiddens = model.init_hiddens(1) 182 | action_mask = ( 183 | np.stack(info["action_mask"], dtype=np.float32) 184 | if "action_mask" in info 185 | else None 186 | ) 187 | done = False 188 | while not done: 189 | with torch.no_grad(): 190 | actions, hiddens = model.act(obs, hiddens, eval_epsilon, action_mask) 191 | obs, _, done, truncated, info = env.step(actions) 192 | done = done or truncated 193 | action_mask = ( 194 | np.stack(info["action_mask"], dtype=np.float32) 195 | if "action_mask" in info 196 | else None 197 | ) 198 | infos.append(info) 199 | return infos 200 | 201 | 202 | def _collect_trajectory(env, model, rb, epsilon, use_proper_termination): 203 | obss, info = env.reset() 204 | action_mask = ( 205 | np.stack(info["action_mask"], dtype=np.float32) 206 | if "action_mask" in info 207 | else None 208 | ) 209 | rb.init_episode(obss, action_mask) 210 | hiddens = model.init_hiddens(1) 211 | done = False 212 | t = 0 213 | 214 | while not done: 215 | with torch.no_grad(): 216 | actions, hiddens = model.act(obss, hiddens, epsilon, action_mask) 217 | next_obss, rews, done, truncated, info = env.step(actions) 218 | 219 | if use_proper_termination: 220 | # TODO: Previously was always False here? 221 | # also previously had other option "ignore"? Why was that separate from "ignore"? 222 | proper_done = done 223 | else: 224 | # here previously was always done? 225 | proper_done = done or truncated 226 | done = done or truncated 227 | action_mask = ( 228 | np.stack(info["action_mask"], dtype=np.float32) 229 | if "action_mask" in info 230 | else None 231 | ) 232 | 233 | rb.add(next_obss, actions, rews, proper_done, action_mask) 234 | t += 1 235 | obss = next_obss 236 | 237 | return t, info 238 | 239 | 240 | def record_episodes(env, model, n_timesteps, path, device, epsilon): 241 | recorder = VideoRecorder() 242 | done = True 243 | 244 | for _ in range(n_timesteps): 245 | if done: 246 | obss, info = env.reset() 247 | hiddens = model.init_hiddens(1) 248 | done = False 249 | else: 250 | action_mask = ( 251 | np.stack(info["action_mask"], dtype=np.float32) 252 | if "action_mask" in info 253 | else None 254 | ) 255 | with torch.no_grad(): 256 | actions, hiddens = model.act(obss, hiddens, epsilon, action_mask) 257 | obss, _, done, truncated, info = env.step(actions) 258 | recorder.record_frame(env) 259 | 260 | Path(path).parent.mkdir(parents=True, exist_ok=True) 261 | recorder.save(path) 262 | 263 | 264 | def main(env, eval_env, logger, time_limit, **cfg): 265 | cfg = DictConfig(cfg) 266 | 267 | _, info = env.reset() 268 | model = hydra.utils.instantiate( 269 | cfg.model, env.observation_space, env.action_space, cfg 270 | ) 271 | 272 | logger.watch(model) 273 | 274 | rb = ReplayBuffer( 275 | cfg.buffer_size, 276 | env.unwrapped.n_agents, 277 | env.observation_space, 278 | env.action_space, 279 | time_limit, 280 | cfg.model.device, 281 | store_action_masks="action_mask" in info, 282 | ) 283 | 284 | eps_sched = _epsilon_schedule( 285 | cfg.eps_decay_style, 286 | cfg.eps_decay_over, 287 | cfg.eps_start, 288 | cfg.eps_end, 289 | cfg.eps_exp_decay_rate, 290 | cfg.total_steps, 291 | ) 292 | 293 | updates = 0 294 | step = 0 295 | last_eval = 0 296 | last_video = 0 297 | last_save = 0 298 | while step < cfg.total_steps + 1: 299 | t, _ = _collect_trajectory( 300 | env, 301 | model, 302 | rb, 303 | eps_sched(step), 304 | cfg.use_proper_termination, 305 | ) 306 | step += t 307 | 308 | if step > cfg.training_start and rb.can_sample(cfg.batch_size): 309 | batch = rb.sample(cfg.batch_size) 310 | metrics = model.update(batch) 311 | updates += 1 312 | else: 313 | metrics = {} 314 | 315 | if cfg.eval_interval and (step - last_eval) >= cfg.eval_interval: 316 | infos = _evaluate(eval_env, model, cfg.eval_episodes, cfg.eps_evaluation) 317 | if metrics: 318 | infos.append(metrics) 319 | infos.append( 320 | { 321 | "updates": updates, 322 | "environment_steps": step, 323 | "epsilon": eps_sched(step), 324 | } 325 | ) 326 | logger.log_metrics(infos) 327 | last_eval = step 328 | 329 | if cfg.video_interval and (step - last_video) >= cfg.video_interval: 330 | record_episodes( 331 | eval_env, 332 | model, 333 | cfg.video_frames, 334 | f"./videos/step-{step}.mp4", 335 | cfg.model.device, 336 | cfg.eps_evaluation, 337 | ) 338 | last_video = step 339 | 340 | if cfg.save_interval and (step - last_save) >= cfg.save_interval: 341 | Path("checkpoints").mkdir(exist_ok=True) 342 | torch.save(model.state_dict(), f"checkpoints/model_s{step}.pt") 343 | last_save = step 344 | 345 | env.close() 346 | -------------------------------------------------------------------------------- /marlbase/ac/model.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from einops import rearrange 4 | from gymnasium.spaces import flatdim 5 | import torch 6 | from torch.distributions import Categorical 7 | import torch.nn as nn 8 | from torch import optim 9 | 10 | from marlbase.utils.models import MultiAgentIndependentNetwork, MultiAgentSharedNetwork 11 | from marlbase.utils.utils import MultiCategorical, compute_nstep_returns 12 | from marlbase.utils.standardise_stream import RunningMeanStd 13 | 14 | 15 | def _split_batch(splits): 16 | def thunk(batch): 17 | return torch.split(batch, splits, dim=-1) 18 | 19 | return thunk 20 | 21 | 22 | class A2CNetwork(nn.Module): 23 | def __init__( 24 | self, 25 | obs_space, 26 | action_space, 27 | cfg, 28 | actor, 29 | critic, 30 | device, 31 | ): 32 | super(A2CNetwork, self).__init__() 33 | self.gamma = cfg.gamma 34 | self.entropy_coef = cfg.entropy_coef 35 | self.n_steps = cfg.n_steps 36 | self.grad_clip = cfg.grad_clip 37 | self.value_loss_coef = cfg.value_loss_coef 38 | self.device = device 39 | 40 | self.n_agents = len(obs_space) 41 | obs_dims = [flatdim(o) for o in obs_space] 42 | act_dims = [flatdim(a) for a in action_space] 43 | 44 | if not actor.parameter_sharing: 45 | self.actor = MultiAgentIndependentNetwork( 46 | obs_dims, 47 | list(actor.layers), 48 | act_dims, 49 | actor.use_rnn, 50 | actor.use_orthogonal_init, 51 | ) 52 | else: 53 | self.actor = MultiAgentSharedNetwork( 54 | obs_dims, 55 | list(actor.layers), 56 | act_dims, 57 | actor.parameter_sharing, 58 | actor.use_rnn, 59 | actor.use_orthogonal_init, 60 | ) 61 | 62 | self.centralised_critic = critic.centralised 63 | critic_obs_shape = ( 64 | self.n_agents * [sum(obs_dims)] if critic.centralised else obs_dims 65 | ) 66 | 67 | if not critic.parameter_sharing: 68 | self.critic = MultiAgentIndependentNetwork( 69 | critic_obs_shape, 70 | list(critic.layers), 71 | [1] * self.n_agents, 72 | critic.use_rnn, 73 | critic.use_orthogonal_init, 74 | ) 75 | self.target_critic = MultiAgentIndependentNetwork( 76 | critic_obs_shape, 77 | list(critic.layers), 78 | [1] * self.n_agents, 79 | critic.use_rnn, 80 | critic.use_orthogonal_init, 81 | ) 82 | else: 83 | self.critic = MultiAgentSharedNetwork( 84 | critic_obs_shape, 85 | list(critic.layers), 86 | [1] * self.n_agents, 87 | critic.parameter_sharing, 88 | critic.use_rnn, 89 | critic.use_orthogonal_init, 90 | ) 91 | self.target_critic = MultiAgentSharedNetwork( 92 | critic_obs_shape, 93 | list(critic.layers), 94 | [1] * self.n_agents, 95 | critic.parameter_sharing, 96 | critic.use_rnn, 97 | critic.use_orthogonal_init, 98 | ) 99 | 100 | self.soft_update(1.0) 101 | self.to(device) 102 | 103 | optimizer = getattr(optim, cfg.optimizer) 104 | if type(optimizer) is str: 105 | optimizer = getattr(optim, optimizer) 106 | self.optimizer_class = optimizer 107 | 108 | lr = cfg.lr 109 | self.optimizer = optimizer(self.parameters(), lr=lr) 110 | self.target_update_interval_or_tau = cfg.target_update_interval_or_tau 111 | 112 | self.standardise_returns = cfg.standardise_returns 113 | if self.standardise_returns: 114 | self.ret_ms = RunningMeanStd(shape=(self.n_agents,), device=device) 115 | 116 | self.split_obs = _split_batch([flatdim(s) for s in obs_space]) 117 | self.split_act = _split_batch(self.n_agents * [1]) 118 | 119 | print(self) 120 | 121 | def init_critic_hiddens(self, batch_size, target=False): 122 | if target: 123 | return self.target_critic.init_hiddens(batch_size, self.device) 124 | else: 125 | return self.critic.init_hiddens(batch_size, self.device) 126 | 127 | def init_actor_hiddens(self, batch_size): 128 | return self.actor.init_hiddens(batch_size, self.device) 129 | 130 | def forward(self, inputs, rnn_hxs, masks): 131 | raise NotImplementedError( 132 | "Forward not implemented. Use act, get_value, get_target_value or evaluate_actions instead." 133 | ) 134 | 135 | def get_dist(self, action_logits, action_mask=None): 136 | if action_mask is not None: 137 | masked_logits = [] 138 | for logits, mask in zip(action_logits, action_mask): 139 | masked_logits.append(logits * mask + (1 - mask) * -1e8) 140 | action_logits = masked_logits 141 | 142 | dist = MultiCategorical( 143 | [Categorical(logits=logits) for logits in action_logits] 144 | ) 145 | return dist 146 | 147 | def act(self, inputs, actor_hiddens, action_mask=None): 148 | inputs = [i.unsqueeze(0) for i in inputs] 149 | actor_logits, actor_hiddens = self.actor(inputs, actor_hiddens) 150 | actor_logits = [logits.squeeze(0) for logits in actor_logits] 151 | dist = self.get_dist(actor_logits, action_mask) 152 | actions = dist.sample() 153 | return torch.stack(actions, dim=0), actor_hiddens 154 | 155 | def get_value(self, inputs, critic_hiddens, target=False): 156 | if self.centralised_critic: 157 | inputs = self.n_agents * [torch.cat(inputs, dim=-1)] 158 | 159 | if target: 160 | values, critic_hiddens = self.target_critic(inputs, critic_hiddens) 161 | else: 162 | values, critic_hiddens = self.critic(inputs, critic_hiddens) 163 | return torch.cat(values, dim=-1), critic_hiddens 164 | 165 | def evaluate_actions( 166 | self, 167 | inputs, 168 | action, 169 | critic_hiddens, 170 | actor_hiddens, 171 | action_mask=None, 172 | state=None, 173 | ): 174 | if state is None: 175 | state = inputs 176 | value, critic_hiddens = self.get_value(state, critic_hiddens) 177 | actor_features, actor_hiddens = self.actor(inputs, actor_hiddens) 178 | dist = self.get_dist(actor_features, action_mask) 179 | action_log_probs = torch.cat(dist.log_probs(action), dim=-1) 180 | dist_entropy = torch.stack(dist.entropy(), dim=-1).sum(dim=-1) 181 | 182 | return (value, action_log_probs, dist_entropy, critic_hiddens, actor_hiddens) 183 | 184 | def soft_update(self, t): 185 | source, target = self.critic, self.target_critic 186 | for target_param, source_param in zip(target.parameters(), source.parameters()): 187 | target_param.data.copy_((1 - t) * target_param.data + t * source_param.data) 188 | 189 | def update(self, batch, step): 190 | with torch.no_grad(): 191 | next_value, _ = self.get_value( 192 | self.split_obs(batch.obss), critic_hiddens=None, target=True 193 | ) 194 | 195 | if self.standardise_returns: 196 | next_value = next_value * torch.sqrt(self.ret_ms.var) + self.ret_ms.mean 197 | 198 | batch_done = batch.dones.float().unsqueeze(-1).repeat(1, 1, self.n_agents) 199 | returns = compute_nstep_returns( 200 | batch.rewards, batch_done, next_value, self.n_steps, self.gamma 201 | ) 202 | if self.standardise_returns: 203 | self.ret_ms.update(returns) 204 | returns = (returns - self.ret_ms.mean) / torch.sqrt(self.ret_ms.var) 205 | 206 | values, action_log_probs, entropy, _, _ = self.evaluate_actions( 207 | self.split_obs(batch.obss[:-1]), 208 | self.split_act(batch.actions), 209 | critic_hiddens=None, 210 | actor_hiddens=None, 211 | action_mask=rearrange(batch.action_masks[:-1], "E B N A -> N E B A") 212 | if batch.action_masks is not None 213 | else None, 214 | ) 215 | 216 | advantage = returns - values 217 | 218 | actor_loss = ( 219 | -(action_log_probs * advantage.detach()).sum(dim=-1) 220 | - self.entropy_coef * entropy 221 | ) 222 | actor_loss = (actor_loss * batch.filled).sum() / batch.filled.sum() 223 | value_loss = (returns - values).pow(2).sum(dim=-1) 224 | value_loss = (value_loss * batch.filled).sum() / batch.filled.sum() 225 | 226 | loss = actor_loss + self.value_loss_coef * value_loss 227 | self.optimizer.zero_grad() 228 | loss.backward() 229 | if self.grad_clip: 230 | torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) 231 | self.optimizer.step() 232 | 233 | if ( 234 | self.target_update_interval_or_tau > 1.0 235 | and step % self.target_update_interval_or_tau == 0 236 | ): 237 | self.soft_update(1.0) 238 | elif self.target_update_interval_or_tau < 1.0: 239 | self.soft_update(self.target_update_interval_or_tau) 240 | 241 | return { 242 | "loss": loss.item(), 243 | "actor_loss": actor_loss.item(), 244 | "value_loss": value_loss.item(), 245 | "entropy": ((entropy * batch.filled).sum() / batch.filled.sum()).item(), 246 | } 247 | 248 | 249 | class PPONetwork(A2CNetwork): 250 | def __init__( 251 | self, 252 | obs_space, 253 | action_space, 254 | cfg, 255 | actor, 256 | critic, 257 | device, 258 | ): 259 | super(PPONetwork, self).__init__( 260 | obs_space, action_space, cfg, actor, critic, device 261 | ) 262 | self.num_epochs = cfg.num_epochs 263 | self.ppo_clip = cfg.ppo_clip 264 | 265 | def update(self, batch, step): 266 | # compute returns 267 | with torch.no_grad(): 268 | next_value, _ = self.get_value( 269 | self.split_obs(batch.obss), critic_hiddens=None, target=True 270 | ) 271 | 272 | if self.standardise_returns: 273 | next_value = next_value * torch.sqrt(self.ret_ms.var) + self.ret_ms.mean 274 | 275 | batch_done = batch.dones.float().unsqueeze(-1).repeat(1, 1, self.n_agents) 276 | returns = compute_nstep_returns( 277 | batch.rewards, batch_done, next_value, self.n_steps, self.gamma 278 | ).detach() 279 | if self.standardise_returns: 280 | self.ret_ms.update(returns) 281 | returns = (returns - self.ret_ms.mean) / torch.sqrt(self.ret_ms.var) 282 | 283 | # compute old policy log probs 284 | with torch.no_grad(): 285 | actor_features, _ = self.actor(self.split_obs(batch.obss[:-1]), None) 286 | dist = self.get_dist( 287 | actor_features, 288 | action_mask=rearrange(batch.action_masks[:-1], "E B N A -> N E B A") 289 | if batch.action_masks is not None 290 | else None, 291 | ) 292 | old_action_log_probs = torch.cat( 293 | dist.log_probs(self.split_act(batch.actions)), dim=-1 294 | ).detach() 295 | 296 | metrics = defaultdict(list) 297 | for _ in range(self.num_epochs): 298 | # sample from current policy 299 | values, action_log_probs, entropy, _, _ = self.evaluate_actions( 300 | self.split_obs(batch.obss[:-1]), 301 | self.split_act(batch.actions), 302 | critic_hiddens=None, 303 | actor_hiddens=None, 304 | action_mask=rearrange(batch.action_masks[:-1], "E B N A -> N E B A") 305 | if batch.action_masks is not None 306 | else None, 307 | ) 308 | 309 | # compute advantage and value loss 310 | advantage = returns - values 311 | value_loss = advantage.pow(2).sum(dim=-1) 312 | 313 | # compute policy loss 314 | ratio = torch.exp(action_log_probs - old_action_log_probs) 315 | surr1 = ratio * advantage.detach() 316 | surr2 = ( 317 | torch.clamp(ratio, 1.0 - self.ppo_clip, 1.0 + self.ppo_clip) 318 | * advantage.detach() 319 | ) 320 | actor_loss = ( 321 | -torch.min(surr1, surr2).sum(dim=-1) - self.entropy_coef * entropy 322 | ) 323 | 324 | # apply masks and compute total loss per epoch 325 | actor_loss = (actor_loss * batch.filled).sum() / batch.filled.sum() 326 | value_loss = (value_loss * batch.filled).sum() / batch.filled.sum() 327 | loss = actor_loss + self.value_loss_coef * value_loss 328 | 329 | # optimisation step 330 | self.optimizer.zero_grad() 331 | loss.backward() 332 | if self.grad_clip: 333 | torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) 334 | self.optimizer.step() 335 | 336 | metrics["loss"].append(loss.item()) 337 | metrics["actor_loss"].append(actor_loss.item()) 338 | metrics["value_loss"].append(value_loss.item()) 339 | metrics["entropy"].append( 340 | ((entropy * batch.filled).sum() / batch.filled.sum()).item() 341 | ) 342 | 343 | # update target network after last epoch 344 | if ( 345 | self.target_update_interval_or_tau > 1.0 346 | and step % self.target_update_interval_or_tau == 0 347 | ): 348 | self.soft_update(1.0) 349 | elif self.target_update_interval_or_tau < 1.0: 350 | self.soft_update(self.target_update_interval_or_tau) 351 | 352 | return {key: sum(values) / len(values) for key, values in metrics.items()} 353 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |

Multi-Agent Reinforcement Learning: 3 | Foundations and Modern Approaches

4 | 5 |

6 | Book Codebase: www.marl-book.com 7 |

8 | 9 | Cite the book using: 10 | ```latex 11 | @book{marl-book, 12 | author = {Stefano V. Albrecht and Filippos Christianos and Lukas Sch\"afer}, 13 | title = {Multi-Agent Reinforcement Learning: Foundations and Modern Approaches}, 14 | publisher = {MIT Press}, 15 | year = {2024}, 16 | url = {https://www.marl-book.com} 17 | } 18 | ``` 19 | 20 | This codebase is part of the [MARL book](http://www.marl-book.com) and provides access to basic and easy-to-understand MARL ideas. 21 | The algorithms are self-contained and the implementations are focusing on simplicity. 22 | Implementation tricks, while necessary for some algorithms, are sparse as not to make the code very complicated. As a result, some performance has been sacrificed. 23 | 24 | All algorithms are implemented in [_PyTorch_](https://pytorch.org/) and use the [_Gymnasium_](https://gymnasium.farama.org/) interface. 25 | 26 |

Table of Contents

27 | 28 | - [Getting Started](#getting-started) 29 | - [Installation](#installation) 30 | - [Running an algorithm](#running-an-algorithm) 31 | - [(Optional) Use Hydra's tab completion](#optional-use-hydras-tab-completion) 32 | - [Running a hyperparameter search](#running-a-hyperparameter-search) 33 | - [An advanced hyperparameter search using `search.py`](#an-advanced-hyperparameter-search-using-searchpy) 34 | - [Logging](#logging) 35 | - [File System Logger](#file-system-logger) 36 | - [WandB Logger](#wandb-logger) 37 | - [Implementing your own algorithm/ideas](#implementing-your-own-algorithmideas) 38 | - [Interpreting your results](#interpreting-your-results) 39 | - [Implemented Algorithms](#implemented-algorithms) 40 | - [Parameter Sharing](#parameter-sharing) 41 | - [Value Decomposition](#value-decomposition) 42 | - [Contact](#contact) 43 | 44 | 45 | # Getting Started 46 | 47 | ## Installation 48 | 49 | We *strongly* suggest you use a virtual environment for the instructions below. A good starting point is [Miniconda](https://docs.conda.io/en/latest/miniconda.html), with which you would do: 50 | 51 | ```sh 52 | conda create -n marlbase python=3.10 53 | conda activate marlbase 54 | ``` 55 | 56 | Then, clone and install the repository using: 57 | 58 | ```sh 59 | git clone https://github.com/marl-book/codebase.git 60 | cd codebase 61 | pip install -r requirements.txt 62 | pip install -e . 63 | ``` 64 | Do not forget to install PyTorch in your environment. Instructions for your system/setup can be found here: https://pytorch.org/get-started/locally/ 65 | 66 | ## Running an algorithm 67 | This project uses [Hydra](https://hydra.cc/) to structure its configuration. Algorithm implementations can be found under `marlbase/`. The respective configs are found in `marlbase/configs/algorithms/`. 68 | 69 | You would first need an environment that is registered in Gymnasium. This repository uses the Gymnasium API (with the only difference being that the rewards are a tuple or list - one for each agent). 70 | 71 | A good starting point would be [Level-based Foraging](https://github.com/uoe-agents/lb-foraging) and [RWARE](https://github.com/uoe-agents/robotic-warehouse). You can install both using: 72 | ```sh 73 | pip install -U lbforaging rware 74 | ``` 75 | 76 | Then, running an algorithm (e.g. IA2C) looks like: 77 | 78 | ```sh 79 | cd marlbase 80 | python run.py +algorithm=ia2c env.name="lbforaging:Foraging-8x8-2p-3f-v3" env.time_limit=25 81 | ``` 82 | 83 | Similarly, running IDQN can be done using: 84 | ```sh 85 | python run.py +algorithm=idqn env.name="lbforaging:Foraging-8x8-2p-3f-v3" env.time_limit=25 86 | ``` 87 | 88 | Overriding hyperparameters is easy and can be done in the command line. An example of overriding the `batch_size` in IDQN: 89 | ```sh 90 | python run.py +algorithm=idqn env.name="lbforaging:Foraging-8x8-2p-3f-v3" env.time_limit=25 algorithm.batch_size=256 91 | ``` 92 | 93 | Find other hyperparameters in the files under `marlbase/configs/algorithm`. 94 | 95 | ### (Optional) Use Hydra's tab completion 96 | Hydra also supports tab completion for filling in the hyperparameters. See [here](https://hydra.cc/docs/tutorials/basic/running_your_app/tab_completion), and install it with: 97 | ```sh 98 | eval "$(python run.py -sc install=bash)" 99 | ``` 100 | ## Running a hyperparameter search 101 | 102 | Can be easily done using [Hydra's multirun](https://hydra.cc/docs/tutorials/basic/running_your_app/multi-run) option. An example of sweeping over batch sizes is: 103 | 104 | ```sh 105 | python run.py -m +algorithm=idqn env.name="lbforaging:Foraging-8x8-2p-3f-v3" env.time_limit=25 algorithm.batch_size=32,64,128 106 | ``` 107 | 108 | ### An advanced hyperparameter search using `search.py` 109 | *This section might get deprecated in the future if Hydra implements this feature.* 110 | 111 | We include a script named `search.py` which reads a search configuration file (e.g. the included `configs/sweeps/sample.yaml`) and runs a hyperparameter search in one or more tasks. The script can be run using 112 | ```sh 113 | python search.py run --config configs/sweeps/sample.yaml --seeds 5 locally 114 | ``` 115 | In a cluster environment where one run should go to a single process, it can also be called in a batch script like: 116 | ```sh 117 | python search.py run --config configs/sweeps/sample.yaml --seeds 5 single $TASK_ID 118 | ``` 119 | Where `$TASK_ID` is an index for the experiment (i.e. 1...#number of experiments). 120 | 121 | ## Logging 122 | We implement two loggers: FileSystem Logger and WandB Logger. 123 | 124 | ### File System Logger 125 | The default logger is the FileSystemLogger which saves experiment results in a `results.csv` file. You can find that file, the configuration that has been used & more under `outputs/{env_name}/{alg_name}/{random_hash}` or `multirun/{date}/{time}/{experiment_id}` for multiruns. 126 | ### WandB Logger 127 | By appending `+logger=wandb` in the command line you can get support for WandB. Do not forget to `wandb login` first. 128 | 129 | Example: 130 | 131 | ```sh 132 | python run.py +algorithm=idqn env.name="lbforaging:Foraging-8x8-2p-3f-v3" env.time_limit=25 logger=wandb 133 | ``` 134 | You can override the project name using: 135 | 136 | ```sh 137 | python run.py +algorithm=idqn env.name="lbforaging:Foraging-8x8-2p-3f-v3" env.time_limit=25 logger=wandb logger.project_name="my-project-name" 138 | ``` 139 | 140 | # Implementing your own algorithm/ideas 141 | 142 | The fastest way would be to create a new folder starting from the algorithm of your choice e.g. 143 | ```sh 144 | cp -R ac ac_new_idea 145 | ``` 146 | and create a new configuration file: 147 | ```sh 148 | cp configs/algorithm/ia2c.yaml configs/algorithm/ac_new_idea.yaml 149 | ``` 150 | 151 | with the editor of your choice, open `ac_new_idea.yaml` and change 152 | ```yaml 153 | ... 154 | algorithm: 155 | _target_: ac.train.main 156 | name: "ac" 157 | model: 158 | _target_: ac.model.A2CNetwork 159 | ... 160 | ``` 161 | to 162 | ```yaml 163 | ... 164 | algorithm: 165 | _target_: ac_new_idea.train.main 166 | name: "ac_new_idea" 167 | model: 168 | _target_: ac_new_idea.model.NewNetwork 169 | ... 170 | ``` 171 | Make any changes you want to the files under `ac_new_idea/` and run it using: 172 | 173 | ```sh 174 | python run.py +algorithm=ac_new_idea env.name="lbforaging:Foraging-8x8-2p-3f-v3" env.time_limit=25 175 | ``` 176 | You can now add new hyperparameters, change the training procedure, or anything else you want and keep the old implementations for easy comparison. We hope that the way we have implemented these algorithms makes it easy to change any part of the algorithm without the hustle of reading through large code-bases and huge unnecessary layers of abstraction. RL research benefits from iterating over ideas quickly to see how they perform! 177 | 178 | # Interpreting your results 179 | 180 | We have multiple tools to analyze the outputs of FileSystemLogger (for WandBLogger, just login to their webpage). 181 | 182 | You can easily find the best hyperparameter configuration per environment using: 183 | ```sh 184 | python utils/postprocessing/find_best_hyperparams.py --source 185 | ``` 186 | By default, this script will determine the best hyperparameters based on the average total returns across all evaluations and seeds. To use a different metric, you can specify the desired metric (from the `results.csv` files) with the `--metric` argument. 187 | 188 | Similarly, you can plot the stored runs (average/std across seeds) using: 189 | ```sh 190 | python utils/postprocessing/plot_runs.py --source 191 | ``` 192 | By default, this will visualise the mean and std across seeds of the `mean_episode_returns` metric. You can specify the metric to plot using the `--metric` argument. You can also provide the additional `--save_path` argument to save the plot as a `.pdf` file. 193 | 194 | We also provide a script to export the data of multiple runs as a pandas dataframe using: 195 | ```sh 196 | python utils/postprocessing/export_multirun.py --folder folder/containing/results --export-file myfile.hd5 197 | ``` 198 | The file will contain two pandas DataFrames: `df` which contains all `mean_episode_returns` (by default summed across all agents), and `config` which contains information about the tested hyperparameters. 199 | You can load both through Python using: 200 | ```python 201 | import pandas as pd 202 | df = pd.read_hdf("myfile.hd5", "df") 203 | configs = pd.read_hdf("myfile.hd5", "configs") 204 | ``` 205 | The imported DataFrames look like the ones below. `df` has a multi-index column indexing the environment name, the algorithm name, a hash unique to the parameter search, and a seed. `configs` maps the hash to the full configuration of the run. 206 | 207 | ```ipython 208 | In [1]: df 209 | Out[2]: 210 | Foraging-20x20-9p-6f-v3 ... 211 | Algo1 ... Algo2 212 | f7c2ecb3ddf1 ... 5284ad99ce02 213 | seed=0 seed=1 ... seed=0 seed=1 214 | environment_steps ... 215 | 0 0.178373 0.000000 ... 0.089167 0.054286 216 | 100000 0.026786 0.066667 ... 0.054545 0.033333 217 | 200000 0.130278 0.084650 ... 0.043333 0.055833 218 | 300000 0.086111 0.109975 ... 0.182626 0.116768 219 | ... 220 | 221 | In [3]: configs 222 | Out[4]: 223 | algorithm.name algorithm.lr algorithm.batch_size 224 | f7c2ecb3ddf1 DQN-FuPS 0.0001 256 225 | ecaf120f572e DQN-SePS 0.0001 128 226 | 5a80fe220cfc DQN-SePS 0.0003 128 227 | d16939a558b6 DQN-FuPS 0.0003 256 228 | ... 229 | ``` 230 | 231 | Finally you can use [HiPlot](https://github.com/facebookresearch/hiplot) to interactively visualize the performance of various hyperparameter configurations using: 232 | ```sh 233 | pip install -U hiplot 234 | hiplot marlbase.utils.postprocessing.hiplot_fetcher.experiment_fetcher 235 | ``` 236 | You will have to enter `exp://myfile.hd5/env_name/alg_name` in the browser's textbox. 237 | 238 | 239 | # Implemented Algorithms 240 | 241 | | | IA2C | MA-A2C | IPPO | MA-PPO | DQN (Double Q) | VDN | QMIX | 242 | |-----------------------------|---------------------|--------------------|--------------------|-------------------|--------------------|---------------------|---------------------| 243 | | Parameter Sharing | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 244 | | Selective Parameter Sharing | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 245 | | Return Standardisation | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 246 | | Reward Standardisation | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 247 | | Target Networks | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | 248 | 249 | 250 | ## Parameter Sharing 251 | 252 | Parameter sharing across agents is optional and being done behind the scenes in the torch model. 253 | There are three types of parameter sharing: 254 | - No Parameter Sharing (default) 255 | - Full Parameter Sharing 256 | - Selective Parameter Sharing ([Christianos et al.](https://arxiv.org/pdf/2102.07475.pdf)) 257 | 258 | For example, for IDQN you can enable either of these using: 259 | ```sh 260 | python run.py +algorithm=dqn env.name="lbforaging:Foraging-8x8-4p-3f-v3" env.time_limit=25 algorithm.model.parameter_sharing=False 261 | python run.py +algorithm=dqn env.name="lbforaging:Foraging-8x8-4p-3f-v3" env.time_limit=25 algorithm.model.parameter_sharing=True 262 | python run.py +algorithm=dqn env.name="lbforaging:Foraging-8x8-4p-3f-v3" env.time_limit=25 "algorithm.model.parameter_sharing=[0,0,1,1]" 263 | ``` 264 | for each of the methods respectively. For Selective Parameter Sharing, you need to supply a list of indices pointing to the network that is going to be used for each agent. Example: `[0,0,1,1]` as above makes the agents `0` and `1` share network `0` and agents `2` and `3` share the network `1`. Similarly `[0,1,1,1]` would make the first agent not share parameters with anyone, and the other three would share parameters. 265 | 266 | In actor-critic methods you would need to separately define parameter sharing for the actor and the critic. The respective config is `algorithm.model.actor.parameter_sharing=...` and `algorithm.model.critic.parameter_sharing=...` 267 | 268 | ## Value Decomposition 269 | 270 | We have implemented VDN and QMIX on top of the DQN algorithm. To use load the respective algorithm config with: 271 | 272 | ```sh 273 | python run.py +algorithm=vdn env.name="lbforaging:Foraging-8x8-4p-3f-v3" env.time_limit=25 274 | ``` 275 | 276 | Note that for this to work we use the `CooperativeReward` wrapper that _sums_ the rewards of all agents before feeding them to the training algorithm. If you have an environment that already has a cooperative reward, you still need it to return a *list of rewards* (e.g. `reward = n_agents * [reward/n_agents]`). 277 | 278 | 279 | # Contact 280 | - Filippos Christianos - filippos {dot} christianos {at} gmail {dot} com 281 | - Lukas Schäfer - luki {dot} schaefer96 {at} gmail {dot} com 282 | 283 | Based on: https://github.com/semitable/fast-marl (by Filippos Christianos) 284 | 285 | -------------------------------------------------------------------------------- /marlbase/dqn/model.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from einops import rearrange 4 | from gymnasium.spaces import flatdim 5 | import torch 6 | from torch import optim 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from marlbase.utils.models import MultiAgentSharedNetwork, MultiAgentIndependentNetwork 11 | from marlbase.utils.standardise_stream import RunningMeanStd 12 | 13 | 14 | class QNetwork(nn.Module): 15 | def __init__( 16 | self, 17 | obs_space, 18 | action_space, 19 | cfg, 20 | layers, 21 | parameter_sharing, 22 | use_rnn, 23 | use_orthogonal_init, 24 | device, 25 | ): 26 | super().__init__() 27 | hidden_dims = list(layers) 28 | 29 | self.action_space = action_space 30 | 31 | self.n_agents = len(obs_space) 32 | obs_shape = [flatdim(o) for o in obs_space] 33 | action_shape = [flatdim(a) for a in action_space] 34 | 35 | if not parameter_sharing: 36 | self.critic = MultiAgentIndependentNetwork( 37 | obs_shape, hidden_dims, action_shape, use_rnn, use_orthogonal_init 38 | ) 39 | self.target = MultiAgentIndependentNetwork( 40 | obs_shape, hidden_dims, action_shape, use_rnn, use_orthogonal_init 41 | ) 42 | else: 43 | self.critic = MultiAgentSharedNetwork( 44 | obs_shape, 45 | hidden_dims, 46 | action_shape, 47 | parameter_sharing, 48 | use_rnn, 49 | use_orthogonal_init, 50 | ) 51 | self.target = MultiAgentSharedNetwork( 52 | obs_shape, 53 | hidden_dims, 54 | action_shape, 55 | parameter_sharing, 56 | use_rnn, 57 | use_orthogonal_init, 58 | ) 59 | 60 | self.hard_update() 61 | self.to(device) 62 | 63 | for param in self.target.parameters(): 64 | param.requires_grad = False 65 | 66 | if type(cfg.optimizer) is str: 67 | self.optimizer_class = getattr(optim, cfg.optimizer) 68 | else: 69 | self.optimizer_class = cfg.optimizer 70 | 71 | self.optimizer = self.optimizer_class(self.critic.parameters(), lr=cfg.lr) 72 | 73 | self.gamma = cfg.gamma 74 | self.grad_clip = cfg.grad_clip 75 | self.device = device 76 | self.target_update_interval_or_tau = cfg.target_update_interval_or_tau 77 | self.double_q = cfg.double_q 78 | 79 | self.updates = 0 80 | self.last_target_update = 0 81 | 82 | self.standardise_returns = cfg.standardise_returns 83 | if self.standardise_returns: 84 | self.ret_ms = RunningMeanStd(shape=(self.n_agents,), device=device) 85 | 86 | print(self) 87 | 88 | def forward(self, inputs): 89 | raise NotImplementedError("Forward not implemented. Use act or update instead!") 90 | 91 | def init_hiddens(self, batch_size): 92 | return self.critic.init_hiddens(batch_size, self.device) 93 | 94 | def act(self, inputs, hiddens, epsilon, action_masks=None): 95 | with torch.no_grad(): 96 | inputs = [ 97 | torch.tensor(i, device=self.device).view(1, 1, -1) for i in inputs 98 | ] 99 | values, hiddens = self.critic(inputs, hiddens) 100 | if action_masks is not None: 101 | masked_values = [] 102 | for value, mask in zip(values, action_masks): 103 | masked_values.append(value * mask + (1 - mask) * -1e8) 104 | values = masked_values 105 | if epsilon > random.random(): 106 | if action_masks is not None: 107 | # random index of action with mask = 1 108 | actions = [ 109 | random.choice([i for i, m in enumerate(mask) if m == 1]) 110 | for mask in action_masks 111 | ] 112 | else: 113 | actions = self.action_space.sample() 114 | else: 115 | actions = [value.argmax(-1).squeeze().cpu().item() for value in values] 116 | return actions, hiddens 117 | 118 | def _compute_loss(self, batch): 119 | obss = batch.obss 120 | actions = batch.actions.unsqueeze(-1) 121 | rewards = batch.rewards 122 | dones = batch.dones[1:].unsqueeze(0).repeat(self.n_agents, 1, 1) 123 | filled = batch.filled 124 | action_masks = batch.action_mask 125 | 126 | # (n_agents, ep_length, batch_size, n_actions) 127 | q_values, _ = self.critic(obss, hiddens=None) 128 | q_values = torch.stack(q_values) 129 | chosen_q_values = q_values[:, :-1].gather(-1, actions).squeeze(-1) 130 | 131 | # compute target 132 | with torch.no_grad(): 133 | target_q_values, _ = self.target(obss, hiddens=None) 134 | target_q_values = torch.stack(target_q_values)[:, 1:] 135 | if action_masks is not None: 136 | target_q_values[action_masks[:, 1:] == 0] = -1e8 137 | 138 | if self.double_q: 139 | q_values_clone = q_values.clone().detach()[:, 1:] 140 | if action_masks is not None: 141 | q_values_clone[action_masks[:, 1:] == 0] = -1e8 142 | a_prime = q_values_clone.argmax(-1) 143 | target_qs = target_q_values.gather(-1, a_prime.unsqueeze(-1)).squeeze(-1) 144 | else: 145 | target_qs, _ = target_q_values.max(dim=-1) 146 | 147 | if self.standardise_returns: 148 | target_qs = rearrange(target_qs, "A E B -> E B A") 149 | target_qs = target_qs * torch.sqrt(self.ret_ms.var) + self.ret_ms.mean 150 | target_qs = rearrange(target_qs, "E B A -> A E B") 151 | 152 | returns = rewards + self.gamma * target_qs.detach() * (1 - dones) 153 | 154 | if self.standardise_returns: 155 | returns = rearrange(returns, "A E B -> E B A") 156 | self.ret_ms.update(returns) 157 | returns = (returns - self.ret_ms.mean) / torch.sqrt(self.ret_ms.var) 158 | returns = rearrange(returns, "E B A -> A E B") 159 | 160 | loss = torch.nn.functional.mse_loss( 161 | chosen_q_values, returns.detach(), reduction="none" 162 | ).sum(dim=0) 163 | return (loss * filled).sum() / filled.sum() 164 | 165 | def update(self, batch): 166 | loss = self._compute_loss(batch) 167 | self.optimizer.zero_grad() 168 | loss.backward() 169 | if self.grad_clip: 170 | torch.nn.utils.clip_grad_norm_(self.critic.parameters(), self.grad_clip) 171 | self.optimizer.step() 172 | self.updates += 1 173 | self.update_target() 174 | return {"loss": loss.item()} 175 | 176 | def update_target(self): 177 | if ( 178 | self.target_update_interval_or_tau > 1.0 179 | and (self.updates - self.last_target_update) 180 | >= self.target_update_interval_or_tau 181 | ): 182 | self.hard_update() 183 | self.last_target_update = self.updates 184 | elif self.target_update_interval_or_tau < 1.0: 185 | self.soft_update(self.target_update_interval_or_tau) 186 | 187 | def soft_update(self, tau): 188 | for target_param, source_param in zip( 189 | self.target.parameters(), self.critic.parameters() 190 | ): 191 | target_param.data.copy_( 192 | (1 - tau) * target_param.data + tau * source_param.data 193 | ) 194 | 195 | def hard_update(self): 196 | self.target.load_state_dict(self.critic.state_dict()) 197 | 198 | 199 | class VDNetwork(QNetwork): 200 | def __init__( 201 | self, 202 | obs_space, 203 | action_space, 204 | cfg, 205 | layers, 206 | parameter_sharing, 207 | use_rnn, 208 | use_orthogonal_init, 209 | device, 210 | ): 211 | super().__init__( 212 | obs_space, 213 | action_space, 214 | cfg, 215 | layers, 216 | parameter_sharing, 217 | use_rnn, 218 | use_orthogonal_init, 219 | device, 220 | ) 221 | if self.standardise_returns: 222 | self.ret_ms = RunningMeanStd(shape=(1,)) 223 | 224 | def _compute_loss(self, batch): 225 | obss = batch.obss 226 | actions = batch.actions.unsqueeze(-1) 227 | # Get reward of agent 0 --> assume cooperative rewards/ same reward for all agents 228 | rewards = batch.rewards[0] 229 | dones = batch.dones[1:] 230 | filled = batch.filled 231 | action_masks = batch.action_mask 232 | 233 | # (n_agents, ep_length, batch_size, n_actions) 234 | q_values, _ = self.critic(obss, hiddens=None) 235 | q_values = torch.stack(q_values) 236 | # sum over all agents for cooperative VDN estimate 237 | chosen_q_values = q_values[:, :-1].gather(-1, actions).squeeze(-1).sum(dim=0) 238 | 239 | # compute target 240 | with torch.no_grad(): 241 | target_q_values, _ = self.target(obss, hiddens=None) 242 | target_q_values = torch.stack(target_q_values)[:, 1:] 243 | if action_masks is not None: 244 | target_q_values[action_masks[:, 1:] == 0] = -1e8 245 | 246 | if self.double_q: 247 | q_values_clone = q_values.clone().detach()[:, 1:] 248 | if action_masks is not None: 249 | q_values_clone[action_masks[:, 1:] == 0] = -1e8 250 | a_prime = q_values_clone.argmax(-1) 251 | target_qs = target_q_values.gather(-1, a_prime.unsqueeze(-1)).squeeze(-1) 252 | else: 253 | target_qs, _ = target_q_values.max(dim=-1) 254 | target_qs = target_qs.sum(dim=0).detach() 255 | 256 | if self.standardise_returns: 257 | target_qs = target_qs * torch.sqrt(self.ret_ms.var) + self.ret_ms.mean 258 | 259 | # sum over target values of all agents for cooperative VDN target 260 | returns = rewards + self.gamma * target_qs * (1 - dones) 261 | 262 | if self.standardise_returns: 263 | self.ret_ms.update(returns) 264 | returns = (returns - self.ret_ms.mean) / torch.sqrt(self.ret_ms.var) 265 | 266 | loss = torch.nn.functional.mse_loss( 267 | chosen_q_values, returns.detach(), reduction="none" 268 | ) 269 | return (loss * filled).sum() / filled.sum() 270 | 271 | 272 | class QMixer(nn.Module): 273 | def __init__(self, n_agents, state_dim, embed_dim, hypernet_layers, hypernet_embed): 274 | super().__init__() 275 | 276 | self.n_agents = n_agents 277 | self.state_dim = state_dim 278 | 279 | self.embed_dim = embed_dim 280 | self.hypernet_layers = hypernet_layers 281 | self.hypernet_embed = hypernet_embed 282 | 283 | if hypernet_layers == 1: 284 | self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) 285 | self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) 286 | elif hypernet_layers == 2: 287 | hypernet_embed = self.hypernet_embed 288 | self.hyper_w_1 = nn.Sequential( 289 | nn.Linear(self.state_dim, hypernet_embed), 290 | nn.ReLU(), 291 | nn.Linear(hypernet_embed, self.embed_dim * self.n_agents), 292 | ) 293 | self.hyper_w_final = nn.Sequential( 294 | nn.Linear(self.state_dim, hypernet_embed), 295 | nn.ReLU(), 296 | nn.Linear(hypernet_embed, self.embed_dim), 297 | ) 298 | else: 299 | raise Exception( 300 | "Error setting number of hypernet layers (please set `hypernet_layers=1` or `hypernet_layers=2`)." 301 | ) 302 | 303 | # State dependent bias for hidden layer 304 | self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 305 | 306 | # V(s) instead of a bias for the last layers 307 | self.V = nn.Sequential( 308 | nn.Linear(self.state_dim, self.embed_dim), 309 | nn.ReLU(), 310 | nn.Linear(self.embed_dim, 1), 311 | ) 312 | 313 | def forward(self, agent_qs, states): 314 | _, ep_length, batch_size = agent_qs.shape 315 | agent_qs = rearrange(agent_qs, "N E B -> (E B) 1 N") 316 | states = states.reshape(-1, self.state_dim) 317 | # First layer 318 | w1 = torch.abs(self.hyper_w_1(states)) 319 | b1 = self.hyper_b_1(states) 320 | w1 = w1.view(-1, self.n_agents, self.embed_dim) 321 | b1 = b1.view(-1, 1, self.embed_dim) 322 | hidden = F.elu(torch.bmm(agent_qs, w1) + b1) 323 | # Second layer 324 | w_final = torch.abs(self.hyper_w_final(states)) 325 | w_final = w_final.view(-1, self.embed_dim, 1) 326 | # State-dependent bias 327 | v = self.V(states).view(-1, 1, 1) 328 | # Compute final output 329 | y = torch.bmm(hidden, w_final) + v 330 | # Reshape and return 331 | return y.view(ep_length, batch_size) 332 | 333 | 334 | class QMixNetwork(QNetwork): 335 | def __init__( 336 | self, 337 | obs_space, 338 | action_space, 339 | cfg, 340 | layers, 341 | parameter_sharing, 342 | use_rnn, 343 | use_orthogonal_init, 344 | mixing, 345 | device, 346 | ): 347 | super().__init__( 348 | obs_space, 349 | action_space, 350 | cfg, 351 | layers, 352 | parameter_sharing, 353 | use_rnn, 354 | use_orthogonal_init, 355 | device, 356 | ) 357 | if self.standardise_returns: 358 | self.ret_ms = RunningMeanStd(shape=(1,)) 359 | 360 | state_dim = sum([flatdim(o) for o in obs_space]) 361 | self.mixer = QMixer(self.n_agents, state_dim, **mixing) 362 | self.target_mixer = QMixer(self.n_agents, state_dim, **mixing) 363 | self.hard_update() 364 | 365 | for param in self.target_mixer.parameters(): 366 | param.requires_grad = False 367 | 368 | self.optimizer = self.optimizer_class( 369 | list(self.critic.parameters()) + list(self.mixer.parameters()), 370 | lr=cfg.lr, 371 | ) 372 | print(self) 373 | 374 | def _compute_loss(self, batch): 375 | obss = batch.obss 376 | actions = batch.actions.unsqueeze(-1) 377 | # Get reward of agent 0 --> assume cooperative rewards/ same reward for all agents 378 | rewards = batch.rewards[0] 379 | dones = batch.dones[1:] 380 | filled = batch.filled 381 | action_masks = batch.action_mask 382 | 383 | # (n_agents, ep_length, batch_size, n_actions) 384 | q_values, _ = self.critic(obss, hiddens=None) 385 | q_values = torch.stack(q_values) 386 | # sum over all agents for cooperative VDN estimate 387 | chosen_q_values = self.mixer( 388 | q_values[:, :-1].gather(-1, actions).squeeze(-1), 389 | torch.concat(list(obss[:, :-1]), dim=-1), 390 | ) 391 | 392 | # compute target 393 | with torch.no_grad(): 394 | target_q_values, _ = self.target(obss, hiddens=None) 395 | target_q_values = torch.stack(target_q_values)[:, 1:] 396 | if action_masks is not None: 397 | target_q_values[action_masks[:, 1:] == 0] = -1e8 398 | 399 | if self.double_q: 400 | q_values_clone = q_values.clone().detach()[:, 1:] 401 | if action_masks is not None: 402 | q_values_clone[action_masks[:, 1:] == 0] = -1e8 403 | a_prime = q_values_clone.argmax(-1) 404 | target_qs = target_q_values.gather(-1, a_prime.unsqueeze(-1)).squeeze( 405 | -1 406 | ) 407 | else: 408 | target_qs, _ = target_q_values.max(dim=-1) 409 | 410 | target_qs = self.target_mixer( 411 | target_qs, 412 | torch.concat(list(obss[:, 1:]), dim=-1), 413 | ).detach() 414 | 415 | if self.standardise_returns: 416 | target_qs = target_qs * torch.sqrt(self.ret_ms.var) + self.ret_ms.mean 417 | 418 | returns = rewards + self.gamma * target_qs * (1 - dones) 419 | 420 | if self.standardise_returns: 421 | self.ret_ms.update(returns) 422 | returns = (returns - self.ret_ms.mean) / torch.sqrt(self.ret_ms.var) 423 | 424 | loss = torch.nn.functional.mse_loss( 425 | chosen_q_values, returns.detach(), reduction="none" 426 | ) 427 | return (loss * filled).sum() / filled.sum() 428 | 429 | def soft_update(self, t): 430 | super().soft_update(t) 431 | try: 432 | source, target = self.mixer, self.target_mixer 433 | except AttributeError: # fix for when qmix has not initialised a mixer yet 434 | return 435 | for target_param, source_param in zip(target.parameters(), source.parameters()): 436 | target_param.data.copy_((1 - t) * target_param.data + t * source_param.data) 437 | 438 | def hard_update(self): 439 | super().hard_update() 440 | try: 441 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 442 | except AttributeError: # fix for when qmix has not initialised a mixer yet 443 | return 444 | --------------------------------------------------------------------------------