├── marlbase ├── __init__.py ├── ac │ ├── __init__.py │ ├── eval.py │ ├── train.py │ └── model.py ├── dqn │ ├── __init__.py │ ├── eval.py │ ├── train.py │ └── model.py ├── utils │ ├── __init__.py │ ├── video.py │ ├── smaclite_wrapper.py │ ├── standardise_stream.py │ ├── postprocessing │ │ ├── find_best_hyperparams.py │ │ ├── plot_runs.py │ │ ├── hiplot_fetcher.py │ │ ├── export_multirun.py │ │ └── load_data.py │ ├── utils.py │ ├── envs.py │ ├── wrappers.py │ ├── loggers.py │ ├── stats.py │ └── models.py ├── configs │ ├── logger │ │ ├── basic.yaml │ │ ├── wandb.yaml │ │ └── filesystemlogger.yaml │ ├── algorithm │ │ ├── vdn.yaml │ │ ├── qmix.yaml │ │ ├── ia2c.yaml │ │ ├── maa2c.yaml │ │ ├── ippo.yaml │ │ ├── mappo.yaml │ │ └── idqn.yaml │ ├── eval.yaml │ ├── hydra │ │ └── job_logging │ │ │ ├── file.yaml │ │ │ └── console.yaml │ ├── default.yaml │ └── sweeps │ │ └── sample.yaml ├── run.py ├── eval.py └── search.py ├── .flake8 ├── requirements.txt ├── setup.py ├── .editorconfig ├── .gitignore └── README.md /marlbase/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marlbase/ac/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marlbase/dqn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marlbase/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /marlbase/configs/logger/basic.yaml: -------------------------------------------------------------------------------- 1 | _target_: utils.loggers.Logger 2 | project_name: fastmarl -------------------------------------------------------------------------------- /marlbase/configs/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | _target_: utils.loggers.WandbLogger 2 | project_name: fastmarl -------------------------------------------------------------------------------- /marlbase/configs/logger/filesystemlogger.yaml: -------------------------------------------------------------------------------- 1 | _target_: utils.loggers.FileSystemLogger 2 | project_name: fastmarl 3 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E226,E302,E41, F841 3 | max-line-length = 89 4 | exclude = tests/* 5 | max-complexity = 10 -------------------------------------------------------------------------------- /marlbase/configs/algorithm/vdn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - idqn 5 | 6 | env: 7 | wrappers : 8 | - CooperativeReward 9 | 10 | algorithm: 11 | name: "vdn" 12 | model: 13 | _target_: dqn.model.VDNetwork 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click==8.1.7 2 | einops==0.8.0 3 | gymnasium<1.0.0 4 | hydra-core==1.3.2 5 | hydra-submitit-launcher==1.2.0 6 | hydra-ax-sweeper==1.1.5 7 | imageio==2.9.0 8 | imageio-ffmpeg==0.5.1 9 | matplotlib==3.9.2 10 | munch==4.0.0 11 | pandas==2.2.2 12 | seaborn==0.13.2 13 | torch==2.4.0 14 | pyyaml==6.0.2 15 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/qmix.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - idqn 5 | 6 | env: 7 | wrappers : 8 | - CooperativeReward 9 | 10 | algorithm: 11 | name: "qmix" 12 | model: 13 | _target_: dqn.model.QMixNetwork 14 | mixing: 15 | embed_dim: 64 16 | hypernet_layers: 2 17 | hypernet_embed: 32 18 | -------------------------------------------------------------------------------- /marlbase/configs/eval.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - logger: filesystemlogger 4 | - override hydra/job_logging: file 5 | - override hydra/launcher: submitit_local 6 | 7 | hydra: 8 | run: 9 | dir: evals/${random:4} 10 | launcher: 11 | timeout_min: 2880 12 | cpus_per_task: 1 13 | mem_gb: 4 14 | job: 15 | chdir: True 16 | 17 | path: null 18 | load_step: null 19 | video_frames: 10000 20 | seed: null 21 | 22 | -------------------------------------------------------------------------------- /marlbase/utils/video.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | 3 | 4 | class VideoRecorder: 5 | def __init__(self, fps=30): 6 | self.fps = fps 7 | self.frames = [] 8 | 9 | def reset(self): 10 | self.frames = [] 11 | 12 | def record_frame(self, env): 13 | frame = env.unwrapped.render() 14 | self.frames.append(frame) 15 | 16 | def save(self, filename): 17 | imageio.mimsave(f"{filename}", self.frames, fps=self.fps) 18 | -------------------------------------------------------------------------------- /marlbase/configs/hydra/job_logging/file.yaml: -------------------------------------------------------------------------------- 1 | version: 1 2 | formatters: 3 | simple: 4 | format: '(%(process)d) [%(levelname)s] - (%(asctime)s) - %(name)s >> %(message)s' 5 | datefmt: '%m/%d %H:%M:%S' 6 | handlers: 7 | console: 8 | class: logging.StreamHandler 9 | formatter: simple 10 | stream: ext://sys.stdout 11 | file: 12 | class: logging.FileHandler 13 | formatter: simple 14 | filename: ${hydra.job.name}.log 15 | root: 16 | level: INFO 17 | handlers: [console, file] 18 | 19 | disable_existing_loggers: false -------------------------------------------------------------------------------- /marlbase/configs/hydra/job_logging/console.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | version: 1 3 | formatters: 4 | simple: 5 | format: '(%(process)d) [%(levelname)s] - (%(asctime)s) - %(name)s >> %(message)s' 6 | datefmt: '%m/%d %H:%M:%S' 7 | handlers: 8 | console: 9 | class: logging.StreamHandler 10 | formatter: simple 11 | stream: ext://sys.stdout 12 | file: 13 | class: logging.FileHandler 14 | formatter: simple 15 | filename: ${hydra.job.name}.log 16 | root: 17 | level: INFO 18 | handlers: [console] 19 | 20 | disable_existing_loggers: false -------------------------------------------------------------------------------- /marlbase/ac/eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | import torch 4 | 5 | from marlbase.ac.train import record_episodes 6 | 7 | 8 | def main(env, ckpt_path, **cfg): 9 | cfg = DictConfig(cfg) 10 | 11 | model = hydra.utils.instantiate( 12 | cfg.model, env.observation_space, env.action_space, cfg 13 | ) 14 | print(f"Loading model from {ckpt_path}") 15 | state_dict = torch.load(ckpt_path, weights_only=True) 16 | model.load_state_dict(state_dict) 17 | 18 | record_episodes( 19 | env, 20 | model, 21 | cfg.video_frames, 22 | "./eval.mp4", 23 | cfg.model.device, 24 | ) 25 | 26 | env.close() 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="marlbase", 5 | version="0.1.0", 6 | description="Fast Multi-Agent RL: a collection of algorithms applied in multi-agent environments", 7 | author="Filippos Christianos, Lukas Schäfer", 8 | url="https://github.com/marl-book/codebase", 9 | packages=find_packages(exclude=["contrib", "docs", "tests"]), 10 | classifiers=[ 11 | "Intended Audience :: Developers", 12 | "Programming Language :: Python :: 3.10", 13 | ], 14 | install_requires=["hydra-core>=1.1", "torch", "cpprb", "einops"], 15 | extras_require={"test": ["pytest"]}, 16 | include_package_data=True, 17 | ) 18 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # EditorConfig is awesome: https://EditorConfig.org 2 | 3 | # top-most EditorConfig file 4 | root = true 5 | 6 | # Unix-style newlines with a newline ending every file 7 | [*] 8 | end_of_line = lf 9 | insert_final_newline = true 10 | 11 | # Matches multiple files with brace expansion notation 12 | # Set default charset 13 | [*.{js,py}] 14 | charset = utf-8 15 | 16 | # 4 space indentation 17 | [*.py] 18 | indent_style = space 19 | indent_size = 4 20 | 21 | # Tab indentation (no size specified) 22 | [Makefile] 23 | indent_style = tab 24 | 25 | # Matches the exact files either package.json or .travis.yml 26 | [{package.json,.travis.yml}] 27 | indent_style = space 28 | indent_size = 2 29 | -------------------------------------------------------------------------------- /marlbase/dqn/eval.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | from omegaconf import DictConfig 3 | import torch 4 | 5 | from marlbase.dqn.train import record_episodes 6 | 7 | 8 | def main(env, ckpt_path, **cfg): 9 | cfg = DictConfig(cfg) 10 | 11 | model = hydra.utils.instantiate( 12 | cfg.model, env.observation_space, env.action_space, cfg 13 | ) 14 | print(f"Loading model from {ckpt_path}") 15 | state_dict = torch.load(ckpt_path, weights_only=True) 16 | model.load_state_dict(state_dict) 17 | 18 | record_episodes( 19 | env, 20 | model, 21 | cfg.video_frames, 22 | "./eval.mp4", 23 | cfg.model.device, 24 | cfg.eps_evaluation, 25 | ) 26 | 27 | env.close() 28 | -------------------------------------------------------------------------------- /marlbase/configs/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - logger: filesystemlogger 4 | - override hydra/job_logging: file 5 | - override hydra/launcher: submitit_local 6 | 7 | hydra: 8 | run: 9 | dir: outputs/${env.name}/${algorithm.name}/${random:4} 10 | launcher: 11 | timeout_min: 2880 12 | cpus_per_task: 1 13 | mem_gb: 4 14 | job: 15 | chdir: True 16 | 17 | seed: null 18 | 19 | algorithm: 20 | total_steps: 100_000 21 | log_interval: 10_000 22 | save_interval: False 23 | eval_interval: 10_000 24 | eval_episodes: 100 25 | video_interval: False 26 | video_frames: 500 27 | 28 | env: 29 | _target_: utils.envs.make_env 30 | name : ??? 31 | time_limit : ??? 32 | clear_info: False 33 | observe_id: False 34 | standardise_rewards: False 35 | wrappers : null 36 | -------------------------------------------------------------------------------- /marlbase/configs/sweeps/sample.yaml: -------------------------------------------------------------------------------- 1 | # constants 2 | algorithm.save_interval: False 3 | algorithm.eval_episodes: 10 4 | algorithm.video_interval: False 5 | 6 | # top-level choice (grid) search 7 | algorithm.standardise_returns: 8 | - True 9 | - False 10 | 11 | hparam-tuples-1: 12 | - !!python/tuple [env.name: lbforaging:Foraging-10x10-3p-3f-v3, env.time_limit: 25] 13 | - !!python/tuple [env.name: lbforaging:Foraging-8x8-2p-2f-coop-v3, env.time_limit: 25] 14 | 15 | hparam-tuples-2: 16 | - !!python/tuple 17 | - "+algorithm": "idqn" 18 | - algorithm.total_steps: 2_000_000 19 | - algorithm.batch_size: 20 | - 128 21 | - 256 22 | 23 | - !!python/tuple 24 | - "+algorithm": "ia2c" 25 | - algorithm.total_steps: 2_000_000 26 | - algorithm.entropy_coef: 27 | - 0.01 28 | - 0.001 29 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/ia2c.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | env: 3 | parallel_envs: 10 4 | 5 | algorithm: 6 | _target_: ac.train.main 7 | name: "ia2c" 8 | model: 9 | _target_: ac.model.A2CNetwork 10 | actor: 11 | layers: 12 | - 128 13 | - 128 14 | parameter_sharing: False # True/False/List[int] (seps_indices) 15 | use_orthogonal_init: True 16 | use_rnn: False 17 | critic: 18 | centralised: False 19 | layers: 20 | - 128 21 | - 128 22 | parameter_sharing: False # True/False/List[int] (seps_indices) 23 | use_orthogonal_init: True 24 | use_rnn: False 25 | 26 | device : "cpu" # a pytorch device ("cpu" or "cuda") 27 | 28 | optimizer : "Adam" 29 | lr: 3.e-4 30 | 31 | grad_clip: False 32 | 33 | n_steps: 5 34 | gamma: 0.99 35 | entropy_coef: 0.001 36 | value_loss_coef: 0.5 37 | use_proper_termination: False 38 | standardise_returns: False 39 | 40 | target_update_interval_or_tau: 200 41 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/maa2c.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | env: 3 | parallel_envs: 10 4 | 5 | algorithm: 6 | _target_: ac.train.main 7 | name: "maa2c" 8 | model: 9 | _target_: ac.model.A2CNetwork 10 | actor: 11 | layers: 12 | - 128 13 | - 128 14 | parameter_sharing: False # True/False/List[int] (seps_indices) 15 | use_orthogonal_init: True 16 | use_rnn: False 17 | critic: 18 | centralised: True 19 | layers: 20 | - 128 21 | - 128 22 | parameter_sharing: False # True/False/List[int] (seps_indices) 23 | use_orthogonal_init: True 24 | use_rnn: False 25 | 26 | device : "cpu" # a pytorch device ("cpu" or "cuda") 27 | 28 | optimizer : "Adam" 29 | lr: 3.e-4 30 | 31 | grad_clip: False 32 | 33 | n_steps: 5 34 | gamma: 0.99 35 | entropy_coef: 0.001 36 | value_loss_coef: 0.5 37 | use_proper_termination: False 38 | standardise_returns: False 39 | 40 | target_update_interval_or_tau: 200 41 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/ippo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | env: 3 | parallel_envs: 10 4 | 5 | algorithm: 6 | _target_: ac.train.main 7 | name: "ippo" 8 | model: 9 | _target_: ac.model.PPONetwork 10 | actor: 11 | layers: 12 | - 128 13 | - 128 14 | parameter_sharing: False # True/False/List[int] (seps_indices) 15 | use_orthogonal_init: True 16 | use_rnn: False 17 | critic: 18 | centralised: False 19 | layers: 20 | - 128 21 | - 128 22 | parameter_sharing: False # True/False/List[int] (seps_indices) 23 | use_orthogonal_init: True 24 | use_rnn: False 25 | 26 | device : "cpu" # a pytorch device ("cpu" or "cuda") 27 | 28 | optimizer : "Adam" 29 | lr: 3.e-4 30 | 31 | grad_clip: False 32 | 33 | n_steps: 5 34 | gamma: 0.99 35 | entropy_coef: 0.001 36 | value_loss_coef: 0.5 37 | use_proper_termination: False 38 | standardise_returns: False 39 | 40 | num_epochs: 4 41 | ppo_clip: 0.2 42 | 43 | target_update_interval_or_tau: 200 44 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/mappo.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | env: 3 | parallel_envs: 10 4 | 5 | algorithm: 6 | _target_: ac.train.main 7 | name: "mappo" 8 | model: 9 | _target_: ac.model.PPONetwork 10 | actor: 11 | layers: 12 | - 128 13 | - 128 14 | parameter_sharing: False # True/False/List[int] (seps_indices) 15 | use_orthogonal_init: True 16 | use_rnn: False 17 | critic: 18 | centralised: True 19 | layers: 20 | - 128 21 | - 128 22 | parameter_sharing: False # True/False/List[int] (seps_indices) 23 | use_orthogonal_init: True 24 | use_rnn: False 25 | 26 | device : "cpu" # a pytorch device ("cpu" or "cuda") 27 | 28 | optimizer : "Adam" 29 | lr: 3.e-4 30 | 31 | grad_clip: False 32 | 33 | n_steps: 5 34 | gamma: 0.99 35 | entropy_coef: 0.001 36 | value_loss_coef: 0.5 37 | use_proper_termination: False 38 | standardise_returns: False 39 | 40 | num_epochs: 4 41 | ppo_clip: 0.2 42 | 43 | target_update_interval_or_tau: 200 44 | -------------------------------------------------------------------------------- /marlbase/configs/algorithm/idqn.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | algorithm: 4 | _target_: dqn.train.main 5 | name: "idqn" 6 | model: 7 | _target_: dqn.model.QNetwork 8 | layers: 9 | - 128 10 | - 128 11 | parameter_sharing: False # True/False/List[int] (seps_indices) 12 | use_orthogonal_init: True 13 | use_rnn: False 14 | device: "cpu" # a pytorch device ("cpu" or "cuda") 15 | 16 | training_start: 2000 17 | buffer_size: 10000 # number of *episodes* to store in the replay buffer 18 | 19 | optimizer: "Adam" 20 | lr: 3e-4 21 | gamma: 0.99 22 | batch_size: 32 23 | double_q: True 24 | 25 | grad_clip: 1.0 26 | 27 | use_proper_termination: False # True/ False 28 | standardise_returns: False 29 | 30 | eps_decay_style: "linear" # "linear" or "exponential" 31 | eps_decay_over: 0.5 # fraction of total steps over which to decay epsilon 32 | eps_start: 1.0 33 | eps_end: 0.05 34 | eps_exp_decay_rate: 6.5 # exponential decay rate (ignored for linear decay) 35 | eps_evaluation: 0.05 36 | 37 | target_update_interval_or_tau: 200 38 | -------------------------------------------------------------------------------- /marlbase/utils/smaclite_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gymnasium as gym 3 | 4 | 5 | class SMACliteWrapper(gym.Wrapper): 6 | def __init__(self, env): 7 | super().__init__(env) 8 | self.n_agents = self.env.unwrapped.n_agents 9 | 10 | def step(self, actions): 11 | """Returns obss, reward, terminated, truncated, info""" 12 | actions = [int(act) for act in actions] 13 | obs, reward, terminated, truncated, info = self.env.step(actions) 14 | info["action_mask"] = np.array( 15 | self.env.unwrapped.get_avail_actions(), dtype=np.float32 16 | ) 17 | return obs, [reward] * self.n_agents, terminated, truncated, info 18 | 19 | def reset(self, seed=None, options=None): 20 | """Returns initial observations and info""" 21 | obs, info = self.env.reset(seed=seed, options=options) 22 | info["action_mask"] = np.array( 23 | self.env.unwrapped.get_avail_actions(), dtype=np.float32 24 | ) 25 | return obs, info 26 | 27 | def render(self): 28 | self.env.render() 29 | 30 | def close(self): 31 | self.env.close() 32 | -------------------------------------------------------------------------------- /marlbase/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import hydra 4 | import numpy as np 5 | from omegaconf import OmegaConf, DictConfig 6 | import torch 7 | 8 | OmegaConf.register_new_resolver( 9 | "random", 10 | lambda x: os.urandom(x).hex(), 11 | ) 12 | 13 | 14 | @hydra.main(config_path="configs", config_name="default", version_base="1.3") 15 | def main(cfg: DictConfig): 16 | logger = hydra.utils.instantiate(cfg.logger, cfg=cfg, _recursive_=False) 17 | 18 | env = hydra.utils.call(cfg.env, seed=cfg.seed) 19 | 20 | # Use singular env for evaluation/ recording 21 | if "parallel_envs" in cfg.env: 22 | del cfg.env.parallel_envs 23 | eval_env = hydra.utils.call( 24 | cfg.env, 25 | enable_video=True if cfg.algorithm.video_interval else False, 26 | seed=cfg.seed, 27 | ) 28 | 29 | torch.set_num_threads(1) 30 | 31 | if cfg.seed is not None: 32 | torch.manual_seed(cfg.seed) 33 | np.random.seed(cfg.seed) 34 | else: 35 | logger.warning("No seed has been set.") 36 | 37 | assert cfg.env.time_limit is not None, "Time limit must be set." 38 | hydra.utils.call( 39 | cfg.algorithm, 40 | env, 41 | eval_env, 42 | logger, 43 | time_limit=cfg.env.time_limit, 44 | _recursive_=False, 45 | ) 46 | 47 | return logger.get_state() 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /marlbase/utils/standardise_stream.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | 5 | 6 | class RunningMeanStd(object): 7 | def __init__(self, epsilon: float = 1e-4, shape: Tuple[int, ...] = (), device: str = "cpu"): 8 | """ 9 | https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 10 | """ 11 | self.mean = torch.zeros(shape, dtype=torch.float32, device=device) 12 | self.var = torch.ones(shape, dtype=torch.float32, device=device) 13 | self.count = epsilon 14 | 15 | def update(self, arr): 16 | arr = arr.reshape(-1, arr.size(-1)) 17 | batch_mean = torch.mean(arr, dim=0) 18 | batch_var = torch.var(arr, dim=0) 19 | batch_count = arr.shape[0] 20 | self.update_from_moments(batch_mean, batch_var, batch_count) 21 | 22 | def update_from_moments(self, batch_mean, batch_var, batch_count: int): 23 | delta = batch_mean - self.mean 24 | tot_count = self.count + batch_count 25 | 26 | new_mean = self.mean + delta * batch_count / tot_count 27 | m_a = self.var * self.count 28 | m_b = batch_var * batch_count 29 | m_2 = ( 30 | m_a 31 | + m_b 32 | + torch.square(delta) 33 | * self.count 34 | * batch_count 35 | / (self.count + batch_count) 36 | ) 37 | new_var = m_2 / (self.count + batch_count) 38 | 39 | new_count = batch_count + self.count 40 | 41 | self.mean = new_mean 42 | self.var = new_var 43 | self.count = new_count 44 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/find_best_hyperparams.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | from omegaconf import OmegaConf 5 | 6 | from marlbase.utils.postprocessing.load_data import load_and_group_runs 7 | 8 | 9 | DEFAULT_METRIC = "mean_episode_returns" 10 | 11 | 12 | @click.command() 13 | @click.option("--source", type=click.Path(dir_okay=True, writable=False), required=True) 14 | @click.option("--metric", type=str, default=DEFAULT_METRIC) 15 | def run(source, metric): 16 | groups = load_and_group_runs(Path(source)) 17 | assert len(groups) > 0, "No groups found" 18 | 19 | assert all( 20 | group.has_metric(metric) for group in groups 21 | ), f"Metric {metric} not found in all groups" 22 | 23 | envs = set([group.config.env.name for group in groups]) 24 | 25 | for env in envs: 26 | env_groups = [group for group in groups if group.config.env.name == env] 27 | 28 | best_group = None 29 | best_value = -float("inf") 30 | 31 | for group in env_groups: 32 | values = group.get_metric(metric) 33 | mean = values.mean() 34 | if mean > best_value: 35 | best_group = group 36 | best_value = mean 37 | 38 | click.echo( 39 | "Best group for " 40 | + click.style(env, fg="red", bold=True) 41 | + " according to " 42 | + click.style(metric, fg="red", bold=True) 43 | + ": " 44 | + click.style(best_group.name, fg="red", bold=True) 45 | ) 46 | 47 | click.echo(OmegaConf.to_yaml(best_group.config)) 48 | 49 | click.echo(85 * "-" + "\n") 50 | 51 | 52 | if __name__ == "__main__": 53 | run() 54 | -------------------------------------------------------------------------------- /marlbase/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | import numpy as np 6 | from omegaconf import OmegaConf, DictConfig 7 | import torch 8 | 9 | 10 | OmegaConf.register_new_resolver( 11 | "random", 12 | lambda x: os.urandom(x).hex(), 13 | ) 14 | 15 | 16 | @hydra.main(config_path="configs", config_name="eval", version_base="1.3") 17 | def main(cfg: DictConfig): 18 | path = Path(__file__).parent / cfg.path 19 | assert path.exists(), f"Path {path} does not exist." 20 | assert path.is_dir(), f"Path {path} is not a directory." 21 | 22 | config_path = path / "config.yaml" 23 | assert config_path.exists(), f"Config file {config_path} does not exist." 24 | run_config = OmegaConf.load(config_path) 25 | 26 | if "parallel_envs" in run_config.env: 27 | del run_config.env.parallel_envs 28 | env = hydra.utils.call( 29 | run_config.env, 30 | enable_video=True, 31 | seed=cfg.seed, 32 | ) 33 | 34 | torch.set_num_threads(1) 35 | 36 | if cfg.seed is not None: 37 | torch.manual_seed(cfg.seed) 38 | np.random.seed(cfg.seed) 39 | 40 | run_config.algorithm._target_ = run_config.algorithm._target_.replace("train", "eval") 41 | 42 | if cfg.load_step is not None: 43 | load_step = cfg.load_step 44 | else: 45 | # Find the latest checkpoint 46 | load_step = max( 47 | [ 48 | int(f.stem.split("_")[-1][1:]) 49 | for f in (path / "checkpoints").glob("model_s*.pt") 50 | ] 51 | ) 52 | ckpt_path = path / "checkpoints" / f"model_s{load_step}.pt" 53 | assert ckpt_path.exists(), f"Checkpoint {ckpt_path} does not exist." 54 | 55 | hydra.utils.call( 56 | run_config.algorithm, 57 | env, 58 | ckpt_path, 59 | _recursive_=False, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/plot_runs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import click 4 | import matplotlib.pyplot as plt 5 | import seaborn as sns 6 | 7 | from marlbase.utils.postprocessing.load_data import load_and_group_runs 8 | 9 | 10 | DEFAULT_METRIC = "mean_episode_returns" 11 | 12 | 13 | @click.command() 14 | @click.option("--source", type=click.Path(dir_okay=True, writable=False), required=True) 15 | @click.option("--minimal-name", type=bool, default=True) 16 | @click.option("--metric", type=str, default=DEFAULT_METRIC) 17 | @click.option("--save_path", type=click.Path(dir_okay=True, writable=True)) 18 | def run(source, minimal_name, metric, save_path): 19 | groups = load_and_group_runs(Path(source), minimal_name) 20 | assert len(groups) > 0, "No groups found" 21 | 22 | click.echo(f"Loaded {len(groups)} groups:") 23 | for group in groups: 24 | click.echo(f"\t{group.name} with {len(group.runs)} runs") 25 | 26 | assert all( 27 | group.has_metric(metric) for group in groups 28 | ), f"Metric {metric} not found in all groups" 29 | 30 | envs = set([group.config.env.name for group in groups]) 31 | 32 | for env in envs: 33 | env_groups = [group for group in groups if group.config.env.name == env] 34 | 35 | sns.set_style("whitegrid") 36 | plt.figure() 37 | for group in env_groups: 38 | steps = group.get_metric("environment_steps").mean(axis=0) 39 | values = group.get_metric(metric) 40 | means = values.mean(axis=0) 41 | stds = values.std(axis=0) 42 | plt.plot(steps, means, label=group.name) 43 | plt.fill_between( 44 | steps, 45 | means - stds, 46 | means + stds, 47 | alpha=0.3, 48 | ) 49 | plt.legend() 50 | plt.xlabel("Environment steps") 51 | plt.ylabel(metric) 52 | plt.title(env) 53 | if save_path: 54 | path = Path(save_path) / f"{env.replace('/', ':')}_{metric}.pdf" 55 | path.parent.mkdir(parents=True, exist_ok=True) 56 | plt.savefig(path) 57 | plt.show() 58 | 59 | 60 | if __name__ == "__main__": 61 | run() 62 | -------------------------------------------------------------------------------- /marlbase/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class MultiCategorical: 5 | def __init__(self, categoricals): 6 | self.categoricals = categoricals 7 | 8 | def __getitem__(self, key): 9 | return self.categoricals[key] 10 | 11 | def sample(self): 12 | return [c.sample().unsqueeze(-1) for c in self.categoricals] 13 | 14 | def log_probs(self, actions): 15 | return [ 16 | c.log_prob(a.squeeze(-1)).unsqueeze(-1) 17 | for c, a in zip(self.categoricals, actions) 18 | ] 19 | 20 | def mode(self): 21 | return [c.mode for c in self.categoricals] 22 | 23 | def entropy(self): 24 | return [c.entropy() for c in self.categoricals] 25 | 26 | 27 | def to_onehot(tensor, n_dims): 28 | """ 29 | Convert tensor of indices to one-hot representation 30 | :param tensor: tensor of indices (batch_size, ..., 1) 31 | :param n_dims: number of dimensions 32 | :return: one-hot representation (batch_size, ..., n_dims) 33 | """ 34 | onehot = torch.zeros(tensor.shape + (n_dims,), device=tensor.device) 35 | return onehot.scatter(-1, tensor.unsqueeze(-1), 1) 36 | 37 | 38 | def compute_nstep_returns(rewards, done, next_values, nsteps, gamma): 39 | """ 40 | Computed n-step returns 41 | :param rewards: tensor of shape (ep_length, batch_size, n_agents) 42 | :param done: tensor of shape (ep_length, batch_size, n_agents) 43 | :param next_values: tensor of shape (ep_length, batch_size, n_agents) 44 | :param nsteps: number of steps to bootstrap 45 | :param gamma: discount factor 46 | :return: tensor of shape with returns (ep_length, batch_size, n_agents) 47 | """ 48 | ep_length = rewards.size(0) 49 | nstep_values = torch.zeros_like(rewards) 50 | for t_start in range(ep_length): 51 | nstep_return_t = torch.zeros_like(rewards[0]) 52 | for step in range(nsteps + 1): 53 | t = t_start + step 54 | if t >= ep_length: 55 | # episode has ended 56 | break 57 | elif step == nsteps: 58 | # last n-step value --> bootstrap from the next value 59 | nstep_return_t += gamma**step * next_values[t] * (1 - done[t]) 60 | else: 61 | nstep_return_t += gamma**step * rewards[t] * (1 - done[t]) 62 | nstep_values[t_start] = nstep_return_t 63 | return nstep_values 64 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/hiplot_fetcher.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import click 4 | import hiplot as hip 5 | import json 6 | import numpy as np 7 | import pandas as pd 8 | 9 | 10 | class NumpyEncoder(json.JSONEncoder): 11 | """ Custom encoder for numpy data types """ 12 | def default(self, obj): 13 | if isinstance(obj, (np.int_, np.intc, np.intp, np.int8, 14 | np.int16, np.int32, np.int64, np.uint8, 15 | np.uint16, np.uint32, np.uint64)): 16 | 17 | return int(obj) 18 | 19 | elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): 20 | return float(obj) 21 | 22 | elif isinstance(obj, (np.complex_, np.complex64, np.complex128)): 23 | return {'real': obj.real, 'imag': obj.imag} 24 | 25 | elif isinstance(obj, (np.ndarray,)): 26 | return obj.tolist() 27 | 28 | elif isinstance(obj, (np.bool_)): 29 | return bool(obj) 30 | 31 | elif isinstance(obj, (np.void)): 32 | return None 33 | 34 | return json.JSONEncoder.default(self, obj) 35 | 36 | def experiment_fetcher(uri): 37 | 38 | PREFIX = "exp://" 39 | 40 | if not uri.startswith(PREFIX): 41 | # Let other fetchers handle this one 42 | raise hip.ExperimentFetcherDoesntApply() 43 | uri = uri[len(PREFIX):] # Remove the prefix 44 | 45 | exported_file = uri.split("/")[0] 46 | 47 | df = pd.read_hdf(exported_file, "df") 48 | configs = pd.read_hdf(exported_file, "configs") 49 | 50 | df = ( 51 | df.groupby(axis=1, level=[0, 1, 2]).mean().max() 52 | ) 53 | 54 | data = defaultdict(lambda: defaultdict(list)) 55 | 56 | for env, df in df.groupby(level=0): 57 | df = df.xs(env) 58 | for alg, df in df.groupby(level=0): 59 | df = df.xs(alg) 60 | 61 | for index, perf in df.iteritems(): 62 | data[env][alg].append({**configs.loc[index].to_dict(), "performance": perf, "uid": index}) 63 | 64 | env, alg = uri.split("/")[1], uri.split("/")[2] 65 | 66 | data = json.loads(json.dumps(data[env][alg], cls=NumpyEncoder)) 67 | exp = hip.Experiment.from_iterable(data) 68 | 69 | return exp 70 | 71 | 72 | if __name__ == "__main__": 73 | click.echo("Run with \"hiplot fastmarl.utils.postprocessing.hiplot_fetcher.experiment_fetcher\"") 74 | click.echo("And enter \"exp://filename.hd5/envname/alg\" in the textbox") 75 | 76 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/export_multirun.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from hashlib import sha256 3 | from pathlib import Path 4 | 5 | import click 6 | import json 7 | from munch import munchify 8 | import pandas as pd 9 | import yaml 10 | 11 | 12 | def _load_data(folder): 13 | path = Path(folder) 14 | config_files = path.glob("**/**/.hydra/config.yaml") 15 | 16 | data = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) 17 | seed_data = defaultdict(lambda: defaultdict(lambda: defaultdict(set))) 18 | 19 | algos = set() 20 | envs = set() 21 | 22 | hash_to_config = defaultdict() 23 | 24 | for config_path in config_files: 25 | # hash = config_path.parent.parent.name 26 | 27 | with config_path.open() as fp: 28 | config = munchify(yaml.safe_load(fp)) 29 | 30 | env = config.env.name 31 | try: 32 | env = env.split(":")[1] # remove the library 33 | except IndexError: 34 | pass 35 | 36 | algo = config.algorithm.name 37 | 38 | algos.add(algo) 39 | envs.add(env) 40 | 41 | seed = config.seed 42 | del config.seed 43 | 44 | raw_data = json.dumps(config, sort_keys=True).encode("utf8") 45 | hash = sha256(raw_data).hexdigest()[:12] 46 | 47 | hash_to_config[hash] = pd.json_normalize(config) 48 | 49 | df = pd.read_csv(config_path.parent.parent / "results.csv", index_col=0)[ 50 | "mean_episode_returns" 51 | ] 52 | data[env][algo][hash].append(df.rename(f"seed={seed}")) 53 | assert seed not in seed_data[env][algo][hash], "Duplicate seed" 54 | seed_data[env][algo][hash].add(seed) 55 | 56 | env_df_list = [] 57 | for env in data.keys(): 58 | algo_df_list = [] 59 | for algo in data[env].keys(): 60 | lst = [] 61 | for hash in data[env][algo].keys(): 62 | lst.append(pd.concat(data[env][algo][hash], axis=1)) 63 | df = pd.concat(lst, axis=1, keys=[h for h in data[env][algo].keys()]) 64 | algo_df_list.append(df) 65 | df = pd.concat(algo_df_list, axis=1, keys=data[env].keys()) 66 | env_df_list.append(df) 67 | df = pd.concat(env_df_list, axis=1, keys=data.keys()) 68 | 69 | return pd.concat(hash_to_config).droplevel(1), df 70 | 71 | 72 | @click.command() 73 | @click.option("--folder", type=click.Path(exists=True), default="outputs/") 74 | @click.option("--export-file", type=click.Path(dir_okay=False, writable=True)) 75 | @click.pass_context 76 | def run(ctx, folder, export_file): 77 | 78 | hash_to_config, df = _load_data(folder) 79 | 80 | df.to_hdf(export_file, key="df", mode="w", complevel=9) 81 | hash_to_config.to_hdf(export_file, key="configs") 82 | 83 | 84 | if __name__ == "__main__": 85 | run() 86 | -------------------------------------------------------------------------------- /marlbase/utils/envs.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import random 3 | 4 | import gymnasium as gym 5 | from omegaconf import DictConfig 6 | 7 | from marlbase.utils import wrappers as mwrappers 8 | from marlbase.utils.smaclite_wrapper import SMACliteWrapper 9 | 10 | 11 | def _make_parallel_envs( 12 | name, 13 | parallel_envs, 14 | wrappers, 15 | time_limit, 16 | clear_info, 17 | observe_id, 18 | standardise_rewards, 19 | seed, 20 | enable_video, 21 | **kwargs, 22 | ): 23 | def _env_thunk(seed): 24 | if "smaclite" in name: 25 | import smaclite # noqa 26 | 27 | env = gym.make( 28 | name, 29 | seed=seed, 30 | render_mode="rgb_array" if enable_video else None, 31 | **kwargs, 32 | ) 33 | env = SMACliteWrapper(env) 34 | else: 35 | env = gym.make( 36 | name, **kwargs, render_mode="rgb_array" if enable_video else None 37 | ) 38 | if clear_info: 39 | env = mwrappers.ClearInfo(env) 40 | if time_limit: 41 | env = gym.wrappers.TimeLimit(env, time_limit) 42 | env = mwrappers.RecordEpisodeStatistics(env) 43 | if observe_id: 44 | env = mwrappers.ObserveID(env) 45 | if standardise_rewards: 46 | env = mwrappers.StandardiseReward(env) 47 | if wrappers is not None: 48 | for wrapper in wrappers: 49 | wrapper = ( 50 | getattr(mwrappers, wrapper) 51 | if hasattr(mwrappers, wrapper) 52 | else getattr(gym.wrappers, wrapper) 53 | ) 54 | env = wrapper(env) 55 | env.reset(seed=seed) 56 | return env 57 | 58 | if seed is None: 59 | seed = random.randint(0, 99999) 60 | 61 | envs = gym.vector.AsyncVectorEnv( 62 | [partial(_env_thunk, seed + i) for i in range(parallel_envs)] 63 | ) 64 | 65 | return envs 66 | 67 | 68 | def _make_env( 69 | name, 70 | time_limit, 71 | clear_info, 72 | observe_id, 73 | standardise_rewards, 74 | wrappers, 75 | seed, 76 | enable_video, 77 | **kwargs, 78 | ): 79 | if "smaclite" in name: 80 | import smaclite # noqa 81 | 82 | env = gym.make( 83 | name, 84 | seed=seed, 85 | render_mode="rgb_array" if enable_video else None, 86 | **kwargs, 87 | ) 88 | env = SMACliteWrapper(env) 89 | else: 90 | env = gym.make( 91 | name, render_mode="rgb_array" if enable_video else None, **kwargs 92 | ) 93 | if clear_info: 94 | env = mwrappers.ClearInfo(env) 95 | if time_limit: 96 | env = gym.wrappers.TimeLimit(env, time_limit) 97 | env = mwrappers.RecordEpisodeStatistics(env) 98 | if observe_id: 99 | env = mwrappers.ObserveID(env) 100 | if standardise_rewards: 101 | env = mwrappers.StandardiseReward(env) 102 | if wrappers is not None: 103 | for wrapper in wrappers: 104 | wrapper = ( 105 | getattr(mwrappers, wrapper) 106 | if hasattr(mwrappers, wrapper) 107 | else getattr(gym.wrappers, wrapper) 108 | ) 109 | env = wrapper(env) 110 | 111 | env.reset(seed=seed) 112 | return env 113 | 114 | 115 | def make_env(seed, enable_video=False, **env_config): 116 | env_config = DictConfig(env_config) 117 | if "parallel_envs" in env_config and env_config.parallel_envs: 118 | return _make_parallel_envs(**env_config, enable_video=enable_video, seed=seed) 119 | return _make_env(**env_config, enable_video=enable_video, seed=seed) 120 | -------------------------------------------------------------------------------- /marlbase/search.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from itertools import product 3 | import multiprocessing 4 | import random 5 | import subprocess 6 | import time 7 | 8 | import click 9 | import yaml 10 | 11 | _CPU_COUNT = multiprocessing.cpu_count() - 1 12 | 13 | 14 | def _flatten_lists(object): 15 | for item in object: 16 | if isinstance(item, (list, tuple, set)): 17 | yield from _flatten_lists(item) 18 | else: 19 | yield item 20 | 21 | 22 | def _seed_and_shuffle(configs, shuffle, seeds): 23 | configs = [[f"+hypergroup=hp_grp_{i}"] + c for i, c in enumerate(configs)] 24 | configs = list(product(configs, [f"seed={i}" for i in range(seeds)])) 25 | configs = [list(_flatten_lists(c)) for c in configs] 26 | 27 | if shuffle: 28 | random.Random(1337).shuffle(configs) 29 | 30 | return configs 31 | 32 | 33 | def _load_config(filename): 34 | config = yaml.load(filename, Loader=yaml.CLoader) 35 | return config 36 | 37 | 38 | def _gen_combos(config, built_config): 39 | built_config = deepcopy(built_config) 40 | if not config: 41 | return [[f"{k}={v}" for k, v in built_config.items()]] 42 | 43 | k, v = list(config.items())[0] 44 | 45 | configs = [] 46 | if type(v) is list: 47 | for item in v: 48 | new_config = deepcopy(config) 49 | del new_config[k] 50 | new_config[k] = item 51 | configs += _gen_combos(new_config, built_config) 52 | elif type(v) is tuple: 53 | new_config = deepcopy(config) 54 | del new_config[k] 55 | for item in v: 56 | new_config.update(item) 57 | 58 | configs += _gen_combos(new_config, built_config) 59 | else: 60 | new_config = deepcopy(config) 61 | del new_config[k] 62 | built_config[k] = v 63 | configs += _gen_combos(new_config, built_config) 64 | return configs 65 | 66 | 67 | def work(cmd, sleep): 68 | cmd = cmd.split(" ") 69 | time.sleep(sleep) 70 | return subprocess.call(cmd, shell=False) 71 | 72 | 73 | @click.group() 74 | def cli(): 75 | pass 76 | 77 | 78 | @cli.command() 79 | @click.argument("output", type=click.Path(exists=False, dir_okay=False, writable=True)) 80 | def write(output): 81 | raise NotImplementedError 82 | 83 | 84 | @cli.group() 85 | @click.option("--config", type=click.File(), default="config.yaml") 86 | @click.option("--shuffle/--no-shuffle", default=True) 87 | @click.option("--seeds", default=3, show_default=True, help="How many seeds to run") 88 | @click.pass_context 89 | def run(ctx, config, shuffle, seeds): 90 | config = _load_config(config) 91 | configs = _gen_combos(config, {}) 92 | 93 | configs = _seed_and_shuffle(configs, shuffle, seeds) 94 | if len(configs) == 0: 95 | click.echo("No valid combinations. Aborted!") 96 | exit(1) 97 | ctx.obj = configs 98 | 99 | 100 | @run.command() 101 | @click.option( 102 | "--cpus", 103 | default=_CPU_COUNT, 104 | show_default=True, 105 | help="How many processes to run in parallel", 106 | ) 107 | @click.pass_obj 108 | def locally(combos, cpus): 109 | configs = [ 110 | "python run.py " + "-m " + " ".join([c for c in combo]) for combo in combos 111 | ] 112 | args = [(conf, i * 2) for i, conf in enumerate(configs)] 113 | 114 | click.confirm( 115 | f"There are {click.style(str(len(combos)), fg='red')} combinations of configurations. Up to {cpus} will run in parallel. Continue?", 116 | abort=True, 117 | ) 118 | 119 | pool = multiprocessing.Pool(processes=cpus) 120 | print(pool.starmap(work, args)) 121 | 122 | 123 | @run.command() 124 | @click.pass_obj 125 | def dry_run(combos): 126 | configs = [" ".join([c for c in combo]) for combo in combos] 127 | click.echo( 128 | f"There are {click.style(str(len(combos)), fg='red')} configurations as shown below:" 129 | ) 130 | for c in configs: 131 | click.echo(c) 132 | 133 | 134 | @run.command() 135 | @click.argument( 136 | "index", 137 | type=int, 138 | ) 139 | @click.pass_obj 140 | def single(combos, index): 141 | """Runs a single hyperparameter combination 142 | INDEX is the index of the combination to run in the generated combination list 143 | """ 144 | 145 | config = combos[index] 146 | cmd = "python run.py " + " ".join([c for c in config]) 147 | print(cmd) 148 | work(cmd) 149 | 150 | 151 | if __name__ == "__main__": 152 | cli() 153 | -------------------------------------------------------------------------------- /marlbase/utils/postprocessing/load_data.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from pathlib import Path 3 | from typing import Dict, List, Union 4 | 5 | import numpy as np 6 | from omegaconf import OmegaConf 7 | import pandas as pd 8 | 9 | 10 | def _load_csv(path: Path) -> Dict[str, List[float]]: 11 | assert ( 12 | path.exists() and path.is_file() and path.suffix == ".csv" 13 | ), f"{path} is not a valid csv file" 14 | return pd.read_csv(path).to_dict(orient="list") 15 | 16 | 17 | def _load_config(path: Path) -> OmegaConf: 18 | assert ( 19 | path.exists() and path.is_file() and path.suffix == ".yaml" 20 | ), f"{path} is not a valid config file" 21 | return OmegaConf.load(path) 22 | 23 | 24 | class Run: 25 | def __init__(self, config: OmegaConf, data: Dict[str, List[float]], path: Path): 26 | self.data = data 27 | self.config = config 28 | self.path = path 29 | 30 | def __from_path__(path: Path) -> "Run": 31 | assert path.exists() and path.is_dir(), f"{path} is not a valid run directory" 32 | data = _load_csv(path / "results.csv") 33 | config = _load_config(path / "config.yaml") 34 | return Run(config, data, path) 35 | 36 | def __str__(self) -> str: 37 | return f"Run {self.path}" 38 | 39 | def get_config_name(self) -> str: 40 | return " ".join( 41 | [f"{key}={value}" for key, value in self.config.items() if key != "seed"] 42 | ) 43 | 44 | 45 | def _load_run(path: Path): 46 | return Run.__from_path__(path) 47 | 48 | 49 | class Group: 50 | def __init__(self, name, runs: List[Run]): 51 | self.name = name 52 | self.config = runs[0].config 53 | self.config.pop("seed") 54 | self.runs = runs 55 | 56 | def __str__(self) -> str: 57 | return f"Group {self.name} ({len(self.runs)} runs)" 58 | 59 | def has_metric(self, key) -> bool: 60 | has_metrics = [key in run.data for run in self.runs] 61 | assert all(has_metrics) or not any( 62 | has_metrics 63 | ), f"Key {key} is present in some but not all runs" 64 | return has_metrics[0] 65 | 66 | def get_metric(self, key) -> np.ndarray: 67 | assert self.has_metric(key), f"Key {key} is not present in all runs" 68 | values = [run.data[key] for run in self.runs] 69 | assert all( 70 | len(value) == len(values[0]) for value in values 71 | ), f"Values for key {key} have different lengths" 72 | return np.array(values) 73 | 74 | 75 | def _load_runs(path: Path) -> List[Run]: 76 | assert path.exists() and path.is_dir(), f"{path} is not a valid directory" 77 | runs = [] 78 | for run in path.glob("**/results.csv"): 79 | run = _load_run(run.parent) 80 | runs.append(run) 81 | return runs 82 | 83 | 84 | def _flatten_omegaconf( 85 | config: OmegaConf, base_name=None 86 | ) -> Dict[str, Union[str, float, int]]: 87 | flat_config = {} 88 | for key, value in config.items(): 89 | key = f"{base_name}.{key}" if base_name else key 90 | if OmegaConf.is_config(value) and not OmegaConf.is_list(value): 91 | flat_config.update(_flatten_omegaconf(value, key)) 92 | else: 93 | flat_config[key] = value 94 | return flat_config 95 | 96 | 97 | def load_and_group_runs(path: Path, minimal_name: bool = True) -> List[Group]: 98 | """ 99 | Load all runs in a directory and group them by unique configurations 100 | :param path: Path to directory containing runs 101 | :param minimal_name: Use minimal name for each group 102 | :return: List of Group objects 103 | """ 104 | # group runs by unique confrigurations 105 | runs_by_config_name = defaultdict(list) 106 | for run in _load_runs(path): 107 | runs_by_config_name[run.get_config_name()].append(run) 108 | 109 | if minimal_name: 110 | # identify minimal hyperparameters that differentiate runs 111 | values_by_key = defaultdict(set) 112 | for config_name, runs in runs_by_config_name.items(): 113 | group_config = _flatten_omegaconf(runs[0].config) 114 | for key, value in group_config.items(): 115 | if ( 116 | key == "seed" 117 | or key == "algorithm.name" 118 | or "_target_" in key 119 | or key == "hypergroup" 120 | or "wrappers" in key 121 | ): 122 | continue 123 | values_by_key[key].add(value) 124 | 125 | # distinguishing hyperparameters 126 | distinguishing_keys = [ 127 | key for key, values in values_by_key.items() if len(values) > 1 128 | ] 129 | 130 | runs_by_minimal_config_name = {} 131 | for runs in runs_by_config_name.values(): 132 | group_config = _flatten_omegaconf(runs[0].config) 133 | minimal_config_name = group_config["algorithm.name"].upper() 134 | config_name = " ".join( 135 | [ 136 | f"{key}={group_config[key]}" 137 | for key in distinguishing_keys 138 | if key in group_config 139 | ] 140 | ) 141 | if config_name: 142 | minimal_config_name += f" ({config_name})" 143 | runs_by_minimal_config_name[minimal_config_name] = runs 144 | 145 | runs_by_config_name = runs_by_minimal_config_name 146 | 147 | return [Group(name, runs) for name, runs in runs_by_config_name.items()] 148 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | fastmarl/outputs/ 3 | fastmarl/multirun/ 4 | # Created by https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode 5 | # Edit at https://www.gitignore.io/?templates=linux,python,windows,pycharm+all,visualstudiocode 6 | .vscode 7 | ### Linux ### 8 | *~ 9 | 10 | # temporary files which can be created if a process still has a handle open of a deleted file 11 | .fuse_hidden* 12 | 13 | # KDE directory preferences 14 | .directory 15 | 16 | # Linux trash folder which might appear on any partition or disk 17 | .Trash-* 18 | 19 | # .nfs files are created when an open file is removed but is still being accessed 20 | .nfs* 21 | 22 | ### PyCharm+all ### 23 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 24 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 25 | 26 | # User-specific stuff 27 | .idea/**/workspace.xml 28 | .idea/**/tasks.xml 29 | .idea/**/usage.statistics.xml 30 | .idea/**/dictionaries 31 | .idea/**/shelf 32 | 33 | # Generated files 34 | .idea/**/contentModel.xml 35 | 36 | # Sensitive or high-churn files 37 | .idea/**/dataSources/ 38 | .idea/**/dataSources.ids 39 | .idea/**/dataSources.local.xml 40 | .idea/**/sqlDataSources.xml 41 | .idea/**/dynamic.xml 42 | .idea/**/uiDesigner.xml 43 | .idea/**/dbnavigator.xml 44 | 45 | # Gradle 46 | .idea/**/gradle.xml 47 | .idea/**/libraries 48 | 49 | # Gradle and Maven with auto-import 50 | # When using Gradle or Maven with auto-import, you should exclude module files, 51 | # since they will be recreated, and may cause churn. Uncomment if using 52 | # auto-import. 53 | # .idea/modules.xml 54 | # .idea/*.iml 55 | # .idea/modules 56 | # *.iml 57 | # *.ipr 58 | 59 | # CMake 60 | cmake-build-*/ 61 | 62 | # Mongo Explorer plugin 63 | .idea/**/mongoSettings.xml 64 | 65 | # File-based project format 66 | *.iws 67 | 68 | # IntelliJ 69 | out/ 70 | 71 | # mpeltonen/sbt-idea plugin 72 | .idea_modules/ 73 | 74 | # JIRA plugin 75 | atlassian-ide-plugin.xml 76 | 77 | # Cursive Clojure plugin 78 | .idea/replstate.xml 79 | 80 | # Crashlytics plugin (for Android Studio and IntelliJ) 81 | com_crashlytics_export_strings.xml 82 | crashlytics.properties 83 | crashlytics-build.properties 84 | fabric.properties 85 | 86 | # Editor-based Rest Client 87 | .idea/httpRequests 88 | 89 | # Android studio 3.1+ serialized cache file 90 | .idea/caches/build_file_checksums.ser 91 | 92 | ### PyCharm+all Patch ### 93 | # Ignores the whole .idea folder and all .iml files 94 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 95 | 96 | .idea/ 97 | 98 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 99 | 100 | *.iml 101 | modules.xml 102 | .idea/misc.xml 103 | *.ipr 104 | 105 | # Sonarlint plugin 106 | .idea/sonarlint 107 | 108 | ### Python ### 109 | # Byte-compiled / optimized / DLL files 110 | __pycache__/ 111 | *.py[cod] 112 | *$py.class 113 | 114 | # C extensions 115 | *.so 116 | 117 | # Distribution / packaging 118 | .Python 119 | build/ 120 | develop-eggs/ 121 | dist/ 122 | downloads/ 123 | eggs/ 124 | .eggs/ 125 | lib/ 126 | lib64/ 127 | parts/ 128 | sdist/ 129 | var/ 130 | wheels/ 131 | pip-wheel-metadata/ 132 | share/python-wheels/ 133 | *.egg-info/ 134 | .installed.cfg 135 | *.egg 136 | MANIFEST 137 | 138 | # PyInstaller 139 | # Usually these files are written by a python script from a template 140 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 141 | *.manifest 142 | *.spec 143 | 144 | # Installer logs 145 | pip-log.txt 146 | pip-delete-this-directory.txt 147 | 148 | # Unit test / coverage reports 149 | htmlcov/ 150 | .tox/ 151 | .nox/ 152 | .coverage 153 | .coverage.* 154 | .cache 155 | nosetests.xml 156 | coverage.xml 157 | *.cover 158 | .hypothesis/ 159 | .pytest_cache/ 160 | 161 | # Translations 162 | *.mo 163 | *.pot 164 | 165 | # Scrapy stuff: 166 | .scrapy 167 | 168 | # Sphinx documentation 169 | docs/_build/ 170 | 171 | # PyBuilder 172 | target/ 173 | 174 | # pyenv 175 | .python-version 176 | 177 | # pipenv 178 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 179 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 180 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 181 | # install all needed dependencies. 182 | #Pipfile.lock 183 | 184 | # celery beat schedule file 185 | celerybeat-schedule 186 | 187 | # SageMath parsed files 188 | *.sage.py 189 | 190 | # Spyder project settings 191 | .spyderproject 192 | .spyproject 193 | 194 | # Rope project settings 195 | .ropeproject 196 | 197 | # Mr Developer 198 | .mr.developer.cfg 199 | .project 200 | .pydevproject 201 | 202 | # mkdocs documentation 203 | /site 204 | 205 | # mypy 206 | .mypy_cache/ 207 | .dmypy.json 208 | dmypy.json 209 | 210 | # Pyre type checker 211 | .pyre/ 212 | 213 | ### VisualStudioCode ### 214 | .vscode/* 215 | !.vscode/settings.json 216 | !.vscode/tasks.json 217 | !.vscode/launch.json 218 | !.vscode/extensions.json 219 | 220 | ### VisualStudioCode Patch ### 221 | # Ignore all local history of files 222 | .history 223 | 224 | ### Windows ### 225 | # Windows thumbnail cache files 226 | Thumbs.db 227 | Thumbs.db:encryptable 228 | ehthumbs.db 229 | ehthumbs_vista.db 230 | 231 | # Dump file 232 | *.stackdump 233 | 234 | # Folder config file 235 | [Dd]esktop.ini 236 | 237 | # Recycle Bin used on file shares 238 | $RECYCLE.BIN/ 239 | 240 | # Windows Installer files 241 | *.cab 242 | *.msi 243 | *.msix 244 | *.msm 245 | *.msp 246 | 247 | # Windows shortcuts 248 | *.lnk 249 | 250 | # End of https://www.gitignore.io/api/linux,python,windows,pycharm+all,visualstudiocode -------------------------------------------------------------------------------- /marlbase/utils/wrappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | A collection of environment wrappers for multi-agent environments 3 | """ 4 | 5 | from collections import deque 6 | from time import perf_counter 7 | 8 | import gymnasium as gym 9 | from gymnasium import ObservationWrapper, spaces 10 | import numpy as np 11 | 12 | 13 | class RecordEpisodeStatistics(gym.Wrapper): 14 | """Multi-agent version of RecordEpisodeStatistics gym wrapper""" 15 | 16 | def __init__(self, env, deque_size=100): 17 | super().__init__(env) 18 | self.t0 = perf_counter() 19 | self.episode_reward = np.zeros(self.unwrapped.n_agents) 20 | self.episode_length = 0 21 | self.reward_queue = deque(maxlen=deque_size) 22 | self.length_queue = deque(maxlen=deque_size) 23 | 24 | def reset(self, **kwargs): 25 | obs, info = super().reset(**kwargs) 26 | self.episode_reward = 0 27 | self.episode_length = 0 28 | self.t0 = perf_counter() 29 | return obs, info 30 | 31 | def step(self, action): 32 | observation, reward, done, truncated, info = super().step(action) 33 | self.episode_reward += np.array(reward, dtype=np.float32) 34 | self.episode_length += 1 35 | if done or truncated: 36 | info["episode_returns"] = self.episode_reward 37 | if len(self.episode_reward) == self.unwrapped.n_agents: 38 | for i, agent_reward in enumerate(self.episode_reward): 39 | info[f"agent{i}/episode_returns"] = agent_reward 40 | info["episode_length"] = self.episode_length 41 | info["episode_time"] = perf_counter() - self.t0 42 | 43 | self.reward_queue.append(self.episode_reward) 44 | self.length_queue.append(self.episode_length) 45 | return observation, reward, done, truncated, info 46 | 47 | 48 | class FlattenObservation(ObservationWrapper): 49 | r"""Observation wrapper that flattens the observation of individual agents.""" 50 | 51 | def __init__(self, env): 52 | super(FlattenObservation, self).__init__(env) 53 | ma_spaces = [] 54 | for sa_obs in env.observation_space: 55 | flatdim = spaces.flatdim(sa_obs) 56 | ma_spaces += [ 57 | spaces.Box( 58 | low=-float("inf"), 59 | high=float("inf"), 60 | shape=(flatdim,), 61 | dtype=np.float32, 62 | ) 63 | ] 64 | self.observation_space = spaces.Tuple(tuple(ma_spaces)) 65 | 66 | def observation(self, observation): 67 | return tuple( 68 | [ 69 | spaces.flatten(obs_space, obs) 70 | for obs_space, obs in zip(self.env.observation_space, observation) 71 | ] 72 | ) 73 | 74 | 75 | class ObserveID(gym.ObservationWrapper): 76 | def __init__(self, env): 77 | super().__init__(env) 78 | agent_count = env.unwrapped.n_agents 79 | for obs_space in self.observation_space: 80 | assert ( 81 | isinstance(obs_space, gym.spaces.Box) and len(obs_space.shape) == 1 82 | ), "ObserveID wrapper assumes flattened observation space." 83 | self.observation_space = gym.spaces.Tuple( 84 | tuple( 85 | [ 86 | gym.spaces.Box( 87 | low=-np.inf, 88 | high=np.inf, 89 | shape=((x.shape[0] + agent_count),), 90 | dtype=x.dtype, 91 | ) 92 | for x in self.observation_space 93 | ] 94 | ) 95 | ) 96 | 97 | def observation(self, observation): 98 | observation = np.stack(observation) 99 | observation = np.concatenate( 100 | (np.eye(self.unwrapped.n_agents, dtype=observation.dtype), observation), 101 | axis=1, 102 | ) 103 | return [o.squeeze() for o in np.split(observation, self.unwrapped.n_agents)] 104 | 105 | 106 | class CooperativeReward(gym.RewardWrapper): 107 | def reward(self, reward): 108 | return self.unwrapped.n_agents * [sum(reward)] 109 | 110 | 111 | class StandardiseReward(gym.RewardWrapper): 112 | def __init__(self, *args, **kwargs): 113 | super().__init__(*args, **kwargs) 114 | self.stdr_wrp_sumw = np.zeros(self.unwrapped.n_agents, dtype=np.float32) 115 | self.stdr_wrp_wmean = np.zeros(self.unwrapped.n_agents, dtype=np.float32) 116 | self.stdr_wrp_t = np.zeros(self.unwrapped.n_agents, dtype=np.float32) 117 | self.stdr_wrp_n = 0 118 | 119 | def reward(self, reward): 120 | # based on http://www.nowozin.net/sebastian/blog/streaming-mean-and-variance-computation.html 121 | # update running avg and std 122 | weight = 1.0 123 | 124 | q = reward - self.stdr_wrp_wmean 125 | temp_sumw = self.stdr_wrp_sumw + weight 126 | r = q * weight / temp_sumw 127 | 128 | self.stdr_wrp_wmean += r 129 | self.stdr_wrp_t += q * r * self.stdr_wrp_sumw 130 | self.stdr_wrp_sumw = temp_sumw 131 | self.stdr_wrp_n += 1 132 | 133 | if self.stdr_wrp_n == 1: 134 | return reward 135 | 136 | # calculate standardised reward 137 | var = (self.stdr_wrp_t * self.stdr_wrp_n) / ( 138 | self.stdr_wrp_sumw * (self.stdr_wrp_n - 1) 139 | ) 140 | stdr_rew = (reward - self.stdr_wrp_wmean) / (np.sqrt(var) + 1e-6) 141 | return stdr_rew 142 | 143 | 144 | class ClearInfo(gym.Wrapper): 145 | def step(self, action): 146 | observation, reward, done, truncated, _ = self.env.step(action) 147 | return observation, reward, done, truncated, {} 148 | -------------------------------------------------------------------------------- /marlbase/utils/loggers.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from hashlib import sha256 3 | import json 4 | import logging 5 | import math 6 | import time 7 | from typing import Dict, List 8 | 9 | import numpy as np 10 | from omegaconf import DictConfig, OmegaConf 11 | import pandas as pd 12 | 13 | 14 | def squash_info(info): 15 | new_info = {} 16 | keys = set([k for i in info for k in i.keys()]) 17 | keys.discard("TimeLimit.truncated") 18 | keys.discard("terminal_observation") 19 | for key in keys: 20 | values = [d[key] for d in info if key in d] 21 | if len(values) == 1: 22 | new_info[key] = values[0] 23 | continue 24 | 25 | mean = np.mean([np.array(v).sum() for v in values]) 26 | std = np.std([np.array(v).sum() for v in values]) 27 | 28 | split_key = key.rsplit("/", 1) 29 | mean_key = split_key[:] 30 | std_key = split_key[:] 31 | mean_key[-1] = "mean_" + mean_key[-1] 32 | std_key[-1] = "std_" + std_key[-1] 33 | 34 | new_info["/".join(mean_key)] = mean 35 | new_info["/".join(std_key)] = std 36 | return new_info 37 | 38 | 39 | class Logger: 40 | def __init__(self, project_name, cfg: DictConfig) -> None: 41 | self.config_hash = sha256( 42 | json.dumps( 43 | {k: v for k, v in OmegaConf.to_container(cfg).items() if k != "seed"}, 44 | sort_keys=True, 45 | ).encode("utf8") 46 | ).hexdigest()[-10:] 47 | 48 | self._total_steps = cfg.algorithm.total_steps 49 | self._start_time = time.time() 50 | self._prev_time = None 51 | self._prev_steps = (0, 0) # steps (updates) and env_samples 52 | 53 | def log_metrics(self, metrics: List[Dict]): ... 54 | 55 | def print_progress(self, updates, steps, mean_returns, episodes): 56 | self.info(f"Updates {updates}, Environment timesteps {steps}") 57 | 58 | time_now = time.time() 59 | 60 | elapsed_wallclock = time_now - self._prev_time[0] if self._prev_time else None 61 | elapsed_cpu = ( 62 | time.process_time() - self._prev_time[1] if self._prev_time else None 63 | ) 64 | elapsed_from_start = timedelta(seconds=math.ceil((time_now - self._start_time))) 65 | 66 | completed = steps / self._total_steps 67 | 68 | if elapsed_wallclock: 69 | ups = (updates - self._prev_steps[0]) / elapsed_wallclock 70 | fps = (steps - self._prev_steps[1]) / elapsed_wallclock 71 | self.info(f"UPS: {ups:.2f}, FPS: {fps:.2f} (wall time)") 72 | 73 | # ups = (updates - self._prev_steps[0]) / elapsed_cpu 74 | # fps = (steps - self._prev_steps[1]) / elapsed_cpu 75 | # self.info(f"UPS: {ups:.2f}, FPS: {fps:.2f} (cpu time)") 76 | 77 | eta = elapsed_from_start * (1 - completed) / completed 78 | eta = timedelta(seconds=math.ceil(eta.total_seconds())) 79 | self.info(f"Elapsed Time: {elapsed_from_start}") 80 | self.info(f"Estim. Time Left: {eta}") 81 | 82 | self.info(f"Completed: {100*completed:.2f}%") 83 | 84 | self._prev_steps = (updates, steps) 85 | self._prev_time = time.time(), time.process_time() 86 | 87 | self.info(f"Last {episodes} episodes with mean returns: {mean_returns:.3f}") 88 | self.info("-------------------------------------------") 89 | 90 | def watch(self, model): 91 | self.debug(model) 92 | 93 | def debug(self, *args, **kwargs): 94 | return logging.debug(*args, **kwargs) 95 | 96 | def info(self, *args, **kwargs): 97 | return logging.info(*args, **kwargs) 98 | 99 | def warning(self, *args, **kwargs): 100 | return logging.warning(*args, **kwargs) 101 | 102 | def error(self, *args, **kwargs): 103 | return logging.error(*args, **kwargs) 104 | 105 | def critical(self, *args, **kwargs): 106 | return logging.critical(*args, **kwargs) 107 | 108 | def get_state(self): 109 | return None 110 | 111 | 112 | class WandbLogger(Logger): 113 | def __init__(self, project_name, cfg: DictConfig) -> None: 114 | import wandb 115 | 116 | super().__init__(project_name, cfg) 117 | self._run = wandb.init( 118 | project=project_name, 119 | config=OmegaConf.to_container(cfg), 120 | monitor_gym=True, 121 | group=self.config_hash, 122 | ) 123 | 124 | def log_metrics(self, metrics: List[Dict]): 125 | d = squash_info(metrics) 126 | self._run.log(d) 127 | 128 | self.print_progress( 129 | d["updates"], 130 | d["environment_steps"], 131 | d["mean_episode_returns"], 132 | len(metrics) - 1, 133 | ) 134 | 135 | def watch(self, model): 136 | self.debug(model) 137 | self._run.watch(model) 138 | 139 | 140 | class FileSystemLogger(Logger): 141 | def __init__(self, project_name, cfg): 142 | super().__init__(project_name, cfg) 143 | 144 | self.results_path = "results.csv" 145 | self.config_path = "config.yaml" 146 | with open(self.config_path, "w") as f: 147 | OmegaConf.save(cfg, f) 148 | 149 | def log_metrics(self, metrics): 150 | d = squash_info(metrics) 151 | df = pd.DataFrame.from_dict([d])[ 152 | ["environment_steps"] 153 | + sorted([k for k in d.keys() if k != "environment_steps"]) 154 | ] 155 | # Since we are appending, we only want to write the csv headers if the file does not already exist 156 | # the following codeblock handles this automatically 157 | with open(self.results_path, "a") as f: 158 | df.to_csv(f, header=f.tell() == 0, index=False) 159 | 160 | self.print_progress( 161 | d["updates"], 162 | d["environment_steps"], 163 | d["mean_episode_returns"], 164 | len(metrics) - 1, 165 | ) 166 | 167 | def get_state(self): 168 | df = pd.read_csv(self.results_path, index_col=0) 169 | return df 170 | -------------------------------------------------------------------------------- /marlbase/utils/stats.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | 7 | def _load_data_from_subfolder(folder, metric, step=None, step_metric=None): 8 | """Helper function for pulling results data from logs 9 | 10 | Args: 11 | folder (str): 12 | metric (str): 13 | step (int): 14 | step_metric (str): 15 | 16 | Returns: 17 | list of performance values 18 | """ 19 | # The given folder will contain several sub-folders with random hashes like "1a8fdsk3" 20 | # Within each sub-folder is the data we need 21 | results = [] 22 | 23 | for subfolder in os.listdir(folder): 24 | data = pd.read_csv(f'{os.path.join(folder, subfolder, "results.csv")}') 25 | 26 | if step is not None and step_metric is not None: 27 | data = [data[data[step_metric] == step][metric].tolist()[0]] 28 | 29 | else: 30 | data = data[metric].tolist() 31 | 32 | results.append(data) 33 | 34 | return results 35 | 36 | 37 | def make_agg_metrics_intervals(folders, algos, metric, step=None, step_metric=None): 38 | """Pulls results for the 'Aggregate metrics with 95% Stratified Bootstrap CIs' plot 39 | Can also be used for "Performance Profiles" plot 40 | 41 | Below is an example usage for this function: 42 | make_agg_metrics_intervals( 43 | folders=[folder, folder, folder, folder], 44 | algos=['ac', 'ac', 'dqn', 'dqn'], 45 | metric=['mean_reward', 'mean_reward', 'mean_reward', 'mean_reward'], 46 | step=[240, 240, 500, 500], 47 | step_metric=['environment_steps', 'environment_steps', 'updates', 'updates'] 48 | ) 49 | 50 | Shape of the output data is {'algo_1': (n_runs x n_envs), ..., 'algo_j': (n_runs x n_envs} 51 | 52 | Args: 53 | folders (List[str]): 54 | algos (List[str]): 55 | metric (List[str]): 56 | step (List[int]): 57 | step_metric (List[str]): 58 | 59 | Returns: 60 | Dict of performance matrices 61 | """ 62 | # For the interval estimates plot, we need performance at a specific point during training/evaluation 63 | if step is None: 64 | raise ValueError('For interval plots, a specific step must be specified') 65 | if step_metric is None: 66 | raise ValueError('For interval plots, a specific step_metric must be specified') 67 | 68 | # Process for reading in the data 69 | results = {} 70 | 71 | for i in range(len(folders)): 72 | data = _load_data_from_subfolder(os.path.join(folders[i], algos[i]), metric[i], step[i], step_metric[i]) 73 | 74 | if algos[i] not in results.keys(): 75 | results[algos[i]] = [] 76 | 77 | results[algos[i]].append(data) 78 | 79 | # Now we need to transpose the pulled results into results matrices. For specific shape, see function docstring 80 | results_T = {} 81 | 82 | for algo in results.keys(): 83 | pulled_results = results[algo] 84 | results_T[algo] = np.array(pulled_results).T[0] 85 | 86 | return results_T 87 | 88 | 89 | def make_agg_metrics_pxy(folders, algos, metric, step=None, step_metric=None): 90 | """Pulls results for the 'Probability of Improvement' plot 91 | 92 | Below is an example usage for this function: 93 | make_agg_metrics_pxy( 94 | folders=[folder, folder, folder, folder], 95 | algos=['ac', 'ac', 'dqn', 'dqn'], 96 | metric=['mean_reward', 'mean_reward', 'mean_reward', 'mean_reward'], 97 | step=[240, 240, 500, 500], 98 | step_metric=['environment_steps', 'environment_steps', 'updates', 'updates'] 99 | ) 100 | 101 | Shape of the output data is {'algo_1,algo_2': ((n_runs x n_envs), (n_runs x n_envs)), ...} 102 | 103 | Args: 104 | folders (List[str]): 105 | algos (List[str]): 106 | metric (List[str]): 107 | step (List[int]): 108 | step_metric (List[str]): 109 | 110 | Returns: 111 | Dicts of comparative performance matrices 112 | """ 113 | # First pulling the metrics as we would for other single-value plots 114 | agg_metrics = make_agg_metrics_intervals(folders=folders, algos=algos, metric=metric, 115 | step=step, step_metric=step_metric) 116 | 117 | # Now building out the combinatorics dict 118 | results = {} 119 | 120 | for i in range(len(algos)): 121 | for j in range(len(algos)): 122 | if i == j: 123 | continue 124 | results[f'{algos[i]},{algos[j]}'] = (agg_metrics[algos[i]], agg_metrics[algos[j]]) 125 | 126 | return results 127 | 128 | 129 | def make_agg_metrics_efficiency(folders, algos, metric): 130 | """Pulls results for the 'Aggregate metrics with 95% Stratified Bootstrap CIs' plot 131 | Can also be used for "Performance Profiles" plot 132 | 133 | Below is an example usage for this function: 134 | make_agg_metrics_efficiency( 135 | folders=[folder, folder, folder, folder], 136 | algos=['ac', 'ac', 'dqn', 'dqn'], 137 | metric=['mean_reward', 'mean_reward', 'mean_reward', 'mean_reward'], 138 | ) 139 | 140 | Shape of the output data is {'algo_1': (n_runs x n_envs x n_steps), ...,} 141 | 142 | Args: 143 | folders (List[str]): 144 | algos (List[str]): 145 | metric (List[str]): 146 | step (List[int]): 147 | step_metric (List[str]): 148 | 149 | Returns: 150 | Dict of performance matrices 151 | """ 152 | step = [None for _ in range(len(algos))] 153 | step_metric = [None for _ in range(len(algos))] 154 | 155 | # Process for reading in the data 156 | results = {} 157 | 158 | for i in range(len(folders)): 159 | data = _load_data_from_subfolder(os.path.join(folders[i], algos[i]), metric[i], step[i], step_metric[i]) 160 | 161 | if algos[i] not in results.keys(): 162 | results[algos[i]] = [] 163 | 164 | results[algos[i]].append(data) 165 | 166 | results_T = {} 167 | 168 | for algo in results.keys(): 169 | pulled_results = results[algo] 170 | 171 | n_envs = len(pulled_results) 172 | n_runs = len(pulled_results[0]) 173 | n_steps = len(pulled_results[0][0]) 174 | 175 | 176 | results_T[algo] = np.array(pulled_results).reshape((n_runs, n_envs, n_steps)) 177 | 178 | return results_T 179 | -------------------------------------------------------------------------------- /marlbase/ac/train.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | from pathlib import Path 3 | 4 | from einops import rearrange 5 | import numpy as np 6 | from gymnasium.spaces import flatdim 7 | import hydra 8 | from omegaconf import DictConfig 9 | import torch 10 | 11 | from marlbase.utils.video import VideoRecorder 12 | 13 | 14 | Batch = namedtuple( 15 | "Batch", ["obss", "actions", "rewards", "dones", "filled", "action_masks"] 16 | ) 17 | 18 | 19 | def _log_progress(infos, step, updates, logger): 20 | infos.append({"updates": updates, "environment_steps": step}) 21 | logger.log_metrics(infos) 22 | 23 | 24 | def _collect_trajectories( 25 | envs, model, max_ep_length, parallel_envs, n_agents, device, use_proper_termination 26 | ): 27 | running = torch.ones(parallel_envs, device=device, dtype=torch.bool) 28 | 29 | # creates and initialises storage 30 | obss, info = envs.reset() 31 | obss = [torch.tensor(o, device=device) for o in obss] 32 | parallel_envs = envs.observation_space[0].shape[0] 33 | obs_dim = flatdim(envs.single_observation_space) 34 | num_actions = max(action_space.n for action_space in envs.single_action_space) 35 | 36 | batch_obs = torch.zeros( 37 | max_ep_length + 1, 38 | parallel_envs, 39 | obs_dim, 40 | device=device, 41 | ) 42 | batch_done = torch.zeros( 43 | max_ep_length + 1, parallel_envs, device=device, dtype=torch.bool 44 | ) 45 | batch_act = torch.zeros( 46 | max_ep_length, parallel_envs, n_agents, device=device, dtype=torch.long 47 | ) 48 | batch_rew = torch.zeros(max_ep_length, parallel_envs, n_agents, device=device) 49 | batch_filled = torch.zeros(max_ep_length, parallel_envs, device=device) 50 | t = 0 51 | infos = [] 52 | 53 | # if the environment provides action masks, use them 54 | if "action_mask" in info: 55 | batch_action_masks = torch.ones( 56 | max_ep_length + 1, parallel_envs, n_agents, num_actions, device=device 57 | ) 58 | mask = np.stack(info["action_mask"], dtype=np.float32) 59 | action_mask = torch.tensor(mask, dtype=torch.float32, device=device) 60 | batch_action_masks[0] = action_mask 61 | action_mask = action_mask.swapaxes(0, 1) 62 | else: 63 | batch_action_masks = None 64 | action_mask = None 65 | 66 | # set initial obs 67 | batch_obs[0] = torch.cat(obss, dim=-1) 68 | 69 | actor_hiddens = model.init_actor_hiddens(parallel_envs) 70 | 71 | while running.any(): 72 | with torch.no_grad(): 73 | actions, actor_hiddens = model.act( 74 | obss, 75 | actor_hiddens, 76 | action_mask=action_mask, 77 | ) 78 | 79 | next_obss, rewards, done, truncated, info = envs.step( 80 | actions.squeeze().tolist() 81 | ) 82 | next_obss = [torch.tensor(o, device=device) for o in next_obss] 83 | 84 | done = torch.tensor(done, dtype=torch.bool, device=device) 85 | truncated = torch.tensor(truncated, dtype=torch.bool, device=device) 86 | if not use_proper_termination: 87 | # TODO: does this make sense? 88 | done = torch.logical_or(done, truncated) 89 | 90 | batch_obs[t + 1, running, :] = torch.cat(next_obss, dim=1)[running] 91 | batch_act[t, running] = rearrange(actions, "N B 1 -> B N")[running] 92 | batch_done[t + 1, running] = done[running] 93 | batch_rew[t, running] = torch.tensor(rewards, dtype=torch.float32, device=device)[running] 94 | batch_filled[t, running] = 1 95 | if "action_mask" in info: 96 | mask = np.stack(info["action_mask"], dtype=np.float32) 97 | action_mask = torch.tensor(mask, dtype=torch.float32, device=device) 98 | batch_action_masks[t + 1, running] = action_mask[running] 99 | action_mask = action_mask.swapaxes(0, 1) 100 | 101 | if done.any(): 102 | for i, d in enumerate(done): 103 | if d: 104 | assert ( 105 | "final_info" in info 106 | and info["final_info"][i] is not None 107 | and "episode_returns" in info["final_info"][i] 108 | ), "Finished episode info does not contain expected statistics." 109 | infos.append(info["final_info"][i]) 110 | running[i] = False 111 | 112 | t += 1 113 | obss = next_obss 114 | 115 | batch = Batch( 116 | batch_obs, batch_act, batch_rew, batch_done, batch_filled, batch_action_masks 117 | ) 118 | 119 | return t, batch, infos 120 | 121 | 122 | def record_episodes(env, model, n_timesteps, path, device): 123 | recorder = VideoRecorder() 124 | done = True 125 | 126 | for _ in range(n_timesteps): 127 | if done: 128 | obss, info = env.reset() 129 | hiddens = model.init_actor_hiddens(1) 130 | if "action_mask" in info: 131 | action_mask = torch.tensor( 132 | info["action_mask"], dtype=torch.float32, device=device 133 | ) 134 | else: 135 | action_mask = None 136 | done = False 137 | else: 138 | with torch.no_grad(): 139 | obss = torch.tensor(obss, dtype=torch.float32, device=device).unsqueeze( 140 | 1 141 | ) 142 | actions, hiddens = model.act(obss, hiddens, action_mask) 143 | obss, _, done, truncated, info = env.step([a.item() for a in actions]) 144 | if "action_mask" in info: 145 | action_mask = torch.tensor( 146 | info["action_mask"], dtype=torch.float32, device=device 147 | ) 148 | done = done or truncated 149 | recorder.record_frame(env) 150 | 151 | Path(path).parent.mkdir(parents=True, exist_ok=True) 152 | recorder.save(path) 153 | 154 | 155 | def main(envs, eval_env, logger, time_limit, **cfg): 156 | cfg = DictConfig(cfg) 157 | 158 | model = hydra.utils.instantiate( 159 | cfg.model, envs.single_observation_space, envs.single_action_space, cfg 160 | ) 161 | logger.watch(model) 162 | 163 | parallel_envs = envs.observation_space[0].shape[0] 164 | 165 | step = 0 166 | updates = 0 167 | last_eval = 0 168 | last_save = 0 169 | last_video = 0 170 | while step < cfg.total_steps + 1: 171 | t, batch, infos = _collect_trajectories( 172 | envs, 173 | model, 174 | time_limit, 175 | parallel_envs, 176 | model.n_agents, 177 | cfg.model.device, 178 | cfg.use_proper_termination, 179 | ) 180 | 181 | metrics = model.update(batch, step) 182 | infos.append(metrics) 183 | 184 | if (step - last_eval) >= cfg.eval_interval: 185 | _log_progress(infos, step, updates, logger) 186 | last_eval = step 187 | 188 | if cfg.save_interval and (step - last_save) >= cfg.save_interval: 189 | Path("checkpoints").mkdir(exist_ok=True) 190 | torch.save(model.state_dict(), f"checkpoints/model_s{step}.pt") 191 | last_save = step 192 | 193 | if cfg.video_interval and (step - last_video) >= cfg.video_interval: 194 | record_episodes( 195 | eval_env, 196 | model, 197 | cfg.video_frames, 198 | f"./videos/step-{step}.mp4", 199 | cfg.model.device, 200 | ) 201 | last_video = step 202 | 203 | updates += 1 204 | step += t * parallel_envs 205 | 206 | envs.close() 207 | -------------------------------------------------------------------------------- /marlbase/utils/models.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Union 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | 8 | def orthogonal_init(m): 9 | nn.init.orthogonal_(m.weight.data, gain=np.sqrt(2)) 10 | nn.init.constant_(m.bias.data, 0) 11 | return m 12 | 13 | 14 | class FCNetwork(nn.Module): 15 | def __init__( 16 | self, dims, activation=nn.ReLU, final_activation=None, use_orthogonal_init=True 17 | ): 18 | """ 19 | Create fully-connected network 20 | :param dims: list of dimensions for the network 21 | :param activation: activation function to use 22 | :param final_activation: activation function to use on output (if any) 23 | :param use_orthogonal_init: whether to use orthogonal initialization 24 | :return: sequential network 25 | """ 26 | super().__init__() 27 | mods = [] 28 | 29 | input_size = dims[0] 30 | h_sizes = dims[1:] 31 | 32 | init_fn = orthogonal_init if use_orthogonal_init else lambda x: x 33 | 34 | mods = [init_fn(nn.Linear(input_size, h_sizes[0]))] 35 | for i in range(len(h_sizes) - 1): 36 | mods.append(activation()) 37 | mods.append(init_fn(nn.Linear(h_sizes[i], h_sizes[i + 1]))) 38 | 39 | if final_activation: 40 | mods.append(final_activation()) 41 | 42 | self.network = nn.Sequential(*mods) 43 | 44 | def init_hiddens(self, batch_size, device): 45 | return None 46 | 47 | def forward(self, x, h=None): 48 | return self.network(x), None 49 | 50 | 51 | class RNNNetwork(nn.Module): 52 | def __init__( 53 | self, 54 | dims, 55 | rnn=nn.GRU, 56 | activation=nn.ReLU, 57 | final_activation=None, 58 | use_orthogonal_init=True, 59 | ): 60 | """ 61 | Create recurrent network (last layer is fully connected) 62 | :param dims: list of dimensions for the network 63 | :param activation: activation function to use 64 | :param final_activation: activation function to use on output (if any) 65 | :param use_orthogonal_init: whether to use orthogonal initialization 66 | :return: sequential network 67 | """ 68 | super().__init__() 69 | assert ( 70 | len(dims) > 2 71 | ), "Need at least 3 dimensions for RNN (1 input dim, >= 1 hidden dim, 1 output dim)" 72 | 73 | assert rnn in [nn.GRU, nn.LSTM], "Only GRU and LSTM are supported" 74 | 75 | input_size = dims[0] 76 | rnn_hiddens = dims[1:-1] 77 | rnn_layers = len(rnn_hiddens) - 1 78 | rnn_hidden_size = rnn_hiddens[0] 79 | assert all( 80 | rnn_hidden_size == h for h in rnn_hiddens 81 | ), "Expect same hidden size across all RNN layers" 82 | output_size = dims[-1] 83 | 84 | self.first_layer = nn.Linear(input_size, rnn_hidden_size) 85 | self.rnn = rnn( 86 | input_size=rnn_hidden_size, 87 | hidden_size=rnn_hidden_size, 88 | num_layers=rnn_layers, 89 | batch_first=False, 90 | ) 91 | self.activation = activation() 92 | self.final_layer = nn.Linear(rnn_hidden_size, output_size) 93 | if use_orthogonal_init: 94 | self.final_layer = orthogonal_init(self.final_layer) 95 | 96 | self.final_activation = final_activation 97 | 98 | def init_hiddens(self, batch_size, device): 99 | return torch.zeros( 100 | self.rnn.num_layers, 101 | batch_size, 102 | self.rnn.hidden_size, 103 | device=device, 104 | ) 105 | 106 | def forward(self, x, h=None): 107 | assert x.dim() == 3, "Expect input to be 3D tensor (seq_len, batch, input_size)" 108 | assert ( 109 | h is None or h.dim() == 3 110 | ), "Expect hidden state to be 3D tensor (num_layers, batch, hidden_size)" 111 | x = self.activation(self.first_layer(x)) 112 | x, h = self.rnn(x, h) 113 | x = self.final_layer(x) 114 | if self.final_activation: 115 | x = self.final_activation(x) 116 | return x, h 117 | 118 | 119 | def make_network( 120 | dims, 121 | use_rnn=False, 122 | rnn=nn.GRU, 123 | activation=nn.ReLU, 124 | final_activation=None, 125 | use_orthogonal_init=True, 126 | ): 127 | if use_rnn: 128 | return RNNNetwork(dims, rnn, activation, final_activation, use_orthogonal_init) 129 | else: 130 | return FCNetwork(dims, activation, final_activation, use_orthogonal_init) 131 | 132 | 133 | class MultiAgentIndependentNetwork(nn.Module): 134 | def __init__( 135 | self, 136 | input_sizes, 137 | hidden_dims, 138 | output_sizes, 139 | use_rnn=False, 140 | use_orthogonal_init=True, 141 | ): 142 | super().__init__() 143 | assert len(input_sizes) == len( 144 | output_sizes 145 | ), "Expect same number of input and output sizes" 146 | self.independent = nn.ModuleList() 147 | 148 | for in_size, out_size in zip(input_sizes, output_sizes): 149 | dims = [in_size] + hidden_dims + [out_size] 150 | self.independent.append( 151 | make_network( 152 | dims, use_rnn=use_rnn, use_orthogonal_init=use_orthogonal_init 153 | ) 154 | ) 155 | 156 | def forward( 157 | self, 158 | inputs: Union[List[torch.Tensor], torch.Tensor], 159 | hiddens: Optional[List[torch.Tensor]] = None, 160 | ): 161 | if hiddens is None: 162 | hiddens = [None] * len(inputs) 163 | futures = [ 164 | torch.jit.fork(model, x, h) 165 | for model, x, h in zip(self.independent, inputs, hiddens) 166 | ] 167 | results = [torch.jit.wait(fut) for fut in futures] 168 | outs = [x for x, _ in results] 169 | hiddens = [h for _, h in results] 170 | return outs, hiddens 171 | 172 | def init_hiddens(self, batch_size, device): 173 | return [model.init_hiddens(batch_size, device) for model in self.independent] 174 | 175 | 176 | class MultiAgentSharedNetwork(nn.Module): 177 | def __init__( 178 | self, 179 | input_sizes, 180 | hidden_dims, 181 | output_sizes, 182 | sharing_indices, 183 | use_rnn=False, 184 | use_orthogonal_init=True, 185 | ): 186 | super().__init__() 187 | assert len(input_sizes) == len( 188 | output_sizes 189 | ), "Expect same number of input and output sizes" 190 | self.num_agents = len(input_sizes) 191 | 192 | if sharing_indices is True: 193 | self.sharing_indices = len(input_sizes) * [0] 194 | elif sharing_indices is False: 195 | self.sharing_indices = list(range(len(input_sizes))) 196 | else: 197 | self.sharing_indices = sharing_indices 198 | assert len(self.sharing_indices) == len( 199 | input_sizes 200 | ), "Expect same number of sharing indices as agents" 201 | 202 | self.num_networks = 0 203 | self.networks = nn.ModuleList() 204 | self.agents_by_network = [] 205 | self.input_sizes = [] 206 | self.output_sizes = [] 207 | created_networks = set() 208 | for i in self.sharing_indices: 209 | if i in created_networks: 210 | # network already created 211 | continue 212 | 213 | # agent indices that share this network 214 | network_agents = [ 215 | j for j, idx in enumerate(self.sharing_indices) if idx == i 216 | ] 217 | in_sizes = [input_sizes[j] for j in network_agents] 218 | in_size = in_sizes[0] 219 | assert all( 220 | idim == in_size for idim in in_sizes 221 | ), f"Expect same input sizes across all agents sharing network {i}" 222 | out_sizes = [output_sizes[j] for j in network_agents] 223 | out_size = out_sizes[0] 224 | assert all( 225 | odim == out_size for odim in out_sizes 226 | ), f"Expect same output sizes across all agents sharing network {i}" 227 | 228 | dims = [in_size] + hidden_dims + [out_size] 229 | self.networks.append( 230 | make_network( 231 | dims, use_rnn=use_rnn, use_orthogonal_init=use_orthogonal_init 232 | ) 233 | ) 234 | self.agents_by_network.append(network_agents) 235 | self.input_sizes.append(in_size) 236 | self.output_sizes.append(out_size) 237 | self.num_networks += 1 238 | created_networks.add(i) 239 | 240 | def forward( 241 | self, 242 | inputs: Union[List[torch.Tensor], torch.Tensor], 243 | hiddens: Optional[List[torch.Tensor]] = None, 244 | ): 245 | assert all( 246 | x.dim() == 3 for x in inputs 247 | ), "Expect each agent input to be 3D tensor (seq_len, batch, input_size)" 248 | assert hiddens is None or all( 249 | x is None or x.dim() == 3 for x in hiddens 250 | ), "Expect hidden state to be 3D tensor (num_layers, batch, hidden_size)" 251 | 252 | batch_size = inputs[0].size(1) 253 | assert all( 254 | x.size(1) == batch_size for x in inputs 255 | ), "Expect all agent inputs to have same batch size" 256 | 257 | # group inputs and hiddens by network 258 | network_inputs = [] 259 | network_hiddens = [] 260 | for agent_indices in self.agents_by_network: 261 | net_inputs = [inputs[i] for i in agent_indices] 262 | if hiddens is None or all(h is None for h in hiddens): 263 | net_hiddens = None 264 | else: 265 | net_hiddens = [hiddens[i] for i in agent_indices] 266 | network_inputs.append(torch.cat(net_inputs, dim=1)) 267 | network_hiddens.append( 268 | torch.cat(net_hiddens, dim=1) if net_hiddens is not None else None 269 | ) 270 | 271 | # forward through networks 272 | futures = [ 273 | torch.jit.fork(network, x, h) 274 | for network, x, h in zip(self.networks, network_inputs, network_hiddens) 275 | ] 276 | results = [torch.jit.wait(fut) for fut in futures] 277 | outs = [x.split(batch_size, dim=1) for x, _ in results] 278 | hiddens = [ 279 | h.split(batch_size, dim=1) if h is not None else None for _, h in results 280 | ] 281 | 282 | # group outputs by agents 283 | agent_outputs = [] 284 | agent_hiddens = [] 285 | self.idx_by_network = [0] * self.num_networks 286 | for network_idx in self.sharing_indices: 287 | idx_within_network = self.idx_by_network[network_idx] 288 | agent_outputs.append(outs[network_idx][idx_within_network]) 289 | if hiddens[network_idx] is not None: 290 | agent_hiddens.append(hiddens[network_idx][idx_within_network]) 291 | else: 292 | agent_hiddens.append(None) 293 | self.idx_by_network[network_idx] += 1 294 | return agent_outputs, agent_hiddens 295 | 296 | def init_hiddens(self, batch_size, device): 297 | return [ 298 | self.networks[network_idx].init_hiddens(batch_size, device) 299 | for network_idx in self.sharing_indices 300 | ] 301 | -------------------------------------------------------------------------------- /marlbase/dqn/train.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import math 3 | from pathlib import Path 4 | 5 | # from cpprb import ReplayBuffer 6 | import hydra 7 | import numpy as np 8 | from omegaconf import DictConfig 9 | import torch 10 | 11 | from marlbase.utils.video import VideoRecorder 12 | 13 | 14 | Batch = namedtuple( 15 | "Batch", ["obss", "actions", "rewards", "dones", "filled", "action_mask"] 16 | ) 17 | 18 | 19 | class ReplayBuffer: 20 | def __init__( 21 | self, 22 | buffer_size, 23 | n_agents, 24 | observation_space, 25 | action_space, 26 | max_episode_length, 27 | device, 28 | store_action_masks=False, 29 | ): 30 | self.buffer_size = buffer_size 31 | self.n_agents = n_agents 32 | self.max_episode_length = max_episode_length 33 | self.store_action_masks = store_action_masks 34 | self.device = device 35 | 36 | self.pos = 0 37 | self.cur_pos = 0 38 | self.t = 0 39 | 40 | self.observations = [ 41 | np.zeros( 42 | (max_episode_length + 1, buffer_size, *observation_space[i].shape), 43 | dtype=np.float32, 44 | ) 45 | for i in range(n_agents) 46 | ] 47 | self.actions = np.zeros( 48 | (n_agents, max_episode_length, buffer_size), dtype=np.int64 49 | ) 50 | self.rewards = np.zeros( 51 | (n_agents, max_episode_length, buffer_size), dtype=np.float32 52 | ) 53 | self.dones = np.zeros((max_episode_length + 1, buffer_size), dtype=bool) 54 | self.filled = np.zeros((max_episode_length, buffer_size), dtype=bool) 55 | if store_action_masks: 56 | action_dim = max(action_space.n for action_space in action_space) 57 | self.action_masks = np.zeros( 58 | (n_agents, max_episode_length + 1, buffer_size, action_dim), 59 | dtype=np.float32, 60 | ) 61 | 62 | def __len__(self): 63 | return min(self.pos, self.buffer_size) 64 | 65 | def init_episode(self, obss, action_masks=None): 66 | self.t = 0 67 | for i in range(self.n_agents): 68 | self.observations[i][0, self.cur_pos] = obss[i] 69 | if action_masks is not None: 70 | assert self.store_action_masks, "Action masks not stored in buffer!" 71 | self.action_masks[:, 0, self.cur_pos] = action_masks 72 | 73 | def add(self, obss, acts, rews, done, action_masks=None): 74 | assert self.t < self.max_episode_length, "Episode longer than given max length!" 75 | for i in range(self.n_agents): 76 | self.observations[i][self.t + 1, self.cur_pos] = obss[i] 77 | self.actions[:, self.t, self.cur_pos] = acts 78 | self.rewards[:, self.t, self.cur_pos] = rews 79 | self.dones[self.t + 1, self.cur_pos] = done 80 | self.filled[self.t, self.cur_pos] = True 81 | if action_masks is not None: 82 | assert self.store_action_masks, "Action masks not stored in buffer!" 83 | self.action_masks[:, self.t + 1, self.cur_pos] = action_masks 84 | self.t += 1 85 | 86 | if done: 87 | self.pos += 1 88 | self.cur_pos = self.pos % self.buffer_size 89 | self.t = 0 90 | 91 | def can_sample(self, batch_size): 92 | return self.pos >= batch_size 93 | 94 | def sample(self, batch_size): 95 | idx = np.random.randint(0, len(self), size=batch_size) 96 | obss = torch.stack( 97 | [ 98 | torch.tensor( 99 | self.observations[i][:, idx], 100 | dtype=torch.float32, 101 | device=self.device, 102 | ) 103 | for i in range(self.n_agents) 104 | ] 105 | ) 106 | actions = torch.tensor( 107 | self.actions[:, :, idx], dtype=torch.int64, device=self.device 108 | ) 109 | rewards = torch.tensor( 110 | self.rewards[:, :, idx], dtype=torch.float32, device=self.device 111 | ) 112 | dones = torch.tensor( 113 | self.dones[:, idx], dtype=torch.float32, device=self.device 114 | ) 115 | filled = torch.tensor( 116 | self.filled[:, idx], dtype=torch.float32, device=self.device 117 | ) 118 | if self.store_action_masks: 119 | action_mask = torch.tensor( 120 | self.action_masks[:, :, idx], dtype=torch.float32, device=self.device 121 | ) 122 | else: 123 | action_mask = None 124 | return Batch(obss, actions, rewards, dones, filled, action_mask) 125 | 126 | 127 | def _epsilon_schedule( 128 | decay_style, decay_over, eps_start, eps_end, exp_decay_rate, total_steps 129 | ): 130 | """ 131 | Exponential decay schedule for exploration epsilon. 132 | :param decay_style: style of epsilon schedule. One of "linear"/ "lin" or "exponential"/ "exp". 133 | :param decay_over: fraction of total steps over which to decay epsilon. 134 | :param eps_start: starting epsilon value. 135 | :param eps_end: ending epsilon value. 136 | :param exp_decay_rate: decay rate for exponential decay. 137 | :param total_steps: total number of steps to take. 138 | :return: Epsilon schedule function mapping step number to epsilon value. 139 | """ 140 | assert decay_style in [ 141 | "linear", 142 | "lin", 143 | "exponential", 144 | "exp", 145 | ], "decay_style must be one of 'linear' or 'exponential'" 146 | assert 0 <= eps_start <= 1 and 0 <= eps_end <= 1, "eps must be in [0, 1]" 147 | assert eps_start >= eps_end, "eps_start must be >= eps_end" 148 | assert 0 < decay_over <= 1, "decay_over must be in (0, 1]" 149 | assert total_steps > 0, "total_steps must be > 0" 150 | assert exp_decay_rate > 0, "eps_decay must be > 0" 151 | 152 | if decay_style in ["linear", "lin"]: 153 | 154 | def _thunk(steps_done): 155 | return max( 156 | eps_end 157 | + (eps_start - eps_end) * (1 - steps_done / (total_steps * decay_over)), 158 | eps_end, 159 | ) 160 | 161 | elif decay_style in ["exponential", "exp"]: 162 | # decaying over all steps 163 | # eps_decay = (eps_start - eps_end) / total_steps * exp_decay_rate 164 | # decaying over decay_over fraction of steps 165 | eps_decay = (eps_start - eps_end) / (total_steps * decay_over) * exp_decay_rate 166 | 167 | def _thunk(steps_done): 168 | return max( 169 | eps_end + (eps_start - eps_end) * math.exp(-eps_decay * steps_done), 170 | eps_end, 171 | ) 172 | else: 173 | raise ValueError("decay_style must be one of 'linear' or 'exponential'") 174 | return _thunk 175 | 176 | 177 | def _evaluate(env, model, eval_episodes, eval_epsilon): 178 | infos = [] 179 | while len(infos) < eval_episodes: 180 | obs, info = env.reset() 181 | hiddens = model.init_hiddens(1) 182 | action_mask = ( 183 | np.stack(info["action_mask"], dtype=np.float32) 184 | if "action_mask" in info 185 | else None 186 | ) 187 | done = False 188 | while not done: 189 | with torch.no_grad(): 190 | actions, hiddens = model.act(obs, hiddens, eval_epsilon, action_mask) 191 | obs, _, done, truncated, info = env.step(actions) 192 | done = done or truncated 193 | action_mask = ( 194 | np.stack(info["action_mask"], dtype=np.float32) 195 | if "action_mask" in info 196 | else None 197 | ) 198 | infos.append(info) 199 | return infos 200 | 201 | 202 | def _collect_trajectory(env, model, rb, epsilon, use_proper_termination): 203 | obss, info = env.reset() 204 | action_mask = ( 205 | np.stack(info["action_mask"], dtype=np.float32) 206 | if "action_mask" in info 207 | else None 208 | ) 209 | rb.init_episode(obss, action_mask) 210 | hiddens = model.init_hiddens(1) 211 | done = False 212 | t = 0 213 | 214 | while not done: 215 | with torch.no_grad(): 216 | actions, hiddens = model.act(obss, hiddens, epsilon, action_mask) 217 | next_obss, rews, done, truncated, info = env.step(actions) 218 | 219 | if use_proper_termination: 220 | # TODO: Previously was always False here? 221 | # also previously had other option "ignore"? Why was that separate from "ignore"? 222 | proper_done = done 223 | else: 224 | # here previously was always done? 225 | proper_done = done or truncated 226 | done = done or truncated 227 | action_mask = ( 228 | np.stack(info["action_mask"], dtype=np.float32) 229 | if "action_mask" in info 230 | else None 231 | ) 232 | 233 | rb.add(next_obss, actions, rews, proper_done, action_mask) 234 | t += 1 235 | obss = next_obss 236 | 237 | return t, info 238 | 239 | 240 | def record_episodes(env, model, n_timesteps, path, device, epsilon): 241 | recorder = VideoRecorder() 242 | done = True 243 | 244 | for _ in range(n_timesteps): 245 | if done: 246 | obss, info = env.reset() 247 | hiddens = model.init_hiddens(1) 248 | done = False 249 | else: 250 | action_mask = ( 251 | np.stack(info["action_mask"], dtype=np.float32) 252 | if "action_mask" in info 253 | else None 254 | ) 255 | with torch.no_grad(): 256 | actions, hiddens = model.act(obss, hiddens, epsilon, action_mask) 257 | obss, _, done, truncated, info = env.step(actions) 258 | recorder.record_frame(env) 259 | 260 | Path(path).parent.mkdir(parents=True, exist_ok=True) 261 | recorder.save(path) 262 | 263 | 264 | def main(env, eval_env, logger, time_limit, **cfg): 265 | cfg = DictConfig(cfg) 266 | 267 | _, info = env.reset() 268 | model = hydra.utils.instantiate( 269 | cfg.model, env.observation_space, env.action_space, cfg 270 | ) 271 | 272 | logger.watch(model) 273 | 274 | rb = ReplayBuffer( 275 | cfg.buffer_size, 276 | env.unwrapped.n_agents, 277 | env.observation_space, 278 | env.action_space, 279 | time_limit, 280 | cfg.model.device, 281 | store_action_masks="action_mask" in info, 282 | ) 283 | 284 | eps_sched = _epsilon_schedule( 285 | cfg.eps_decay_style, 286 | cfg.eps_decay_over, 287 | cfg.eps_start, 288 | cfg.eps_end, 289 | cfg.eps_exp_decay_rate, 290 | cfg.total_steps, 291 | ) 292 | 293 | updates = 0 294 | step = 0 295 | last_eval = 0 296 | last_video = 0 297 | last_save = 0 298 | while step < cfg.total_steps + 1: 299 | t, _ = _collect_trajectory( 300 | env, 301 | model, 302 | rb, 303 | eps_sched(step), 304 | cfg.use_proper_termination, 305 | ) 306 | step += t 307 | 308 | if step > cfg.training_start and rb.can_sample(cfg.batch_size): 309 | batch = rb.sample(cfg.batch_size) 310 | metrics = model.update(batch) 311 | updates += 1 312 | else: 313 | metrics = {} 314 | 315 | if cfg.eval_interval and (step - last_eval) >= cfg.eval_interval: 316 | infos = _evaluate(eval_env, model, cfg.eval_episodes, cfg.eps_evaluation) 317 | if metrics: 318 | infos.append(metrics) 319 | infos.append( 320 | { 321 | "updates": updates, 322 | "environment_steps": step, 323 | "epsilon": eps_sched(step), 324 | } 325 | ) 326 | logger.log_metrics(infos) 327 | last_eval = step 328 | 329 | if cfg.video_interval and (step - last_video) >= cfg.video_interval: 330 | record_episodes( 331 | eval_env, 332 | model, 333 | cfg.video_frames, 334 | f"./videos/step-{step}.mp4", 335 | cfg.model.device, 336 | cfg.eps_evaluation, 337 | ) 338 | last_video = step 339 | 340 | if cfg.save_interval and (step - last_save) >= cfg.save_interval: 341 | Path("checkpoints").mkdir(exist_ok=True) 342 | torch.save(model.state_dict(), f"checkpoints/model_s{step}.pt") 343 | last_save = step 344 | 345 | env.close() 346 | -------------------------------------------------------------------------------- /marlbase/ac/model.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | from einops import rearrange 4 | from gymnasium.spaces import flatdim 5 | import torch 6 | from torch.distributions import Categorical 7 | import torch.nn as nn 8 | from torch import optim 9 | 10 | from marlbase.utils.models import MultiAgentIndependentNetwork, MultiAgentSharedNetwork 11 | from marlbase.utils.utils import MultiCategorical, compute_nstep_returns 12 | from marlbase.utils.standardise_stream import RunningMeanStd 13 | 14 | 15 | def _split_batch(splits): 16 | def thunk(batch): 17 | return torch.split(batch, splits, dim=-1) 18 | 19 | return thunk 20 | 21 | 22 | class A2CNetwork(nn.Module): 23 | def __init__( 24 | self, 25 | obs_space, 26 | action_space, 27 | cfg, 28 | actor, 29 | critic, 30 | device, 31 | ): 32 | super(A2CNetwork, self).__init__() 33 | self.gamma = cfg.gamma 34 | self.entropy_coef = cfg.entropy_coef 35 | self.n_steps = cfg.n_steps 36 | self.grad_clip = cfg.grad_clip 37 | self.value_loss_coef = cfg.value_loss_coef 38 | self.device = device 39 | 40 | self.n_agents = len(obs_space) 41 | obs_dims = [flatdim(o) for o in obs_space] 42 | act_dims = [flatdim(a) for a in action_space] 43 | 44 | if not actor.parameter_sharing: 45 | self.actor = MultiAgentIndependentNetwork( 46 | obs_dims, 47 | list(actor.layers), 48 | act_dims, 49 | actor.use_rnn, 50 | actor.use_orthogonal_init, 51 | ) 52 | else: 53 | self.actor = MultiAgentSharedNetwork( 54 | obs_dims, 55 | list(actor.layers), 56 | act_dims, 57 | actor.parameter_sharing, 58 | actor.use_rnn, 59 | actor.use_orthogonal_init, 60 | ) 61 | 62 | self.centralised_critic = critic.centralised 63 | critic_obs_shape = ( 64 | self.n_agents * [sum(obs_dims)] if critic.centralised else obs_dims 65 | ) 66 | 67 | if not critic.parameter_sharing: 68 | self.critic = MultiAgentIndependentNetwork( 69 | critic_obs_shape, 70 | list(critic.layers), 71 | [1] * self.n_agents, 72 | critic.use_rnn, 73 | critic.use_orthogonal_init, 74 | ) 75 | self.target_critic = MultiAgentIndependentNetwork( 76 | critic_obs_shape, 77 | list(critic.layers), 78 | [1] * self.n_agents, 79 | critic.use_rnn, 80 | critic.use_orthogonal_init, 81 | ) 82 | else: 83 | self.critic = MultiAgentSharedNetwork( 84 | critic_obs_shape, 85 | list(critic.layers), 86 | [1] * self.n_agents, 87 | critic.parameter_sharing, 88 | critic.use_rnn, 89 | critic.use_orthogonal_init, 90 | ) 91 | self.target_critic = MultiAgentSharedNetwork( 92 | critic_obs_shape, 93 | list(critic.layers), 94 | [1] * self.n_agents, 95 | critic.parameter_sharing, 96 | critic.use_rnn, 97 | critic.use_orthogonal_init, 98 | ) 99 | 100 | self.soft_update(1.0) 101 | self.to(device) 102 | 103 | optimizer = getattr(optim, cfg.optimizer) 104 | if type(optimizer) is str: 105 | optimizer = getattr(optim, optimizer) 106 | self.optimizer_class = optimizer 107 | 108 | lr = cfg.lr 109 | self.optimizer = optimizer(self.parameters(), lr=lr) 110 | self.target_update_interval_or_tau = cfg.target_update_interval_or_tau 111 | 112 | self.standardise_returns = cfg.standardise_returns 113 | if self.standardise_returns: 114 | self.ret_ms = RunningMeanStd(shape=(self.n_agents,), device=device) 115 | 116 | self.split_obs = _split_batch([flatdim(s) for s in obs_space]) 117 | self.split_act = _split_batch(self.n_agents * [1]) 118 | 119 | print(self) 120 | 121 | def init_critic_hiddens(self, batch_size, target=False): 122 | if target: 123 | return self.target_critic.init_hiddens(batch_size, self.device) 124 | else: 125 | return self.critic.init_hiddens(batch_size, self.device) 126 | 127 | def init_actor_hiddens(self, batch_size): 128 | return self.actor.init_hiddens(batch_size, self.device) 129 | 130 | def forward(self, inputs, rnn_hxs, masks): 131 | raise NotImplementedError( 132 | "Forward not implemented. Use act, get_value, get_target_value or evaluate_actions instead." 133 | ) 134 | 135 | def get_dist(self, action_logits, action_mask=None): 136 | if action_mask is not None: 137 | masked_logits = [] 138 | for logits, mask in zip(action_logits, action_mask): 139 | masked_logits.append(logits * mask + (1 - mask) * -1e8) 140 | action_logits = masked_logits 141 | 142 | dist = MultiCategorical( 143 | [Categorical(logits=logits) for logits in action_logits] 144 | ) 145 | return dist 146 | 147 | def act(self, inputs, actor_hiddens, action_mask=None): 148 | inputs = [i.unsqueeze(0) for i in inputs] 149 | actor_logits, actor_hiddens = self.actor(inputs, actor_hiddens) 150 | actor_logits = [logits.squeeze(0) for logits in actor_logits] 151 | dist = self.get_dist(actor_logits, action_mask) 152 | actions = dist.sample() 153 | return torch.stack(actions, dim=0), actor_hiddens 154 | 155 | def get_value(self, inputs, critic_hiddens, target=False): 156 | if self.centralised_critic: 157 | inputs = self.n_agents * [torch.cat(inputs, dim=-1)] 158 | 159 | if target: 160 | values, critic_hiddens = self.target_critic(inputs, critic_hiddens) 161 | else: 162 | values, critic_hiddens = self.critic(inputs, critic_hiddens) 163 | return torch.cat(values, dim=-1), critic_hiddens 164 | 165 | def evaluate_actions( 166 | self, 167 | inputs, 168 | action, 169 | critic_hiddens, 170 | actor_hiddens, 171 | action_mask=None, 172 | state=None, 173 | ): 174 | if state is None: 175 | state = inputs 176 | value, critic_hiddens = self.get_value(state, critic_hiddens) 177 | actor_features, actor_hiddens = self.actor(inputs, actor_hiddens) 178 | dist = self.get_dist(actor_features, action_mask) 179 | action_log_probs = torch.cat(dist.log_probs(action), dim=-1) 180 | dist_entropy = torch.stack(dist.entropy(), dim=-1).sum(dim=-1) 181 | 182 | return (value, action_log_probs, dist_entropy, critic_hiddens, actor_hiddens) 183 | 184 | def soft_update(self, t): 185 | source, target = self.critic, self.target_critic 186 | for target_param, source_param in zip(target.parameters(), source.parameters()): 187 | target_param.data.copy_((1 - t) * target_param.data + t * source_param.data) 188 | 189 | def update(self, batch, step): 190 | with torch.no_grad(): 191 | next_value, _ = self.get_value( 192 | self.split_obs(batch.obss), critic_hiddens=None, target=True 193 | ) 194 | 195 | if self.standardise_returns: 196 | next_value = next_value * torch.sqrt(self.ret_ms.var) + self.ret_ms.mean 197 | 198 | batch_done = batch.dones.float().unsqueeze(-1).repeat(1, 1, self.n_agents) 199 | returns = compute_nstep_returns( 200 | batch.rewards, batch_done, next_value, self.n_steps, self.gamma 201 | ) 202 | if self.standardise_returns: 203 | self.ret_ms.update(returns) 204 | returns = (returns - self.ret_ms.mean) / torch.sqrt(self.ret_ms.var) 205 | 206 | values, action_log_probs, entropy, _, _ = self.evaluate_actions( 207 | self.split_obs(batch.obss[:-1]), 208 | self.split_act(batch.actions), 209 | critic_hiddens=None, 210 | actor_hiddens=None, 211 | action_mask=rearrange(batch.action_masks[:-1], "E B N A -> N E B A") 212 | if batch.action_masks is not None 213 | else None, 214 | ) 215 | 216 | advantage = returns - values 217 | 218 | actor_loss = ( 219 | -(action_log_probs * advantage.detach()).sum(dim=-1) 220 | - self.entropy_coef * entropy 221 | ) 222 | actor_loss = (actor_loss * batch.filled).sum() / batch.filled.sum() 223 | value_loss = (returns - values).pow(2).sum(dim=-1) 224 | value_loss = (value_loss * batch.filled).sum() / batch.filled.sum() 225 | 226 | loss = actor_loss + self.value_loss_coef * value_loss 227 | self.optimizer.zero_grad() 228 | loss.backward() 229 | if self.grad_clip: 230 | torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) 231 | self.optimizer.step() 232 | 233 | if ( 234 | self.target_update_interval_or_tau > 1.0 235 | and step % self.target_update_interval_or_tau == 0 236 | ): 237 | self.soft_update(1.0) 238 | elif self.target_update_interval_or_tau < 1.0: 239 | self.soft_update(self.target_update_interval_or_tau) 240 | 241 | return { 242 | "loss": loss.item(), 243 | "actor_loss": actor_loss.item(), 244 | "value_loss": value_loss.item(), 245 | "entropy": ((entropy * batch.filled).sum() / batch.filled.sum()).item(), 246 | } 247 | 248 | 249 | class PPONetwork(A2CNetwork): 250 | def __init__( 251 | self, 252 | obs_space, 253 | action_space, 254 | cfg, 255 | actor, 256 | critic, 257 | device, 258 | ): 259 | super(PPONetwork, self).__init__( 260 | obs_space, action_space, cfg, actor, critic, device 261 | ) 262 | self.num_epochs = cfg.num_epochs 263 | self.ppo_clip = cfg.ppo_clip 264 | 265 | def update(self, batch, step): 266 | # compute returns 267 | with torch.no_grad(): 268 | next_value, _ = self.get_value( 269 | self.split_obs(batch.obss), critic_hiddens=None, target=True 270 | ) 271 | 272 | if self.standardise_returns: 273 | next_value = next_value * torch.sqrt(self.ret_ms.var) + self.ret_ms.mean 274 | 275 | batch_done = batch.dones.float().unsqueeze(-1).repeat(1, 1, self.n_agents) 276 | returns = compute_nstep_returns( 277 | batch.rewards, batch_done, next_value, self.n_steps, self.gamma 278 | ).detach() 279 | if self.standardise_returns: 280 | self.ret_ms.update(returns) 281 | returns = (returns - self.ret_ms.mean) / torch.sqrt(self.ret_ms.var) 282 | 283 | # compute old policy log probs 284 | with torch.no_grad(): 285 | actor_features, _ = self.actor(self.split_obs(batch.obss[:-1]), None) 286 | dist = self.get_dist( 287 | actor_features, 288 | action_mask=rearrange(batch.action_masks[:-1], "E B N A -> N E B A") 289 | if batch.action_masks is not None 290 | else None, 291 | ) 292 | old_action_log_probs = torch.cat( 293 | dist.log_probs(self.split_act(batch.actions)), dim=-1 294 | ).detach() 295 | 296 | metrics = defaultdict(list) 297 | for _ in range(self.num_epochs): 298 | # sample from current policy 299 | values, action_log_probs, entropy, _, _ = self.evaluate_actions( 300 | self.split_obs(batch.obss[:-1]), 301 | self.split_act(batch.actions), 302 | critic_hiddens=None, 303 | actor_hiddens=None, 304 | action_mask=rearrange(batch.action_masks[:-1], "E B N A -> N E B A") 305 | if batch.action_masks is not None 306 | else None, 307 | ) 308 | 309 | # compute advantage and value loss 310 | advantage = returns - values 311 | value_loss = advantage.pow(2).sum(dim=-1) 312 | 313 | # compute policy loss 314 | ratio = torch.exp(action_log_probs - old_action_log_probs) 315 | surr1 = ratio * advantage.detach() 316 | surr2 = ( 317 | torch.clamp(ratio, 1.0 - self.ppo_clip, 1.0 + self.ppo_clip) 318 | * advantage.detach() 319 | ) 320 | actor_loss = ( 321 | -torch.min(surr1, surr2).sum(dim=-1) - self.entropy_coef * entropy 322 | ) 323 | 324 | # apply masks and compute total loss per epoch 325 | actor_loss = (actor_loss * batch.filled).sum() / batch.filled.sum() 326 | value_loss = (value_loss * batch.filled).sum() / batch.filled.sum() 327 | loss = actor_loss + self.value_loss_coef * value_loss 328 | 329 | # optimisation step 330 | self.optimizer.zero_grad() 331 | loss.backward() 332 | if self.grad_clip: 333 | torch.nn.utils.clip_grad_norm_(self.parameters(), self.grad_clip) 334 | self.optimizer.step() 335 | 336 | metrics["loss"].append(loss.item()) 337 | metrics["actor_loss"].append(actor_loss.item()) 338 | metrics["value_loss"].append(value_loss.item()) 339 | metrics["entropy"].append( 340 | ((entropy * batch.filled).sum() / batch.filled.sum()).item() 341 | ) 342 | 343 | # update target network after last epoch 344 | if ( 345 | self.target_update_interval_or_tau > 1.0 346 | and step % self.target_update_interval_or_tau == 0 347 | ): 348 | self.soft_update(1.0) 349 | elif self.target_update_interval_or_tau < 1.0: 350 | self.soft_update(self.target_update_interval_or_tau) 351 | 352 | return {key: sum(values) / len(values) for key, values in metrics.items()} 353 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |