├── .flake8 ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── arguments.py ├── coma_config.py ├── idqn_config.py ├── qmix_config.py └── qtran_config.py ├── docs ├── awesome_marl.md ├── harl.md ├── marllib.md └── pymarl.md ├── marltoolkit ├── __init__.py ├── __version__.py ├── agents │ ├── __init__.py │ ├── ac_agent.py │ ├── base_agent.py │ ├── coma_agent.py │ ├── idqn_agent.py │ ├── maddpg_agent.py │ ├── mappo_agent.py │ ├── qatten_agent.py │ ├── qmix2_agent.py │ ├── qmix_agent.py │ ├── qtran_agent.py │ └── vdn_agent.py ├── data │ ├── __init__.py │ ├── base_buffer.py │ ├── ma_buffer.py │ ├── offpolicy_buffer.py │ ├── onpolicy_buffer.py │ └── shared_buffer.py ├── envs │ ├── __init__.py │ ├── base_env.py │ ├── pettingzoo │ │ ├── __init__.py │ │ ├── custom_env1.py │ │ ├── custom_env2.py │ │ └── pettingzoo_env.py │ ├── smacv1 │ │ ├── __init__.py │ │ ├── smac_env.py │ │ └── smac_vec_env.py │ ├── smacv2 │ │ ├── __init__.py │ │ └── smacv2_env.py │ ├── vec_env │ │ ├── __init__.py │ │ ├── base_vec_env.py │ │ ├── dummy_vec_env.py │ │ ├── subproc_vec_env.py │ │ ├── utils.py │ │ └── vec_monitor.py │ └── waregame │ │ ├── __init__.py │ │ └── wargame_wrapper.py ├── modules │ ├── __init__.py │ ├── actors │ │ ├── __init__.py │ │ ├── mlp.py │ │ ├── r_actor.py │ │ └── rnn.py │ ├── critics │ │ ├── __init__.py │ │ ├── coma.py │ │ ├── maddpg.py │ │ ├── mlp.py │ │ └── r_critic.py │ ├── mixers │ │ ├── __init__.py │ │ ├── qatten.py │ │ ├── qmixer.py │ │ ├── qtran.py │ │ └── vdn.py │ └── utils │ │ ├── __init__.py │ │ ├── act.py │ │ ├── common.py │ │ ├── distributions.py │ │ ├── popart.py │ │ └── valuenorm.py ├── runners │ ├── __init__.py │ ├── episode_runner.py │ ├── onpolicy_runner.py │ ├── parallel_episode_runner.py │ └── qtran_runner.py.py └── utils │ ├── __init__.py │ ├── env_utils.py │ ├── logger │ ├── __init__.py │ ├── base.py │ ├── base_orig.py │ ├── logging.py │ ├── logs.py │ ├── tensorboard.py │ └── wandb.py │ ├── lr_scheduler.py │ ├── model_utils.py │ ├── progressbar.py │ ├── timer.py │ └── transforms.py ├── requirements.txt └── scripts ├── main_coma.py ├── main_idqn.py ├── main_mappo.py ├── main_qmix.py ├── main_qtran.py ├── main_vdn.py ├── run.sh ├── smac_runner.py ├── test_multi_env.py └── test_multi_env2.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = E501, W503, E126, W504, E402,E251 3 | max-line-length = 79 4 | show-source = False 5 | application-import-names = marltoolkit 6 | exclude = 7 | .git 8 | docs/ 9 | venv 10 | build 11 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | environment/ 108 | ENVIORNMENT/ 109 | venv/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | .pyre/ 131 | work_dir/ 132 | work_dirs/ 133 | .DS_Store 134 | events* 135 | logs/ 136 | wandb/ 137 | nohup.out 138 | protein_data 139 | **/checkpoints 140 | *.npy 141 | hub/ 142 | 周报/ 143 | multiagent-particle-envs/ 144 | results/ 145 | *.mp4 146 | *.json 147 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_third_party =numpy,smac,torch,wandb 3 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://gitee.com/openmmlab/mirrors-flake8 3 | rev: 5.0.4 4 | hooks: 5 | - id: flake8 6 | - repo: https://gitee.com/openmmlab/mirrors-isort 7 | rev: 5.11.5 8 | hooks: 9 | - id: isort 10 | - repo: https://gitee.com/openmmlab/mirrors-yapf 11 | rev: v0.32.0 12 | hooks: 13 | - id: yapf 14 | - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks 15 | rev: v4.3.0 16 | hooks: 17 | - id: trailing-whitespace 18 | - id: check-yaml 19 | - id: end-of-file-fixer 20 | - id: requirements-txt-fixer 21 | - id: double-quote-string-fixer 22 | - id: check-merge-conflict 23 | - id: fix-encoding-pragma 24 | args: ["--remove"] 25 | - id: mixed-line-ending 26 | args: ["--fix=lf"] 27 | - repo: https://gitee.com/openmmlab/mirrors-mdformat 28 | rev: 0.7.9 29 | hooks: 30 | - id: mdformat 31 | args: ["--number"] 32 | additional_dependencies: 33 | - mdformat-openmmlab 34 | - mdformat_frontmatter 35 | - linkify-it-py 36 | - repo: https://gitee.com/openmmlab/mirrors-docformatter 37 | rev: v1.3.1 38 | hooks: 39 | - id: docformatter 40 | args: ["--in-place", "--wrap-descriptions", "79"] 41 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/configs/__init__.py -------------------------------------------------------------------------------- /configs/coma_config.py: -------------------------------------------------------------------------------- 1 | class ComaConfig: 2 | """Configuration class for Coma model. 3 | 4 | ComaConfig contains parameters used to instantiate a Coma model. 5 | These parameters define the model architecture, behavior, and training settings. 6 | 7 | Args: 8 | rnn_hidden_dim (int, optional): Dimension of GRU's hidden state. 9 | gamma (float, optional): Discount factor in reinforcement learning. 10 | td_lambda (float, optional): Lambda in TD(lambda). 11 | egreedy_exploration (float, optional): Initial 'epsilon' in epsilon-greedy exploration. 12 | min_exploration (float, optional): Minimum 'epsilon' in epsilon-greedy. 13 | target_update_interval (int, optional): Sync parameters to target model after 'target_update_interval' times. 14 | learning_rate (float, optional): Learning rate of the optimizer. 15 | min_learning_rate (float, optional): Minimum learning rate of the optimizer. 16 | clip_grad_norm (float, optional): Clipped value of the global norm of gradients. 17 | learner_update_freq (int, optional): Update learner frequency. 18 | double_q (bool, optional): Use Double-DQN. 19 | algo_name (str, optional): Name of the algorithm. 20 | """ 21 | 22 | model_type: str = 'coma' 23 | 24 | def __init__( 25 | self, 26 | fc_hidden_dim: int = 64, 27 | rnn_hidden_dim: int = 64, 28 | gamma: float = 0.99, 29 | td_lambda: float = 0.8, 30 | egreedy_exploration: float = 1.0, 31 | min_exploration: float = 0.1, 32 | target_update_interval: int = 1000, 33 | learning_rate: float = 0.0005, 34 | min_learning_rate: float = 0.0001, 35 | clip_grad_norm: float = 10, 36 | learner_update_freq: int = 2, 37 | double_q: bool = True, 38 | algo_name: str = 'coma', 39 | ) -> None: 40 | # Network architecture parameters 41 | self.fc_hidden_dim = fc_hidden_dim 42 | self.rnn_hidden_dim = rnn_hidden_dim 43 | 44 | # Training parameters 45 | self.gamma = gamma 46 | self.td_lambda = td_lambda 47 | self.egreedy_exploration = egreedy_exploration 48 | self.min_exploration = min_exploration 49 | self.target_update_interval = target_update_interval 50 | self.learning_rate = learning_rate 51 | self.min_learning_rate = min_learning_rate 52 | self.clip_grad_norm = clip_grad_norm 53 | self.learner_update_freq = learner_update_freq 54 | self.double_q = double_q 55 | 56 | # Logging parameters 57 | self.algo_name = algo_name 58 | -------------------------------------------------------------------------------- /configs/idqn_config.py: -------------------------------------------------------------------------------- 1 | class IDQNConfig: 2 | """Configuration class for QMix model. 3 | 4 | IDQNConfig contains parameters used to instantiate a QMix model. 5 | These parameters define the model architecture, behavior, and training settings. 6 | 7 | Args: 8 | rnn_hidden_dim (int, optional): Dimension of GRU's hidden state. 9 | gamma (float, optional): Discount factor in reinforcement learning. 10 | egreedy_exploration (float, optional): Initial 'epsilon' in epsilon-greedy exploration. 11 | min_exploration (float, optional): Minimum 'epsilon' in epsilon-greedy. 12 | target_update_interval (int, optional): Sync parameters to target model after 'target_update_interval' times. 13 | learning_rate (float, optional): Learning rate of the optimizer. 14 | min_learning_rate (float, optional): Minimum learning rate of the optimizer. 15 | clip_grad_norm (float, optional): Clipped value of the global norm of gradients. 16 | hypernet_layers (int, optional): Number of layers in hypernetwork. 17 | hypernet_embed_dim (int, optional): Embedding dimension for hypernetwork. 18 | learner_update_freq (int, optional): Update learner frequency. 19 | double_q (bool, optional): Use Double-DQN. 20 | algo_name (str, optional): Name of the algorithm. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | rnn_hidden_dim: int = 64, 26 | gamma: float = 0.99, 27 | egreedy_exploration: float = 1.0, 28 | min_exploration: float = 0.1, 29 | target_update_interval: int = 1000, 30 | learning_rate: float = 0.0005, 31 | min_learning_rate: float = 0.0001, 32 | clip_grad_norm: float = 10, 33 | learner_update_freq: int = 2, 34 | double_q: bool = True, 35 | algo_name: str = 'idqn', 36 | ) -> None: 37 | 38 | # Network architecture parameters 39 | self.rnn_hidden_dim = rnn_hidden_dim 40 | 41 | # Training parameters 42 | self.gamma = gamma 43 | self.egreedy_exploration = egreedy_exploration 44 | self.min_exploration = min_exploration 45 | self.target_update_interval = target_update_interval 46 | self.learning_rate = learning_rate 47 | self.min_learning_rate = min_learning_rate 48 | self.clip_grad_norm = clip_grad_norm 49 | self.learner_update_freq = learner_update_freq 50 | self.double_q = double_q 51 | 52 | # Logging parameters 53 | self.algo_name = algo_name 54 | -------------------------------------------------------------------------------- /configs/qmix_config.py: -------------------------------------------------------------------------------- 1 | class QMixConfig: 2 | """Configuration class for QMix model. 3 | 4 | QMixConfig contains parameters used to instantiate a QMix model. 5 | These parameters define the model architecture, behavior, and training settings. 6 | 7 | Args: 8 | mixing_embed_dim (int, optional): Embedding dimension of the mixing network. 9 | rnn_hidden_dim (int, optional): Dimension of GRU's hidden state. 10 | gamma (float, optional): Discount factor in reinforcement learning. 11 | egreedy_exploration (float, optional): Initial 'epsilon' in epsilon-greedy exploration. 12 | min_exploration (float, optional): Minimum 'epsilon' in epsilon-greedy. 13 | target_update_interval (int, optional): Sync parameters to target model after 'target_update_interval' times. 14 | learning_rate (float, optional): Learning rate of the optimizer. 15 | min_learning_rate (float, optional): Minimum learning rate of the optimizer. 16 | clip_grad_norm (float, optional): Clipped value of the global norm of gradients. 17 | hypernet_layers (int, optional): Number of layers in hypernetwork. 18 | hypernet_embed_dim (int, optional): Embedding dimension for hypernetwork. 19 | learner_update_freq (int, optional): Update learner frequency. 20 | double_q (bool, optional): Use Double-DQN. 21 | algo_name (str, optional): Name of the algorithm. 22 | """ 23 | 24 | model_type: str = 'qmix' 25 | 26 | def __init__( 27 | self, 28 | fc_hidden_dim: int = 64, 29 | rnn_hidden_dim: int = 64, 30 | gamma: float = 0.99, 31 | egreedy_exploration: float = 1.0, 32 | min_exploration: float = 0.01, 33 | target_update_tau: float = 0.05, 34 | target_update_interval: int = 100, 35 | learning_rate: float = 0.1, 36 | min_learning_rate: float = 0.00001, 37 | clip_grad_norm: float = 10, 38 | hypernet_layers: int = 2, 39 | hypernet_embed_dim: int = 64, 40 | mixing_embed_dim: int = 32, 41 | learner_update_freq: int = 3, 42 | double_q: bool = True, 43 | algo_name: str = 'qmix', 44 | ) -> None: 45 | # Network architecture parameters 46 | self.fc_hidden_dim = fc_hidden_dim 47 | self.rnn_hidden_dim = rnn_hidden_dim 48 | self.hypernet_layers = hypernet_layers 49 | self.hypernet_embed_dim = hypernet_embed_dim 50 | self.mixing_embed_dim = mixing_embed_dim 51 | 52 | # Training parameters 53 | self.gamma = gamma 54 | self.egreedy_exploration = egreedy_exploration 55 | self.min_exploration = min_exploration 56 | self.target_update_tau = target_update_tau 57 | self.target_update_interval = target_update_interval 58 | self.learning_rate = learning_rate 59 | self.min_learning_rate = min_learning_rate 60 | self.clip_grad_norm = clip_grad_norm 61 | self.learner_update_freq = learner_update_freq 62 | self.double_q = double_q 63 | 64 | # Logging parameters 65 | self.algo_name = algo_name 66 | -------------------------------------------------------------------------------- /configs/qtran_config.py: -------------------------------------------------------------------------------- 1 | QTranConfig = { 2 | 'project': 'StarCraft-II', 3 | 'scenario': '3m', 4 | 'replay_buffer_size': 5000, 5 | 'mixing_embed_dim': 32, 6 | 'rnn_hidden_dim': 64, 7 | 'learning_rate': 0.0005, 8 | 'min_learning_rate': 0.0001, 9 | 'memory_warmup_size': 32, 10 | 'gamma': 0.99, 11 | 'egreedy_exploration': 1.0, 12 | 'min_exploration': 0.1, 13 | 'target_update_interval': 2000, 14 | 'batch_size': 32, 15 | 'total_steps': 1000000, 16 | 'train_log_interval': 5, # log every 10 episode 17 | 'test_log_interval': 20, # log every 100 epidode 18 | 'clip_grad_norm': 10, 19 | 'learner_update_freq': 2, 20 | 'double_q': True, 21 | 'opt_loss_coef': 1.0, 22 | 'nopt_min_loss_coef': 0.1, 23 | 'difficulty': '7', 24 | 'algo': 'qtran', 25 | 'log_dir': 'work_dirs/', 26 | 'logger': 'wandb', 27 | } 28 | -------------------------------------------------------------------------------- /docs/harl.md: -------------------------------------------------------------------------------- 1 | # 异构智能体强化学习 2 | 3 | HARL 算法是 在一般异构智能体设置中实现有效多智能体合作的新颖解决方案,而不依赖于限制性参数共享技巧。 4 | 5 | ## 支持的算法 6 | 7 | - HAPPO 8 | - HATRPO 9 | - HAA2C 10 | - HADDPG 11 | - HATD3 12 | - HAD3QN 13 | - HASAC 14 | 15 | ## 主要特性 16 | 17 | - HARL算法通过采用顺序更新方案来实现协调智能体更新,这与MAPPO和MADDPG采用的同时更新方案不同。 18 | - HARL 算法享有单调改进和收敛于均衡的理论保证,确保其促进智能体之间合作行为的有效性。 19 | - 在线策略和离线策略 HARL 算法(分别以HAPPO和HASAC为例)在各种基准测试中都表现出了卓越的性能 20 | -------------------------------------------------------------------------------- /docs/marllib.md: -------------------------------------------------------------------------------- 1 | # MARLlib 2 | 3 | ## Intro 4 | 5 | MARLlib 是一个基于Ray和其工具包RLlib的综合多智能体强化学习算法库。它为多智能体强化学习研究社区提供了一个统一的平台,用于构建、训练和评估多智能体强化学习算法。 6 | 7 | ../_images/marllib_open.png 8 | 9 | > Overview of the MARLlib architecture.\`\` 10 | 11 | MARLlib 提供了几个突出的关键特点: 12 | 13 | 1. **统一算法Pipeline:** MARLlib通过基于Agenty 级别的分布式数据流将多样的算法Pipeline统一起来,使研究人员能够在不同的任务和环境中开发、测试和评估多智能体强化学习算法。 14 | 2. **支持各种任务模式:** MARLlib支持协作、协同、竞争和混合等所有任务模式。 15 | 3. **与Gym结构一致的接口:** MARLlib提供了一个新的与Gym 相似的环境接口,方便研究人员更加轻松的处理多智能环境。 16 | 4. **灵活参数共享策略:** MARLlib提供了灵活且可定制的参数共享策略。 17 | 18 | 使用 MARLlib,您可以享受到各种好处,例如: 19 | 20 | 1. **零MARL知识门槛:** MARLlib提供了18个预置算法,并具有简单的API,使研究人员能够在不具备多智能体强化学习领域知识的情况下开始进行实验。 21 | 2. **支持所有任务模式:** MARLlib支持几乎所有多智能体环境,使研究人员能够更轻松地尝试不同的任务模式。 22 | 3. **可定制的模型架构:** 研究人员可以从模型库中选择他们喜欢的模型架构,或者构建自己的模型。 23 | 4. **可定制的策略共享:** MARLlib提供了策略共享的分组选项,研究人员也可以创建自己的共享策略。 24 | 5. **访问上千个发布的实验:** 研究人员可以访问上千个已发布的实验,了解其他研究人员如何使用MARLlib。 25 | 26 | ## MARLlib 框架 27 | 28 | ## 环境接口 29 | 30 | ../_images/marl_env_right.png 31 | 32 | Agent-Environment Interface in MARLlib 33 | 34 | MARLlib 中的环境接口支持以下功能: 35 | 36 | 1. 智能体无关:每个智能体在训练阶段都有隔离的数据 37 | 2. 任务无关:多种环境支持一个统一的接口 38 | 3. 异步采样:灵活的智能体-环境交互模式 39 | 40 | 首先,MARLlib 将 MARL 视为单智能体 RL 过程的组合。 41 | 42 | 其次,MARLlib 将所有十种环境统一为一个抽象接口,有助于减轻算法设计工作的负担。并且该接口下的环境可以是任何实例,从而实现多任务 或者 任务无关的学习。 43 | 44 | 第三,与大多数现有的 MARL 框架仅支持智能体和环境之间的同步交互不同,MARLlib 支持异步交互风格。这要归功于RLlib灵活的数据收集机制,不同agent的数据可以通过同步和异步的方式收集和存储。 45 | 46 | ## 工作流 47 | 48 | ### 第一阶段:预学习 49 | 50 | MARLlib 通过实例化环境和智能体模型来开始强化学习过程。随后,根据环境特征生成模拟批次,并将其输入指定算法的采样/训练Pipeline中。成功完成学习工作流程并且没有遇到错误后,MARLlib 将进入后续阶段。 51 | 52 | ![../_images/rllib_data_flow_left.png](https://marllib.readthedocs.io/en/latest/_images/rllib_data_flow_left.png) 53 | 54 | > 预习阶段 55 | 56 | ### 第二阶段:采样和训练 57 | 58 | 预学习阶段完成后,MARLlib 将实际工作分配给 Worker 和 Leaner,并在执行计划下安排这些流程以启动学习过程。 59 | 60 | 在标准学习迭代期间,每个 Worker 使用智能体模型与其环境实例交互以采样数据,然后将数据传递到ReplayBuffer。ReplayBuffer 根据算法进行初始化,并决定如何存储数据。例如,对于on-policy算法,缓冲区是一个串联操作,而对于off-policy算法,缓冲区是一个FIFO队列。 61 | 62 | 随后,预定义的策略映射功能将收集到的数据分发给不同的智能体。一旦完全收集了一次训练迭代的所有数据,学习器就开始使用这些数据优化策略,并将新模型广播给每个Worker以进行下一轮采样。 63 | 64 | ![../_images/rllib_data_flow_right.png](https://marllib.readthedocs.io/en/latest/_images/rllib_data_flow_right.png) 65 | 66 | 采样和训练阶段 67 | 68 | ### 算法Pipeline 69 | 70 | ![../_images/pipeline.png](https://marllib.readthedocs.io/en/latest/_images/pipeline.png) 71 | 72 | ### 独立学习 73 | 74 | 在 MARLlib 中,由于 RLlib 提供了许多算法,实现独立学习(左)非常简单。要开始训练,可以从 RLlib 中选择一种算法并将其应用到多智能体环境中,与 RLlib 相比,无需额外的工作。尽管 MARL 中的独立学习不需要任何数据交换,但在大多数任务中其性能通常不如中心化训练策略。 75 | 76 | ### 中心化 Critic 77 | 78 | 中心化 Critic 是 MARLlib 支持的 CTDE 框架中的两种中心化训练策略之一。在这种方法下,智能体需要在获得策略输出之后但在临界值计算之前相互共享信息。这些共享信息包括 独立观察、动作和全局状态(如果有)。 79 | 80 | 交换的数据在采样阶段被收集并存储为过渡数据,其中每个过渡数据都包含自收集的数据和交换的数据。然后利用这些数据来优化中心化评价函数和分散式策略函数。信息共享的实现主要是在同策略算法的后处理函数中完成的。对于像 MADDPG 这样的Off-Policy算法,在数据进入训练迭代批次之前会收集其他数据,例如其他智能体提供的动作值。 81 | 82 | ### 价值分解 83 | 84 | 在 MARLlib 中,价值分解(VD)是另一类中心化训练策略,与中心化评价者的不同之处在于需要智能体共享信息。具体来说,仅需要在智能体之间共享预测的 Q 值或临界值,并且根据所使用的算法可能需要额外的数据。例如,QMIX 需要全局状态来计算混合 Q 值。 85 | 86 | VD 的数据收集和存储机制与中心化Critic的数据收集和存储机制类似,智能体在采样阶段收集和存储转换数据。联合Q学习方法(VDN、QMIX)基于原始PyMARL,五种VD算法中只有FACMAC、VDA2C和VDPPO遵循标准RLlib训练流程。 87 | 88 | ## 关键组件 89 | 90 | ### 数据收集前的处理 91 | 92 | MARL 算法采用 中心化 训练和分散执行(CTDE)范式,需要在学习阶段在Agent之间共享信息。在 QMIX、FACMAC 和 VDA2C 等值分解算法中,总 Q 或 V 值的计算需要Agent提供各自的 Q 或 V 值估计。相反,基于中心化评价的算法,如 MADDPG、MAPPO 和 HAPPO,需要Agent共享他们的观察和动作数据,以确定中心化评价值。后处理模块是Agent与同伴交换数据的理想位置。对于中心化评价算法,Agent可以从其他Agent获取附加信息来计算中心化评价值。另一方面,对于值分解算法,智能体必须向其他智能体提供其预测的 Q 或 V 值。此外,后处理模块还负责使用 GAE 或 N 步奖励调整等技术来计算各种学习目标。 93 | 94 | ![../_images/pp.png](https://marllib.readthedocs.io/en/latest/_images/pp.png) 95 | 96 | 数据收集前的后处理 97 | 98 | ### 批量学习前的后处理 99 | 100 | 在 MARL 算法的背景下,并非所有算法都可以利用后处理模块。其中一个例子是像 MADDPG 和 FACMAC 这样的Off-Policy算法,它们面临着ReplayBuffer中过时数据无法用于当前训练交互的挑战。为了应对这一挑战,我们实现了一个额外的“批量学习之前”功能,以在采样批次进入训练循环之前准确计算当前模型的 Q 或 V 值。这确保了用于训练的数据是最新且准确的,从而提高了训练效果。 101 | 102 | ![../_images/pp_batch.png](https://marllib.readthedocs.io/en/latest/_images/pp_batch.png) 103 | 104 | 批量学习前的后处理 105 | 106 | ### 中心化价值函数 107 | 108 | 在中心化的评价家Agent模型中,传统的仅基于Agent自我观察的价值函数被能够适应算法要求的中心化 critic所取代。中心化 critic 负责处理从其他Agent收到的信息并生成 中心化 值作为输出。 109 | 110 | ### 混合价值函数 111 | 112 | 在价值分解Agent模型中,保留了原来的价值函数,但引入了新的混合价值函数以获得总体混合值。混合功能灵活,可根据用户要求定制。目前支持VDN和QMIX混合功能。 113 | 114 | ### 异构优化 115 | 116 | 在异构优化中,各个Agent参数是独立更新的,因此,策略函数不会在不同Agent之间共享。然而,根据算法证明,顺序更新Agent的策略并设置与丢失相关的传票的值可以导致任何正更新的增量求和。 117 | 118 | 为了保证算法的增量单调性,利用信任域来获得合适的参数更新,就像HATRPO算法中的情况一样。为了在考虑计算效率的同时加速策略和评价者更新过程,HAPPO 算法中采用了近端策略优化技术。 119 | 120 | ![../_images/hetero.png](https://marllib.readthedocs.io/en/latest/_images/hetero.png) 121 | 122 | 异构Agent评价优化 123 | 124 | ### 策略映射 125 | 126 | 策略映射在标准化多智能体强化学习 (MARL) 环境的接口方面发挥着至关重要的作用。在 MARLlib 中,策略映射被实现为具有层次结构的字典。顶级键代表场景名称,第二级键包含组信息,四个附加键(**description**、**team_prefix**、 **all_agents_one_policy**和**one_agent_one_policy**)用于定义各种策略设置。 team_prefix键根据Agent的名称**对**Agent进行分组,而最后两个键指示完全共享或非共享策略策略是否适用于给定场景。利用策略映射方法来初始化策略并将其分配给不同的智能体,并且每个策略仅使用其相应策略组中的智能体采样的数据来训练。 127 | -------------------------------------------------------------------------------- /marltoolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/__init__.py -------------------------------------------------------------------------------- /marltoolkit/__version__.py: -------------------------------------------------------------------------------- 1 | """Version information.""" 2 | 3 | # The following line *must* be the last in the module, exactly as formatted: 4 | __version__ = '0.10.0' 5 | -------------------------------------------------------------------------------- /marltoolkit/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_agent import BaseAgent 2 | from .idqn_agent import IDQNAgent 3 | from .qmix_agent import QMixAgent 4 | from .vdn_agent import VDNAgent 5 | 6 | __all__ = ['BaseAgent', 'QMixAgent', 'IDQNAgent', 'VDNAgent'] 7 | -------------------------------------------------------------------------------- /marltoolkit/agents/base_agent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | 4 | class BaseAgent(ABC): 5 | 6 | def __init__(self): 7 | pass 8 | 9 | def init_hidden_states(self, **kwargs): 10 | raise NotImplementedError 11 | 12 | def sample(self, **kwargs): 13 | raise NotImplementedError 14 | 15 | def predict(self, **kwargs): 16 | raise NotImplementedError 17 | 18 | def learn(self, **kwargs): 19 | raise NotImplementedError 20 | 21 | def save_model(self, **kwargs): 22 | raise NotImplementedError 23 | 24 | def load_model(self, **kwargs): 25 | raise NotImplementedError 26 | -------------------------------------------------------------------------------- /marltoolkit/agents/maddpg_agent.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from marltoolkit.modules.actors.mlp import MLPActorModel 8 | from marltoolkit.modules.critics.mlp import MLPCriticModel 9 | 10 | from .base_agent import BaseAgent 11 | 12 | 13 | class MaddpgAgent(BaseAgent): 14 | """MADDPG algorithm. 15 | 16 | Args: 17 | model (parl.Model): forward network of actor and critic. 18 | The function get_actor_params() of model should be implemented. 19 | agent_index (int): index of agent, in multiagent env 20 | action_space (list): action_space, gym space 21 | gamma (float): discounted factor for reward computation. 22 | tau (float): decay coefficient when updating the weights of self.target_model with self.model 23 | critic_lr (float): learning rate of the critic model 24 | actor_lr (float): learning rate of the actor model 25 | """ 26 | 27 | def __init__( 28 | self, 29 | actor_model: nn.Module, 30 | critic_model: nn.Module, 31 | n_agents: int = None, 32 | obs_shape: int = None, 33 | n_actions: int = None, 34 | gamma: float = 0.95, 35 | tau: float = 0.01, 36 | actor_lr: float = 0.01, 37 | critic_lr: float = 0.01, 38 | device: str = 'cpu', 39 | ): 40 | 41 | # checks 42 | assert isinstance(gamma, float) 43 | assert isinstance(tau, float) 44 | assert isinstance(actor_lr, float) 45 | assert isinstance(critic_lr, float) 46 | 47 | self.actor_lr = actor_lr 48 | self.critic_lr = critic_lr 49 | self.global_steps = 0 50 | self.device = device 51 | 52 | self.actor_model = [ 53 | MLPActorModel(obs_shape, n_actions) for _ in range(n_agents) 54 | ] 55 | self.critic_model = [ 56 | MLPCriticModel(n_agents, obs_shape, n_actions) 57 | for _ in range(n_agents) 58 | ] 59 | self.actor_target = copy.deepcopy(actor_model) 60 | self.critic_target = copy.deepcopy(critic_model) 61 | self.actor_optimizer = [ 62 | torch.optim.Adam(model.parameters(), lr=self.actor_lr) 63 | for model in self.actor_model 64 | ] 65 | self.critic_optimizer = [ 66 | torch.optim.Adam(model.parameters(), lr=self.critic_lr) 67 | for model in self.critic_model 68 | ] 69 | 70 | def sample(self, obs, use_target_model=False): 71 | """use the policy model to sample actions. 72 | 73 | Args: 74 | obs (torch tensor): observation, shape([B] + shape of obs_n[agent_index]) 75 | use_target_model (bool): use target_model or not 76 | 77 | Returns: 78 | act (torch tensor): action, shape([B] + shape of act_n[agent_index]), 79 | noted that in the discrete case we take the argmax along the last axis as action 80 | """ 81 | obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device) 82 | 83 | if use_target_model: 84 | policy = self.actor_target(obs) 85 | else: 86 | policy = self.actor_model(obs) 87 | 88 | # add noise for action exploration 89 | if self.continuous_actions: 90 | random_normal = torch.randn(size=policy[0].shape).to(self.device) 91 | action = policy[0] + torch.exp(policy[1]) * random_normal 92 | action = torch.tanh(action) 93 | else: 94 | uniform = torch.rand_like(policy) 95 | soft_uniform = torch.log(-1.0 * torch.log(uniform)).to(self.device) 96 | action = F.softmax(policy - soft_uniform, dim=-1) 97 | 98 | action = action.detach().cpu().numpy().flatten() 99 | return action 100 | 101 | def predict(self, obs: torch.Tensor): 102 | """use the policy model to predict actions. 103 | 104 | Args: 105 | obs (torch tensor): observation, shape([B] + shape of obs_n[agent_index]) 106 | 107 | Returns: 108 | act (torch tensor): action, shape([B] + shape of act_n[agent_index]), 109 | noted that in the discrete case we take the argmax along the last axis as action 110 | """ 111 | obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device) 112 | policy = self.actor_model(obs) 113 | if self.continuous_actions: 114 | action = policy[0] 115 | action = torch.tanh(action) 116 | else: 117 | action = torch.argmax(policy, dim=-1) 118 | 119 | action = action.detach().cpu().numpy().flatten() 120 | return action 121 | 122 | def q_value(self, obs_n, act_n, use_target_model=False): 123 | """use the value model to predict Q values. 124 | 125 | Args: 126 | obs_n (list of torch tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n) 127 | act_n (list of torch tensor): all agents' action, len(agent's num) + shape([B] + shape of act_n) 128 | use_target_model (bool): use target_model or not 129 | 130 | Returns: 131 | Q (torch tensor): Q value of this agent, shape([B]) 132 | """ 133 | if use_target_model: 134 | return self.critic_target.value(obs_n, act_n) 135 | else: 136 | return self.critic_model.value(obs_n, act_n) 137 | 138 | def agent_learn(self, obs_n, act_n, target_q): 139 | """update actor and critic model with MADDPG algorithm.""" 140 | self.global_steps += 1 141 | 142 | def learn(self, obs_n, act_n, target_q): 143 | """update actor and critic model with MADDPG algorithm.""" 144 | acotr_loss = self.actor_learn(obs_n, act_n) 145 | critic_loss = self.critic_learn(obs_n, act_n, target_q) 146 | return acotr_loss, critic_loss 147 | 148 | def actor_learn(self, obs_n, act_n): 149 | i = self.agent_index 150 | 151 | sample_this_action = self.sample(obs_n[i]) 152 | action_input_n = act_n + [] 153 | action_input_n[i] = sample_this_action 154 | eval_q = self.q_value(obs_n, action_input_n) 155 | act_cost = torch.mean(-1.0 * eval_q) 156 | 157 | this_policy = self.actor_model.policy(obs_n[i]) 158 | # when continuous, 'this_policy' will be a tuple with two element: (mean, std) 159 | if self.continuous_actions: 160 | this_policy = torch.cat(this_policy, dim=-1) 161 | act_reg = torch.mean(torch.square(this_policy)) 162 | 163 | cost = act_cost + act_reg * 1e-3 164 | 165 | self.actor_optimizer.zero_grad() 166 | cost.backward() 167 | torch.nn.utils.clip_grad_norm_(self.actor_model.parameters(), 0.5) 168 | self.actor_optimizer.step() 169 | return cost 170 | 171 | def critic_learn(self, obs_n, act_n, target_q): 172 | pred_q = self.q_value(obs_n, act_n) 173 | cost = F.mse_loss(pred_q, target_q) 174 | 175 | self.critic_optimizer.zero_grad() 176 | cost.backward() 177 | torch.nn.utils.clip_grad_norm_(self.critic_model.parameters(), 0.5) 178 | self.critic_optimizer.step() 179 | return cost 180 | 181 | def update_target(self, tau): 182 | for target_param, param in zip(self.actor_target.parameters(), 183 | self.actor_model.parameters()): 184 | target_param.data.copy_(tau * param.data + 185 | (1 - tau) * target_param.data) 186 | for target_param, param in zip(self.critic_target.parameters(), 187 | self.critic_model.parameters()): 188 | target_param.data.copy_(tau * param.data + 189 | (1 - tau) * target_param.data) 190 | -------------------------------------------------------------------------------- /marltoolkit/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .ma_buffer import EpisodeData, ReplayBuffer 2 | from .offpolicy_buffer import (BaseBuffer, MaEpisodeData, OffPolicyBuffer, 3 | OffPolicyBufferRNN) 4 | 5 | __all__ = [ 6 | 'ReplayBuffer', 7 | 'EpisodeData', 8 | 'MaEpisodeData', 9 | 'OffPolicyBuffer', 10 | 'BaseBuffer', 11 | 'OffPolicyBufferRNN', 12 | ] 13 | -------------------------------------------------------------------------------- /marltoolkit/data/base_buffer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class BaseBuffer(ABC): 9 | """Abstract base class for all buffers. 10 | 11 | This class provides a basic structure for reinforcement learning experience 12 | replay buffers. 13 | 14 | Args: 15 | :param num_envs: Number of environments. 16 | :param num_agents: Number of agents. 17 | :param obs_shape: Dimensionality of the observation space. 18 | :param state_shape: Dimensionality of the state space. 19 | :param action_shape: Dimensionality of the action space. 20 | :param reward_shape: Dimensionality of the reward space. 21 | :param done_shape: Dimensionality of the done space. 22 | :param device: Device on which to store the buffer data. 23 | :param kwargs: Additional keyword arguments. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | num_envs: int, 29 | num_agents: int, 30 | obs_shape: Union[int, Tuple], 31 | state_shape: Union[int, Tuple], 32 | action_shape: Union[int, Tuple], 33 | reward_shape: Union[int, Tuple], 34 | done_shape: Union[int, Tuple], 35 | device: Union[torch.device, str] = 'cpu', 36 | **kwargs, 37 | ) -> None: 38 | super().__init__() 39 | self.num_envs = num_envs 40 | self.num_agents = num_agents 41 | self.obs_shape = obs_shape 42 | self.state_shape = state_shape 43 | self.action_shape = action_shape 44 | self.reward_shape = reward_shape 45 | self.done_shape = done_shape 46 | self.device = device 47 | self.curr_ptr = 0 48 | self.curr_size = 0 49 | 50 | def reset(self) -> None: 51 | """Reset the buffer.""" 52 | self.curr_ptr = 0 53 | self.curr_size = 0 54 | 55 | @abstractmethod 56 | def store(self, *args) -> None: 57 | """Add elements to the buffer.""" 58 | raise NotImplementedError 59 | 60 | @abstractmethod 61 | def extend(self, *args, **kwargs) -> None: 62 | """Add a new batch of transitions to the buffer.""" 63 | for data in zip(*args): 64 | self.store(*data) 65 | 66 | @abstractmethod 67 | def sample(self, **kwargs): 68 | """Sample elements from the buffer.""" 69 | raise NotImplementedError 70 | 71 | @abstractmethod 72 | def store_transitions(self, **kwargs): 73 | """Store transitions in the buffer.""" 74 | raise NotImplementedError 75 | 76 | @abstractmethod 77 | def store_episodes(self, **kwargs): 78 | """Store episodes in the buffer.""" 79 | raise NotImplementedError 80 | 81 | @abstractmethod 82 | def finish_path(self, **kwargs): 83 | """Finish a trajectory path in the buffer.""" 84 | raise NotImplementedError 85 | 86 | def size(self) -> int: 87 | """Get the current size of the buffer. 88 | 89 | :return: The current size of the buffer. 90 | """ 91 | return self.curr_size 92 | 93 | def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor: 94 | """Convert a numpy array to a PyTorch tensor. 95 | 96 | :param array: Numpy array to be converted. 97 | :param copy: Whether to copy the data or not. 98 | :return: PyTorch tensor. 99 | """ 100 | if copy: 101 | return torch.tensor(array).to(self.device) 102 | return torch.as_tensor(array).to(self.device) 103 | -------------------------------------------------------------------------------- /marltoolkit/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_env import MultiAgentEnv 2 | from .smacv1.smac_env import SMACWrapperEnv 3 | from .vec_env import BaseVecEnv, DummyVecEnv, SubprocVecEnv 4 | 5 | __all__ = [ 6 | 'SMACWrapperEnv', 7 | 'BaseVecEnv', 8 | 'DummyVecEnv', 9 | 'SubprocVecEnv', 10 | 'MultiAgentEnv', 11 | ] 12 | -------------------------------------------------------------------------------- /marltoolkit/envs/base_env.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import gymnasium as gym 4 | 5 | 6 | class MultiAgentEnv(gym.Env): 7 | """A multi-agent environment wrapper. An environment that hosts multiple 8 | independent agents. 9 | 10 | Agents are identified by (string) agent ids. Note that these "agents" here 11 | are not to be confused with RLlib Algorithms, which are also sometimes 12 | referred to as "agents" or "RL agents". 13 | 14 | The preferred format for action- and observation space is a mapping from agent 15 | ids to their individual spaces. If that is not provided, the respective methods' 16 | observation_space_contains(), action_space_contains(), 17 | action_space_sample() and observation_space_sample() have to be overwritten. 18 | """ 19 | 20 | def __init__(self) -> None: 21 | self.num_agents = None 22 | self.n_actions = None 23 | self.episode_limit = None 24 | if not hasattr(self, 'obs_space'): 25 | self.observation_space = None 26 | if not hasattr(self, 'state_space'): 27 | self.observation_space = None 28 | if not hasattr(self, 'action_space'): 29 | self.action_space = None 30 | if not hasattr(self, '_agent_ids'): 31 | self._agent_ids = set() 32 | 33 | def step(self, actions): 34 | """Returns reward, terminated, info.""" 35 | raise NotImplementedError 36 | 37 | def get_obs(self): 38 | """Returns all agent observations in a list.""" 39 | raise NotImplementedError 40 | 41 | def get_obs_agent(self, agent_id): 42 | """Returns observation for agent_id.""" 43 | raise NotImplementedError 44 | 45 | def get_obs_size(self): 46 | """Returns the shape of the observation.""" 47 | raise NotImplementedError 48 | 49 | def get_state(self): 50 | raise NotImplementedError 51 | 52 | def get_state_size(self): 53 | """Returns the shape of the state.""" 54 | raise NotImplementedError 55 | 56 | def get_avail_agent_actions(self, agent_id): 57 | """Returns the available actions for agent_id.""" 58 | raise NotImplementedError 59 | 60 | def get_available_actions(self): 61 | raise NotImplementedError 62 | 63 | def get_actions_one_hot(self): 64 | raise NotImplementedError 65 | 66 | def get_agents_id_one_hot(self): 67 | raise NotImplementedError 68 | 69 | def reset( 70 | self, 71 | seed: Optional[int] = None, 72 | options: Optional[dict] = None, 73 | ): 74 | """Returns initial observations and states.""" 75 | super().reset(seed=seed, options=options) 76 | 77 | def render(self): 78 | raise NotImplementedError 79 | 80 | def close(self): 81 | raise NotImplementedError 82 | 83 | def seed(self): 84 | raise NotImplementedError 85 | 86 | def save_replay(self): 87 | raise NotImplementedError 88 | 89 | def get_env_info(self): 90 | env_info = { 91 | 'obs_shape': self.get_obs_size(), 92 | 'state_shape': self.get_state_size(), 93 | 'num_agents': self.num_agents, 94 | 'actions_shape': self.n_actions, 95 | 'episode_limit': self.episode_limit, 96 | } 97 | return env_info 98 | -------------------------------------------------------------------------------- /marltoolkit/envs/pettingzoo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/envs/pettingzoo/__init__.py -------------------------------------------------------------------------------- /marltoolkit/envs/pettingzoo/custom_env2.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import gymnasium 4 | from gymnasium.spaces import Discrete 5 | from pettingzoo import ParallelEnv 6 | from pettingzoo.utils import parallel_to_aec, wrappers 7 | 8 | ROCK = 0 9 | PAPER = 1 10 | SCISSORS = 2 11 | NONE = 3 12 | MOVES = ['ROCK', 'PAPER', 'SCISSORS', 'None'] 13 | NUM_ITERS = 100 14 | REWARD_MAP = { 15 | (ROCK, ROCK): (0, 0), 16 | (ROCK, PAPER): (-1, 1), 17 | (ROCK, SCISSORS): (1, -1), 18 | (PAPER, ROCK): (1, -1), 19 | (PAPER, PAPER): (0, 0), 20 | (PAPER, SCISSORS): (-1, 1), 21 | (SCISSORS, ROCK): (-1, 1), 22 | (SCISSORS, PAPER): (1, -1), 23 | (SCISSORS, SCISSORS): (0, 0), 24 | } 25 | 26 | 27 | def env(render_mode=None): 28 | """The env function often wraps the environment in wrappers by default. 29 | 30 | You can find full documentation for these methods elsewhere in the 31 | developer documentation. 32 | """ 33 | internal_render_mode = render_mode if render_mode != 'ansi' else 'human' 34 | env = raw_env(render_mode=internal_render_mode) 35 | # This wrapper is only for environments which print results to the terminal 36 | if render_mode == 'ansi': 37 | env = wrappers.CaptureStdoutWrapper(env) 38 | # this wrapper helps error handling for discrete action spaces 39 | env = wrappers.AssertOutOfBoundsWrapper(env) 40 | # Provides a wide vareity of helpful user errors 41 | # Strongly recommended 42 | env = wrappers.OrderEnforcingWrapper(env) 43 | return env 44 | 45 | 46 | def raw_env(render_mode=None): 47 | """To support the AEC API, the raw_env() function just uses the 48 | from_parallel function to convert from a ParallelEnv to an AEC env.""" 49 | env = parallel_env(render_mode=render_mode) 50 | env = parallel_to_aec(env) 51 | return env 52 | 53 | 54 | class parallel_env(ParallelEnv): 55 | metadata = {'render_modes': ['human'], 'name': 'rps_v2'} 56 | 57 | def __init__(self, render_mode=None): 58 | """The init method takes in environment arguments and should define the 59 | following attributes: 60 | 61 | - possible_agents 62 | - render_mode 63 | 64 | Note: as of v1.18.1, the action_spaces and observation_spaces attributes are deprecated. 65 | Spaces should be defined in the action_space() and observation_space() methods. 66 | If these methods are not overridden, spaces will be inferred from self.observation_spaces/action_spaces, raising a warning. 67 | 68 | These attributes should not be changed after initialization. 69 | """ 70 | self.possible_agents = ['player_' + str(r) for r in range(2)] 71 | 72 | # optional: a mapping between agent name and ID 73 | self.agent_name_mapping = dict( 74 | zip(self.possible_agents, list(range(len(self.possible_agents))))) 75 | self.render_mode = render_mode 76 | 77 | # Observation space should be defined here. 78 | # lru_cache allows observation and action spaces to be memoized, reducing clock cycles required to get each agent's space. 79 | # If your spaces change over time, remove this line (disable caching). 80 | @functools.lru_cache(maxsize=None) 81 | def observation_space(self, agent): 82 | # gymnasium spaces are defined and documented here: https://gymnasium.farama.org/api/spaces/ 83 | return Discrete(4) 84 | 85 | # Action space should be defined here. 86 | # If your spaces change over time, remove this line (disable caching). 87 | @functools.lru_cache(maxsize=None) 88 | def action_space(self, agent): 89 | return Discrete(3) 90 | 91 | def render(self): 92 | """Renders the environment. 93 | 94 | In human mode, it can print to terminal, open up a graphical window, or 95 | open up some other display that a human can see and understand. 96 | """ 97 | if self.render_mode is None: 98 | gymnasium.logger.warn( 99 | 'You are calling render method without specifying any render mode.' 100 | ) 101 | return 102 | 103 | if len(self.agents) == 2: 104 | string = 'Current state: Agent1: {} , Agent2: {}'.format( 105 | MOVES[self.state[self.agents[0]]], 106 | MOVES[self.state[self.agents[1]]]) 107 | else: 108 | string = 'Game over' 109 | print(string) 110 | 111 | def close(self): 112 | """Close should release any graphical displays, subprocesses, network 113 | connections or any other environment data which should not be kept 114 | around after the user is no longer using the environment.""" 115 | pass 116 | 117 | def reset(self, seed=None, options=None): 118 | """Reset needs to initialize the `agents` attribute and must set up the 119 | environment so that render(), and step() can be called without issues. 120 | 121 | Here it initializes the `num_moves` variable which counts the number of 122 | hands that are played. Returns the observations for each agent 123 | """ 124 | self.agents = self.possible_agents[:] 125 | self.num_moves = 0 126 | observations = {agent: NONE for agent in self.agents} 127 | infos = {agent: {} for agent in self.agents} 128 | self.state = observations 129 | 130 | return observations, infos 131 | 132 | def step(self, actions): 133 | """step(action) takes in an action for each agent and should return 134 | the. 135 | 136 | - observations 137 | - rewards 138 | - terminations 139 | - truncations 140 | - infos 141 | dicts where each dict looks like {agent_1: item_1, agent_2: item_2} 142 | """ 143 | # If a user passes in actions with no agents, then just return empty observations, etc. 144 | if not actions: 145 | self.agents = [] 146 | return {}, {}, {}, {}, {} 147 | 148 | # rewards for all agents are placed in the rewards dictionary to be returned 149 | rewards = {} 150 | rewards[self.agents[0]], rewards[self.agents[1]] = REWARD_MAP[( 151 | actions[self.agents[0]], actions[self.agents[1]])] 152 | 153 | terminations = {agent: False for agent in self.agents} 154 | 155 | self.num_moves += 1 156 | env_truncation = self.num_moves >= NUM_ITERS 157 | truncations = {agent: env_truncation for agent in self.agents} 158 | 159 | # current observation is just the other player's most recent action 160 | observations = { 161 | self.agents[i]: int(actions[self.agents[1 - i]]) 162 | for i in range(len(self.agents)) 163 | } 164 | self.state = observations 165 | 166 | # typically there won't be any information in the infos, but there must 167 | # still be an entry for each agent 168 | infos = {agent: {} for agent in self.agents} 169 | 170 | if env_truncation: 171 | self.agents = [] 172 | 173 | if self.render_mode == 'human': 174 | self.render() 175 | return observations, rewards, terminations, truncations, infos 176 | -------------------------------------------------------------------------------- /marltoolkit/envs/pettingzoo/pettingzoo_env.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from typing import Any, Dict, List, Tuple 3 | 4 | import numpy as np 5 | from gymnasium import spaces 6 | from pettingzoo.utils.env import AECEnv 7 | from pettingzoo.utils.wrappers import BaseWrapper 8 | 9 | 10 | class PettingZooEnv(AECEnv, ABC): 11 | 12 | def __init__(self, env: BaseWrapper): 13 | super(PettingZooEnv, self).__init__() 14 | self.env = env 15 | self.agents = self.env.possible_agents 16 | self.agent_idx = {} 17 | for i, agent_id in enumerate(self.agents): 18 | self.agent_idx[agent_id] = i 19 | 20 | self.action_spaces = { 21 | k: self.env.action_space(k) 22 | for k in self.env.agents 23 | } 24 | self.observation_spaces = { 25 | k: self.env.observation_space(k) 26 | for k in self.env.agents 27 | } 28 | assert all( 29 | self.observation_space(agent) == self.env.observation_space( 30 | self.agents[0]) 31 | for agent in self.agents), ( 32 | 'Observation spaces for all agents must be identical. Perhaps ' 33 | "SuperSuit's pad_observations wrapper can help (useage: " 34 | '`supersuit.aec_wrappers.pad_observations(env)`') 35 | 36 | assert all( 37 | self.action_space(agent) == self.env.action_space(self.agents[0]) 38 | for agent in self.agents), ( 39 | 'Action spaces for all agents must be identical. Perhaps ' 40 | "SuperSuit's pad_action_space wrapper can help (useage: " 41 | '`supersuit.aec_wrappers.pad_action_space(env)`') 42 | try: 43 | self.state_space = self.env.state_space 44 | except Exception: 45 | self.state_space = None 46 | 47 | self.rewards = [0] * len(self.agents) 48 | self.metadata = self.env.metadata 49 | self.env.reset() 50 | 51 | def reset(self, *args: Any, **kwargs: Any) -> Tuple[dict, dict]: 52 | self.env.reset(*args, **kwargs) 53 | 54 | observation, reward, terminated, truncated, info = self.env.last(self) 55 | 56 | if isinstance(observation, dict) and 'action_mask' in observation: 57 | observation_dict = { 58 | 'agent_id': 59 | self.env.agent_selection, 60 | 'obs': 61 | observation['observation'], 62 | 'mask': [ 63 | True if obm == 1 else False 64 | for obm in observation['action_mask'] 65 | ], 66 | } 67 | else: 68 | if isinstance(self.action_space, spaces.Discrete): 69 | observation_dict = { 70 | 'agent_id': 71 | self.env.agent_selection, 72 | 'obs': 73 | observation, 74 | 'mask': 75 | [True] * self.env.action_space(self.env.agent_selection).n, 76 | } 77 | else: 78 | observation_dict = { 79 | 'agent_id': self.env.agent_selection, 80 | 'obs': observation, 81 | } 82 | 83 | return observation_dict, info 84 | 85 | def step(self, action: Any) -> Tuple[Dict, List[int], bool, bool, Dict]: 86 | self.env.step(action) 87 | 88 | observation, reward, term, trunc, info = self.env.last() 89 | 90 | if isinstance(observation, dict) and 'action_mask' in observation: 91 | obs = { 92 | 'agent_id': 93 | self.env.agent_selection, 94 | 'obs': 95 | observation['observation'], 96 | 'mask': [ 97 | True if obm == 1 else False 98 | for obm in observation['action_mask'] 99 | ], 100 | } 101 | else: 102 | if isinstance(self.action_space, spaces.Discrete): 103 | obs = { 104 | 'agent_id': 105 | self.env.agent_selection, 106 | 'obs': 107 | observation, 108 | 'mask': 109 | [True] * self.env.action_space(self.env.agent_selection).n, 110 | } 111 | else: 112 | obs = { 113 | 'agent_id': self.env.agent_selection, 114 | 'obs': observation 115 | } 116 | 117 | for agent_id, reward in self.env.rewards.items(): 118 | self.rewards[self.agent_idx[agent_id]] = reward 119 | return obs, self.rewards, term, trunc, info 120 | 121 | def state(self): 122 | try: 123 | return np.array(self.env.state()) 124 | except Exception: 125 | return None 126 | 127 | def seed(self, seed: Any = None) -> None: 128 | try: 129 | self.env.seed(seed) 130 | except (NotImplementedError, AttributeError): 131 | self.env.reset(seed=seed) 132 | 133 | def render(self) -> Any: 134 | return self.env.render() 135 | 136 | def close(self): 137 | self.env.close() 138 | -------------------------------------------------------------------------------- /marltoolkit/envs/smacv1/__init__.py: -------------------------------------------------------------------------------- 1 | from .smac_env import SMACWrapperEnv 2 | 3 | __all__ = ['SMACWrapperEnv'] 4 | -------------------------------------------------------------------------------- /marltoolkit/envs/smacv2/__init__.py: -------------------------------------------------------------------------------- 1 | from .smacv2_env import SMACv2Env 2 | 3 | __all__ = ['SMACv2Env'] 4 | -------------------------------------------------------------------------------- /marltoolkit/envs/smacv2/smacv2_env.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | from gymnasium.spaces import Discrete 5 | from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper 6 | 7 | 8 | class SMACv2Env(StarCraftCapabilityEnvWrapper): 9 | 10 | def __init__(self, **kwargs): 11 | super(SMACv2Env, self).__init__(obs_last_action=False, **kwargs) 12 | self.action_space = [] 13 | self.observation_space = [] 14 | self.share_observation_space = [] 15 | 16 | self.n_agents = self.env.n_agents 17 | 18 | for i in range(self.env.n_agents): 19 | self.action_space.append(Discrete(self.env.n_actions)) 20 | self.observation_space.append([self.env.get_obs_size()]) 21 | self.share_observation_space.append([self.env.get_state_size()]) 22 | 23 | def seed(self, seed): 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | 27 | def step(self, actions): 28 | reward, terminated, info = self.env.step(actions) 29 | obs = self.env.get_obs() 30 | state = self.env.get_state() 31 | global_state = [state] * self.n_agents 32 | rewards = [[reward]] * self.n_agents 33 | dones = [terminated] * self.n_agents 34 | infos = [info] * self.n_agents 35 | avail_actions = self.env.get_avail_actions() 36 | 37 | bad_transition = True if self.env._episode_steps >= self.env.episode_limit else False 38 | for info in infos: 39 | info['bad_transition'] = bad_transition 40 | info['battles_won'] = self.env.battles_won 41 | info['battles_game'] = self.env.battles_game 42 | info['battles_draw'] = self.env.timeouts 43 | info['restarts'] = self.env.force_restarts 44 | info['won'] = self.env.win_counted 45 | 46 | return obs, global_state, rewards, dones, infos, avail_actions 47 | 48 | def reset(self): 49 | obs, state = super().reset() 50 | state = [state for i in range(self.env.n_agents)] 51 | avail_actions = self.env.get_avail_actions() 52 | return obs, state, avail_actions 53 | -------------------------------------------------------------------------------- /marltoolkit/envs/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_vec_env import BaseVecEnv, CloudpickleWrapper 2 | from .dummy_vec_env import DummyVecEnv 3 | from .subproc_vec_env import SubprocVecEnv 4 | 5 | __all__ = [ 6 | 'BaseVecEnv', 7 | 'DummyVecEnv', 8 | 'SubprocVecEnv', 9 | 'CloudpickleWrapper', 10 | ] 11 | -------------------------------------------------------------------------------- /marltoolkit/envs/waregame/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/envs/waregame/__init__.py -------------------------------------------------------------------------------- /marltoolkit/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/modules/__init__.py -------------------------------------------------------------------------------- /marltoolkit/modules/actors/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils package.""" 2 | 3 | from .mlp import MLPActorModel 4 | from .rnn import RNNActorModel 5 | 6 | __all__ = [ 7 | 'RNNActorModel', 8 | 'MLPActorModel', 9 | ] 10 | -------------------------------------------------------------------------------- /marltoolkit/modules/actors/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Initialize Policy weights 6 | def weights_init_(module: nn.Module): 7 | if isinstance(module, nn.Linear): 8 | nn.init.xavier_uniform_(module.weight, gain=1) 9 | nn.init.constant_(module.bias, 0) 10 | 11 | 12 | class MLPActorModel(nn.Module): 13 | 14 | def __init__( 15 | self, 16 | input_dim: int = None, 17 | hidden_dim: int = 64, 18 | n_actions: int = None, 19 | ) -> None: 20 | """Initialize the Actor network. 21 | 22 | Args: 23 | input_dim (int, optional): obs, include the agent's id and last action, 24 | shape: (batch, obs_shape + n_action + n_agents) 25 | hidden_dim (int, optional): hidden size of the network. Defaults to 64. 26 | n_actions (int, optional): number of actions. Defaults to None. 27 | """ 28 | super(MLPActorModel, self).__init__() 29 | 30 | self.fc1 = nn.Linear(input_dim, hidden_dim) 31 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 32 | self.fc3 = nn.Linear(hidden_dim, n_actions) 33 | self.relu1 = nn.ReLU(inplace=True) 34 | self.relu2 = nn.ReLU(inplace=True) 35 | self.apply(weights_init_) 36 | 37 | def forward(self, obs: torch.Tensor): 38 | hid1 = self.relu1(self.fc1(obs)) 39 | hid2 = self.relu2(self.fc2(hid1)) 40 | policy = self.fc3(hid2) 41 | return policy 42 | -------------------------------------------------------------------------------- /marltoolkit/modules/actors/r_actor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from marltoolkit.modules.utils.act import ACTLayer 5 | from marltoolkit.modules.utils.common import MLPBase, RNNLayer 6 | 7 | 8 | class R_Actor(nn.Module): 9 | """Actor network class for MAPPO. Outputs actions given observations. 10 | 11 | :param args: (argparse.Namespace) arguments containing relevant model information. 12 | :param obs_space: (gym.Space) observation space. 13 | :param action_space: (gym.Space) action space. 14 | :param device: (torch.device) specifies the device to run on (cpu/gpu). 15 | """ 16 | 17 | def __init__(self, args) -> None: 18 | super(R_Actor, self).__init__() 19 | 20 | self.use_recurrent_policy = args.use_recurrent_policy 21 | self.use_policy_active_masks = args.use_policy_active_masks 22 | self.base = MLPBase( 23 | input_dim=args.actor_input_dim, 24 | hidden_dim=args.hidden_size, 25 | activation=args.activation, 26 | use_orthogonal=args.use_orthogonal, 27 | use_feature_normalization=args.use_feature_normalization, 28 | ) 29 | if args.use_recurrent_policy: 30 | self.rnn = RNNLayer( 31 | args.hidden_size, 32 | args.hidden_size, 33 | args.rnn_layers, 34 | args.use_orthogonal, 35 | ) 36 | self.act = ACTLayer(args) 37 | self.algorithm_name = args.algorithm_name 38 | 39 | def forward( 40 | self, 41 | obs: torch.Tensor, 42 | masks: torch.Tensor, 43 | available_actions: torch.Tensor = None, 44 | rnn_hidden_states: torch.Tensor = None, 45 | deterministic: bool = False, 46 | ): 47 | """Compute actions from the given inputs. 48 | 49 | :param obs: (np.ndarray / torch.Tensor) observation inputs into network. 50 | :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros. 51 | :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent 52 | (if None, all actions available) 53 | :param rnn_hidden_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. 54 | 55 | :param deterministic: (bool) whether to sample from action distribution or return the mode. 56 | 57 | :return actions: (torch.Tensor) actions to take. 58 | :return action_log_probs: (torch.Tensor) log probabilities of taken actions. 59 | :return rnn_hidden_states: (torch.Tensor) updated RNN hidden states. 60 | """ 61 | actor_features = self.base(obs) 62 | 63 | if self.use_recurrent_policy: 64 | actor_features, rnn_hidden_states = self.rnn( 65 | actor_features, rnn_hidden_states, masks) 66 | 67 | actions, action_log_probs = self.act(actor_features, available_actions, 68 | deterministic) 69 | 70 | return actions, action_log_probs, rnn_hidden_states 71 | 72 | def evaluate_actions( 73 | self, 74 | obs: torch.Tensor, 75 | action: torch.Tensor, 76 | masks: torch.Tensor, 77 | active_masks: torch.Tensor = None, 78 | available_actions: torch.Tensor = None, 79 | rnn_hidden_states: torch.Tensor = None, 80 | ): 81 | """Compute log probability and entropy of given actions. 82 | 83 | :param obs: (torch.Tensor) observation inputs into network. 84 | :param action: (torch.Tensor) actions whose entropy and log probability to evaluate. 85 | :param rnn_hidden_states: (torch.Tensor) if RNN network, hidden states for RNN. 86 | :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros. 87 | :param available_actions: (torch.Tensor) denotes which actions are available to agent 88 | (if None, all actions available) 89 | :param active_masks: (torch.Tensor) denotes whether an agent is active or dead. 90 | 91 | :return action_log_probs: (torch.Tensor) log probabilities of the input actions. 92 | :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs. 93 | """ 94 | actor_features = self.base(obs) 95 | 96 | if self.use_recurrent_policy: 97 | actor_features, rnn_hidden_states = self.rnn( 98 | actor_features, rnn_hidden_states, masks) 99 | 100 | if self.algorithm_name == 'hatrpo': 101 | action_log_probs, dist_entropy, action_mu, action_std, all_probs = ( 102 | self.act.evaluate_actions_trpo( 103 | actor_features, 104 | action, 105 | available_actions, 106 | active_masks=active_masks 107 | if self.use_policy_active_masks else None, 108 | )) 109 | 110 | return action_log_probs, dist_entropy, action_mu, action_std, all_probs 111 | else: 112 | action_log_probs, dist_entropy = self.act.evaluate_actions( 113 | actor_features, 114 | action, 115 | available_actions, 116 | active_masks=active_masks 117 | if self.use_policy_active_masks else None, 118 | ) 119 | 120 | return action_log_probs, dist_entropy 121 | -------------------------------------------------------------------------------- /marltoolkit/modules/actors/rnn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class RNNActorModel(nn.Module): 9 | """Because all the agents share the same network, 10 | input_shape=obs_shape+n_actions+n_agents. 11 | 12 | Args: 13 | input_dim (int): The input dimension. 14 | fc_hidden_dim (int): The hidden dimension of the fully connected layer. 15 | rnn_hidden_dim (int): The hidden dimension of the RNN layer. 16 | n_actions (int): The number of actions. 17 | """ 18 | 19 | def __init__(self, args: argparse.Namespace = None) -> None: 20 | super(RNNActorModel, self).__init__() 21 | self.args = args 22 | self.rnn_hidden_dim = args.rnn_hidden_dim 23 | self.fc1 = nn.Linear(self.actor_input_dim, args.fc_hidden_dim) 24 | self.rnn = nn.GRUCell(input_size=args.fc_hidden_dim, 25 | hidden_size=args.rnn_hidden_dim) 26 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions) 27 | 28 | def init_hidden(self): 29 | # make hidden states on same device as model 30 | return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_() 31 | 32 | def forward( 33 | self, 34 | inputs: torch.Tensor = None, 35 | hidden_state: torch.Tensor = None, 36 | ) -> tuple[torch.Tensor, torch.Tensor]: 37 | out = F.relu(self.fc1(inputs), inplace=True) 38 | 39 | if hidden_state is not None: 40 | h_in = hidden_state.reshape(-1, self.rnn_hidden_dim) 41 | else: 42 | h_in = torch.zeros(out.shape[0], 43 | self.rnn_hidden_dim).to(inputs.device) 44 | 45 | hidden_state = self.rnn(out, h_in) 46 | out = self.fc2(hidden_state) # (batch_size, n_actions) 47 | return out, hidden_state 48 | 49 | def update(self, model: nn.Module) -> None: 50 | self.load_state_dict(model.state_dict()) 51 | 52 | 53 | class MultiLayerRNNActorModel(nn.Module): 54 | """Because all the agents share the same network, 55 | input_shape=obs_shape+n_actions+n_agents. 56 | 57 | Args: 58 | input_dim (int): The input dimension. 59 | fc_hidden_dim (int): The hidden dimension of the fully connected layer. 60 | rnn_hidden_dim (int): The hidden dimension of the RNN layer. 61 | n_actions (int): The number of actions. 62 | """ 63 | 64 | def __init__( 65 | self, 66 | input_dim: int = None, 67 | fc_hidden_dim: int = 64, 68 | rnn_hidden_dim: int = 64, 69 | rnn_num_layers: int = 2, 70 | n_actions: int = None, 71 | **kwargs, 72 | ) -> None: 73 | super(MultiLayerRNNActorModel, self).__init__() 74 | 75 | self.rnn_hidden_dim = rnn_hidden_dim 76 | self.rnn_num_layers = rnn_num_layers 77 | 78 | self.fc1 = nn.Linear(input_dim, fc_hidden_dim) 79 | self.rnn = nn.GRU( 80 | input_size=fc_hidden_dim, 81 | hidden_size=rnn_hidden_dim, 82 | num_layers=rnn_num_layers, 83 | batch_first=True, 84 | ) 85 | self.fc2 = nn.Linear(rnn_hidden_dim, n_actions) 86 | 87 | def init_hidden(self, batch_size: int) -> torch.Tensor: 88 | # make hidden states on same device as model 89 | return self.fc1.weight.new(self.rnn_num_layers, batch_size, 90 | self.rnn_hidden_dim).zero_() 91 | 92 | def forward( 93 | self, 94 | input: torch.Tensor = None, 95 | hidden_state: torch.Tensor = None, 96 | ) -> tuple[torch.Tensor, torch.Tensor]: 97 | # input: (batch_size, episode_length, obs_dim) 98 | out = F.relu(self.fc1(input), inplace=True) 99 | # out: (batch_size, episode_length, fc_hidden_dim) 100 | batch_size = input.shape[0] 101 | if hidden_state is None: 102 | hidden_state = self.init_hidden(batch_size) 103 | else: 104 | hidden_state = hidden_state.reshape( 105 | self.rnn_num_layers, batch_size, 106 | self.rnn_hidden_dim).to(input.device) 107 | 108 | out, hidden_state = self.rnn(out, hidden_state) 109 | # out: (batch_size, seq_len, rnn_hidden_dim) 110 | # hidden_state: (num_ayers, batch_size, rnn_hidden_dim) 111 | logits = self.fc2(out) 112 | return logits, hidden_state 113 | 114 | def update(self, model: nn.Module) -> None: 115 | self.load_state_dict(model.state_dict()) 116 | 117 | 118 | if __name__ == '__main__': 119 | rnn = nn.GRU(input_size=512, 120 | hidden_size=256, 121 | num_layers=2, 122 | batch_first=True) 123 | input = torch.randn(3, 512) 124 | h0 = torch.randn(2, 256) 125 | output, hn = rnn(input, h0) 126 | print(output.shape, hn.shape) 127 | # torch.Size([32, 512, 20]) torch.Size([2, 32, 20]) 128 | -------------------------------------------------------------------------------- /marltoolkit/modules/critics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/modules/critics/__init__.py -------------------------------------------------------------------------------- /marltoolkit/modules/critics/coma.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from marltoolkit.modules.actors.mlp import MLPActorModel 4 | from marltoolkit.modules.critics.mlp import MLPCriticModel 5 | 6 | 7 | class ComaModel(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | num_agents: int = None, 12 | n_actions: int = None, 13 | obs_shape: int = None, 14 | state_shape: int = None, 15 | hidden_dim: int = 64, 16 | **kwargs, 17 | ): 18 | super(ComaModel, self).__init__() 19 | 20 | self.num_agents = num_agents 21 | self.obs_shape = obs_shape 22 | self.state_shape = state_shape 23 | self.n_actions = n_actions 24 | self.actions_onehot_shape = self.n_actions * self.num_agents 25 | actor_input_dim = self._get_actor_input_dim() 26 | critic_input_dim = self._get_critic_input_dim() 27 | 28 | # Set up network layers 29 | self.actor_model = MLPActorModel(input_dim=actor_input_dim, 30 | hidden_dim=hidden_dim, 31 | n_actions=n_actions) 32 | self.critic_model = MLPCriticModel(input_dim=critic_input_dim, 33 | hidden_dim=hidden_dim, 34 | output_dim=1) 35 | 36 | def policy(self, obs, hidden_state): 37 | return self.actor_model(obs, hidden_state) 38 | 39 | def value(self, inputs): 40 | return self.critic_model(inputs) 41 | 42 | def get_actor_params(self): 43 | return self.actor_model.parameters() 44 | 45 | def get_critic_params(self): 46 | return self.critic_model.parameters() 47 | 48 | def _get_actor_input_dim(self): 49 | # observation 50 | input_dim = self.obs_shape 51 | # agent id 52 | input_dim += self.num_agents 53 | # actions and last actions 54 | input_dim += self.actions_onehot_shape * self.num_agents * 2 55 | return input_dim 56 | 57 | def _get_critic_input_dim(self): 58 | input_dim = self.state_shape # state: 48 in 3m map 59 | input_dim += self.obs_shape # obs: 30 in 3m map 60 | input_dim += self.num_agents # agent_id: 3 in 3m map 61 | input_dim += ( 62 | self.n_actions * self.num_agents * 2 63 | ) # all agents' action and last_action (one-hot): 54 in 3m map 64 | return input_dim # 48 + 30+ 3 = 135 65 | -------------------------------------------------------------------------------- /marltoolkit/modules/critics/maddpg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MADDPGCritic(nn.Module): 7 | 8 | def __init__( 9 | self, 10 | n_agents, 11 | state_shape, 12 | obs_shape, 13 | hidden_dim, 14 | n_actions, 15 | obs_agent_id: bool, 16 | obs_last_action: bool, 17 | obs_individual_obs: bool, 18 | ): 19 | super(MADDPGCritic, self).__init__() 20 | self.n_agents = n_agents 21 | self.n_actions = n_actions 22 | self.state_shape = state_shape 23 | self.obs_shape = obs_shape 24 | self.obs_individual_obs = obs_individual_obs 25 | self.obs_agent_id = obs_agent_id 26 | self.obs_last_action = obs_last_action 27 | 28 | self.input_shape = self._get_input_shape( 29 | ) + self.n_actions * self.n_agents 30 | if self.obs_last_action: 31 | self.input_shape += self.n_actions 32 | 33 | # Set up network layers 34 | self.fc1 = nn.Linear(self.input_shape, hidden_dim) 35 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 36 | self.fc3 = nn.Linear(hidden_dim, 1) 37 | 38 | def forward(self, inputs, actions): 39 | inputs = torch.cat((inputs, actions), dim=-1) 40 | x = F.relu(self.fc1(inputs)) 41 | x = F.relu(self.fc2(x)) 42 | q = self.fc3(x) 43 | return q 44 | 45 | def _get_input_shape(self): 46 | # state_shape 47 | input_shape = self.state_shape 48 | # whether to add the individual observation 49 | if self.obs_individual_obs: 50 | input_shape += self.obs_shape 51 | # agent id 52 | if self.obs_agent_id: 53 | input_shape += self.n_agents 54 | return input_shape 55 | -------------------------------------------------------------------------------- /marltoolkit/modules/critics/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | # Initialize Policy weights 7 | def weights_init_(module: nn.Module): 8 | if isinstance(module, nn.Linear): 9 | nn.init.xavier_uniform_(module.weight, gain=1) 10 | nn.init.constant_(module.bias, 0) 11 | 12 | 13 | class MLPCriticModel(nn.Module): 14 | """MLP Critic Network. 15 | 16 | Args: 17 | nn (_type_): _description_ 18 | """ 19 | 20 | def __init__( 21 | self, 22 | input_dim: int = None, 23 | hidden_dim: int = 64, 24 | output_dim: int = 1, 25 | ) -> None: 26 | super(MLPCriticModel, self).__init__() 27 | 28 | # Set up network layers 29 | self.fc1 = nn.Linear(input_dim, hidden_dim) 30 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 31 | self.fc3 = nn.Linear(hidden_dim, output_dim) 32 | self.apply(weights_init_) 33 | 34 | def forward(self, inputs: torch.Tensor) -> torch.Tensor: 35 | """ 36 | Args: 37 | inputs (torch.Tensor): 38 | Returns: 39 | q_total (torch.Tensor): 40 | """ 41 | hidden = F.relu(self.fc1(inputs)) 42 | hidden = F.relu(self.fc2(hidden)) 43 | qvalue = self.fc3(hidden) 44 | return qvalue 45 | -------------------------------------------------------------------------------- /marltoolkit/modules/critics/r_critic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from marltoolkit.modules.utils.common import MLPBase, RNNLayer 7 | from marltoolkit.modules.utils.popart import PopArt 8 | 9 | 10 | class R_Critic(nn.Module): 11 | """Critic network class for MAPPO. Outputs value function predictions given 12 | centralized input (MAPPO) or local observations (IPPO). 13 | 14 | :param args: (argparse.Namespace) arguments containing relevant model information. 15 | :param state_space: (gym.Space) (centralized) observation space. 16 | :param device: (torch.device) specifies the device to run on (cpu/gpu). 17 | """ 18 | 19 | def __init__(self, args: argparse.Namespace): 20 | super(R_Critic, self).__init__() 21 | self.use_recurrent_policy = args.use_recurrent_policy 22 | init_method = [nn.init.xavier_uniform_, 23 | nn.init.orthogonal_][args.use_orthogonal] 24 | self.base = MLPBase( 25 | input_dim=args.state_dim, 26 | hidden_dim=args.hidden_size, 27 | use_orthogonal=args.use_orthogonal, 28 | use_feature_normalization=args.use_feature_normalization, 29 | ) 30 | if self.use_recurrent_policy: 31 | self.rnn = RNNLayer( 32 | args.hidden_size, 33 | args.hidden_size, 34 | args.rnn_layers, 35 | args.use_orthogonal, 36 | ) 37 | 38 | def init_weight(module: nn.Module) -> None: 39 | if isinstance(module, nn.Linear): 40 | init_method(module.weight, gain=args.gain) 41 | if module.bias is not None: 42 | nn.init.constant_(module.bias, 0) 43 | 44 | if args.use_popart: 45 | self.v_out = PopArt(args.hidden_size, 1) 46 | else: 47 | self.v_out = nn.Linear(args.hidden_size, 1) 48 | 49 | self.apply(init_weight) 50 | 51 | def forward( 52 | self, 53 | state: torch.Tensor, 54 | masks: torch.Tensor, 55 | rnn_hidden_states: torch.Tensor, 56 | ): 57 | """Compute actions from the given inputs. 58 | 59 | :param state: (np.ndarray / torch.Tensor) global observation inputs into network. 60 | :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros. 61 | :param rnn_hidden_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. 62 | 63 | :return values: (torch.Tensor) value function predictions. 64 | :return rnn_hidden_states: (torch.Tensor) updated RNN hidden states. 65 | """ 66 | critic_features = self.base(state) 67 | if self.use_recurrent_policy: 68 | critic_features, rnn_hidden_states = self.rnn( 69 | critic_features, rnn_hidden_states, masks) 70 | values = self.v_out(critic_features) 71 | return values, rnn_hidden_states 72 | -------------------------------------------------------------------------------- /marltoolkit/modules/mixers/__init__.py: -------------------------------------------------------------------------------- 1 | """Utils package.""" 2 | 3 | from .qmixer import QMixerModel 4 | from .vdn import VDNMixer 5 | 6 | __all__ = ['QMixerModel', 'VDNMixer'] 7 | -------------------------------------------------------------------------------- /marltoolkit/modules/mixers/qatten.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class QattenMixer(nn.Module): 8 | 9 | def __init__(self, 10 | n_agents: int = None, 11 | state_dim: int = None, 12 | agent_own_state_size: int = None, 13 | n_query_embedding_layer1: int = 64, 14 | n_query_embedding_layer2: int = 32, 15 | n_key_embedding_layer1: int = 32, 16 | n_head_embedding_layer1: int = 64, 17 | n_head_embedding_layer2: int = 4, 18 | num_attention_heads: int = 4, 19 | n_constrant_value: int = 32): 20 | super(QattenMixer, self).__init__() 21 | self.n_agents = n_agents 22 | self.state_dim = state_dim 23 | self.agent_own_state_size = agent_own_state_size 24 | 25 | self.n_query_embedding_layer1 = n_query_embedding_layer1 26 | self.n_query_embedding_layer2 = n_query_embedding_layer2 27 | self.n_key_embedding_layer1 = n_key_embedding_layer1 28 | self.n_head_embedding_layer1 = n_head_embedding_layer1 29 | self.n_head_embedding_layer2 = n_head_embedding_layer2 30 | self.num_attention_heads = num_attention_heads 31 | self.n_constrant_value = n_constrant_value 32 | 33 | self.query_embedding_layers = nn.ModuleList() 34 | for _ in range(self.num_attention_heads): 35 | self.query_embedding_layers.append( 36 | nn.Sequential( 37 | nn.Linear(state_dim, n_query_embedding_layer1), 38 | nn.ReLU(inplace=True), 39 | nn.Linear(n_query_embedding_layer1, 40 | n_query_embedding_layer2))) 41 | 42 | self.key_embedding_layers = nn.ModuleList() 43 | for _ in range(num_attention_heads): 44 | self.key_embedding_layers.append( 45 | nn.Linear(self.agent_own_state_size, n_key_embedding_layer1)) 46 | 47 | self.scaled_product_value = np.sqrt(n_query_embedding_layer2) 48 | 49 | self.head_embedding_layer = nn.Sequential( 50 | nn.Linear(state_dim, n_head_embedding_layer1), 51 | nn.ReLU(inplace=True), 52 | nn.Linear(n_head_embedding_layer1, n_head_embedding_layer2)) 53 | 54 | self.constrant_value_layer = nn.Sequential( 55 | nn.Linear(state_dim, n_constrant_value), nn.ReLU(inplace=True), 56 | nn.Linear(n_constrant_value, 1)) 57 | 58 | def forward(self, agent_qs: torch.Tensor, states: torch.Tensor): 59 | ''' 60 | Args: 61 | agent_qs (torch.Tensor): (batch_size, T, n_agents) 62 | states (torch.Tensor): (batch_size, T, state_shape) 63 | Returns: 64 | q_total (torch.Tensor): (batch_size, T, 1) 65 | ''' 66 | bs = agent_qs.size(0) 67 | # states : (batch_size * T, state_shape) 68 | states = states.reshape(-1, self.state_dim) 69 | # agent_qs: (batch_size * T, 1, n_agents) 70 | agent_qs = agent_qs.view(-1, 1, self.n_agents) 71 | us = self._get_us(states) 72 | 73 | q_lambda_list = [] 74 | for i in range(self.num_attention_heads): 75 | state_embedding = self.query_embedding_layers[i](states) 76 | u_embedding = self.key_embedding_layers[i](us) 77 | 78 | # shape: [batch_size * T, 1, state_dim] 79 | state_embedding = state_embedding.reshape( 80 | -1, 1, self.n_query_embedding_layer2) 81 | # shape: [batch_size * T, state_dim, n_agent] 82 | u_embedding = u_embedding.reshape(-1, self.n_agents, 83 | self.n_key_embedding_layer1) 84 | u_embedding = u_embedding.permute(0, 2, 1) 85 | 86 | # shape: [batch_size * T, 1, n_agent] 87 | raw_lambda = torch.matmul(state_embedding, 88 | u_embedding) / self.scaled_product_value 89 | q_lambda = F.softmax(raw_lambda, dim=-1) 90 | 91 | q_lambda_list.append(q_lambda) 92 | 93 | # shape: [batch_size * T, num_attention_heads, n_agent] 94 | q_lambda_list = torch.stack(q_lambda_list, dim=1).squeeze(-2) 95 | 96 | # shape: [batch_size * T, n_agent, num_attention_heads] 97 | q_lambda_list = q_lambda_list.permute(0, 2, 1) 98 | 99 | # shape: [batch_size * T, 1, num_attention_heads] 100 | q_h = torch.matmul(agent_qs, q_lambda_list) 101 | 102 | if self.type == 'weighted': 103 | # shape: [batch_size * T, num_attention_heads] 104 | w_h = torch.abs(self.head_embedding_layer(states)) 105 | # shape: [batch_size * T, num_attention_heads, 1] 106 | w_h = w_h.reshape(-1, self.n_head_embedding_layer2, 1) 107 | 108 | # shape: [batch_size * T, 1,1] 109 | sum_q_h = torch.matmul(q_h, w_h) 110 | # shape: [batch_size * T, 1] 111 | sum_q_h = sum_q_h.reshape(-1, 1) 112 | else: 113 | # shape: [-1, 1] 114 | sum_q_h = q_h.sum(-1) 115 | sum_q_h = sum_q_h.reshape(-1, 1) 116 | 117 | c = self.constrant_value_layer(states) 118 | q_tot = sum_q_h + c 119 | q_tot = q_tot.view(bs, -1, 1) 120 | return q_tot 121 | 122 | def _get_us(self, states): 123 | agent_own_state_size = self.agent_own_state_size 124 | with torch.no_grad(): 125 | us = states[:, :agent_own_state_size * self.n_agents].reshape( 126 | -1, agent_own_state_size) 127 | return us 128 | -------------------------------------------------------------------------------- /marltoolkit/modules/mixers/qmixer.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: jianzhnie 3 | LastEditors: jianzhnie 4 | Description: RLToolKit is a flexible and high-efficient reinforcement learning framework. 5 | Copyright (c) 2022 by jianzhnie@126.com, All Rights Reserved. 6 | ''' 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class QMixerModel(nn.Module): 13 | """Implementation of the QMixer network. 14 | 15 | Args: 16 | num_agents (int): The number of agents. 17 | state_dim (int): The shape of the state. 18 | hypernet_layers (int): The number of layers in the hypernetwork. 19 | mixing_embed_dim (int): The dimension of the mixing embedding. 20 | hypernet_embed_dim (int): The dimension of the hypernetwork embedding. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | num_agents: int = None, 26 | state_dim: int = None, 27 | hypernet_layers: int = 2, 28 | mixing_embed_dim: int = 32, 29 | hypernet_embed_dim: int = 64, 30 | ): 31 | super(QMixerModel, self).__init__() 32 | 33 | self.num_agents = num_agents 34 | self.state_dim = state_dim 35 | self.mixing_embed_dim = mixing_embed_dim 36 | if hypernet_layers == 1: 37 | self.hyper_w_1 = nn.Linear(state_dim, 38 | mixing_embed_dim * num_agents) 39 | self.hyper_w_2 = nn.Linear(state_dim, mixing_embed_dim) 40 | elif hypernet_layers == 2: 41 | self.hyper_w_1 = nn.Sequential( 42 | nn.Linear(state_dim, hypernet_embed_dim), 43 | nn.ReLU(inplace=True), 44 | nn.Linear(hypernet_embed_dim, mixing_embed_dim * num_agents), 45 | ) 46 | self.hyper_w_2 = nn.Sequential( 47 | nn.Linear(state_dim, hypernet_embed_dim), 48 | nn.ReLU(inplace=True), 49 | nn.Linear(hypernet_embed_dim, mixing_embed_dim), 50 | ) 51 | else: 52 | raise ValueError('hypernet_layers should be "1" or "2"!') 53 | 54 | # State dependent bias for hidden layer 55 | self.hyper_b_1 = nn.Linear(state_dim, mixing_embed_dim) 56 | self.hyper_b_2 = nn.Sequential( 57 | nn.Linear(state_dim, mixing_embed_dim), 58 | nn.ReLU(inplace=True), 59 | nn.Linear(mixing_embed_dim, 1), 60 | ) 61 | 62 | def forward(self, agent_qs: torch.Tensor, states: torch.Tensor): 63 | """ 64 | Args: 65 | agent_qs (torch.Tensor): (batch_size, T, num_agents) 66 | states (torch.Tensor): (batch_size, T, state_dim) 67 | Returns: 68 | q_total (torch.Tensor): (batch_size, T, 1) 69 | """ 70 | batch_size = agent_qs.size(0) 71 | # states : (batch_size * T, state_dim) 72 | states = states.reshape(-1, self.state_dim) 73 | # agent_qs: (batch_size * T, 1, num_agents) 74 | agent_qs = agent_qs.view(-1, 1, self.num_agents) 75 | 76 | # First layer w and b 77 | w1 = torch.abs(self.hyper_w_1(states)) 78 | # w1: (batch_size * T, num_agents, embed_dim) 79 | w1 = w1.view(-1, self.num_agents, self.mixing_embed_dim) 80 | b1 = self.hyper_b_1(states) 81 | b1 = b1.view(-1, 1, self.mixing_embed_dim) 82 | 83 | # Second layer w and b 84 | # w2 : (batch_size * T, embed_dim) 85 | w2 = torch.abs(self.hyper_w_2(states)) 86 | # w2 : (batch_size * T, embed_dim, 1) 87 | w2 = w2.view(-1, self.mixing_embed_dim, 1) 88 | # State-dependent bias 89 | b2 = self.hyper_b_2(states).view(-1, 1, 1) 90 | 91 | # First hidden layer 92 | # hidden: (batch_size * T, 1, embed_dim) 93 | hidden = F.elu(torch.bmm(agent_qs, w1) + b1) 94 | # Compute final output 95 | # y: (batch_size * T, 1, 1) 96 | y = torch.bmm(hidden, w2) + b2 97 | # Reshape and return 98 | # q_total: (batch_size, T, 1) 99 | q_total = y.view(batch_size, -1, 1) 100 | return q_total 101 | 102 | def update(self, model): 103 | self.load_state_dict(model.state_dict()) 104 | 105 | 106 | class QMixerCentralFF(nn.Module): 107 | 108 | def __init__(self, num_agents, state_dim, central_mixing_embed_dim, 109 | central_action_embed): 110 | super(QMixerCentralFF, self).__init__() 111 | 112 | self.num_agents = num_agents 113 | self.state_dim = state_dim 114 | self.input_dim = num_agents * central_action_embed + state_dim 115 | self.central_mixing_embed_dim = central_mixing_embed_dim 116 | self.central_action_embed = central_action_embed 117 | 118 | self.net = nn.Sequential( 119 | nn.Linear(self.input_dim, central_mixing_embed_dim), 120 | nn.ReLU(inplace=True), 121 | nn.Linear(central_mixing_embed_dim, central_mixing_embed_dim), 122 | nn.ReLU(inplace=True), 123 | nn.Linear(central_mixing_embed_dim, central_mixing_embed_dim), 124 | nn.ReLU(inplace=True), nn.Linear(central_mixing_embed_dim, 1)) 125 | 126 | # V(s) instead of a bias for the last layers 127 | self.vnet = nn.Sequential( 128 | nn.Linear(state_dim, central_mixing_embed_dim), 129 | nn.ReLU(inplace=True), 130 | nn.Linear(central_mixing_embed_dim, 1), 131 | ) 132 | 133 | def forward(self, agent_qs, states): 134 | bs = agent_qs.size(0) 135 | states = states.reshape(-1, self.state_dim) 136 | agent_qs = agent_qs.reshape( 137 | -1, self.num_agents * self.central_action_embed) 138 | 139 | inputs = torch.cat([states, agent_qs], dim=1) 140 | 141 | advs = self.net(inputs) 142 | vs = self.vnet(states) 143 | y = advs + vs 144 | q_tot = y.view(bs, -1, 1) 145 | return q_tot 146 | -------------------------------------------------------------------------------- /marltoolkit/modules/mixers/vdn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class VDNMixer(nn.Module): 7 | """Computes total Q values given agent q values and global states.""" 8 | 9 | def __init__(self): 10 | super(VDNMixer, self).__init__() 11 | 12 | def forward(self, agent_qs: torch.Tensor, states: torch.Tensor): 13 | if type(agent_qs) == np.ndarray: 14 | agent_qs = torch.FloatTensor(agent_qs) 15 | return torch.sum(agent_qs, dim=2, keepdim=True) 16 | -------------------------------------------------------------------------------- /marltoolkit/modules/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/modules/utils/__init__.py -------------------------------------------------------------------------------- /marltoolkit/modules/utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class MLPBase(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | input_dim: int, 12 | hidden_dim: int, 13 | activation: str = 'relu', 14 | use_orthogonal: bool = False, 15 | use_feature_normalization: bool = False, 16 | ) -> None: 17 | super(MLPBase, self).__init__() 18 | 19 | use_relu = 1 if activation == 'relu' else 0 20 | active_func = [nn.ReLU(), nn.Tanh()][use_relu] 21 | if use_orthogonal: 22 | init_method = nn.init.orthogonal_ 23 | else: 24 | init_method = nn.init.xavier_uniform_ 25 | 26 | gain = nn.init.calculate_gain(['tanh', 'relu'][use_relu]) 27 | self.use_feature_normalization = use_feature_normalization 28 | if use_feature_normalization: 29 | self.feature_norm = nn.LayerNorm(input_dim) 30 | 31 | def init_weight(module: nn.Module) -> None: 32 | if isinstance(module, nn.Linear): 33 | init_method(module.weight, gain=gain) 34 | if module.bias is not None: 35 | nn.init.constant_(module.bias, 0) 36 | 37 | self.fc1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), active_func, 38 | nn.LayerNorm(hidden_dim)) 39 | self.fc2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), 40 | active_func, nn.LayerNorm(hidden_dim)) 41 | self.apply(init_weight) 42 | 43 | def forward(self, inputs: torch.Tensor): 44 | """Forward method for MLPBase. 45 | 46 | Args: 47 | inputs (torch.Tensor): Input tensor. Shape (batch_size, input_dim) 48 | 49 | Returns: 50 | output (torch.Tensor): Output tensor. Shape (batch_size, hidden_dim) 51 | """ 52 | if self.use_feature_normalization: 53 | output = self.feature_norm(inputs) 54 | 55 | output = self.fc1(output) 56 | output = self.fc2(output) 57 | 58 | return output 59 | 60 | 61 | class Flatten(nn.Module): 62 | 63 | def forward(self, inputs: torch.Tensor): 64 | return inputs.view(inputs.size(0), -1) 65 | 66 | 67 | class CNNBase(nn.Module): 68 | 69 | def __init__( 70 | self, 71 | obs_shape: Tuple, 72 | hidden_dim: int, 73 | kernel_size: int = 3, 74 | stride: int = 1, 75 | activation: str = 'relu', 76 | use_orthogonal: bool = False, 77 | ) -> None: 78 | super(CNNBase, self).__init__() 79 | 80 | use_relu = 1 if activation == 'relu' else 0 81 | active_func = [nn.ReLU(), nn.Tanh()][use_relu] 82 | init_method = [nn.init.xavier_uniform_, 83 | nn.init.orthogonal_][use_orthogonal] 84 | gain = nn.init.calculate_gain(['tanh', 'relu'][use_relu]) 85 | 86 | (in_channel, width, height) = obs_shape 87 | 88 | def init_weight(module: nn.Module) -> None: 89 | if isinstance(module, nn.Linear): 90 | init_method(module.weight, gain=gain) 91 | if module.bias is not None: 92 | nn.init.constant_(module.bias, 0) 93 | 94 | self.cnn = nn.Sequential( 95 | nn.Conv2d( 96 | in_channels=in_channel, 97 | out_channels=hidden_dim // 2, 98 | kernel_size=kernel_size, 99 | stride=stride, 100 | ), 101 | active_func, 102 | Flatten(), 103 | nn.Linear( 104 | hidden_dim // 2 * (width - kernel_size + stride) * 105 | (height - kernel_size + stride), 106 | hidden_dim, 107 | ), 108 | active_func, 109 | nn.Linear(hidden_dim, hidden_dim), 110 | active_func, 111 | ) 112 | self.apply(init_weight) 113 | 114 | def forward(self, inputs: torch.Tensor): 115 | inputs = inputs / 255.0 116 | output = self.cnn(inputs) 117 | return output 118 | 119 | 120 | class RNNLayer(nn.Module): 121 | 122 | def __init__( 123 | self, 124 | input_dim: int, 125 | rnn_hidden_dim: int, 126 | rnn_layers: int, 127 | use_orthogonal: bool = True, 128 | ) -> None: 129 | super(RNNLayer, self).__init__() 130 | self.rnn_layers = rnn_layers 131 | self.use_orthogonal = use_orthogonal 132 | 133 | self.rnn = nn.GRU(input_dim, 134 | rnn_hidden_dim, 135 | num_layers=rnn_layers, 136 | batch_first=True) 137 | for name, param in self.rnn.named_parameters(): 138 | if 'bias' in name: 139 | nn.init.constant_(param, 0) 140 | elif 'weight' in name: 141 | if self.use_orthogonal: 142 | nn.init.orthogonal_(param) 143 | else: 144 | nn.init.xavier_uniform_(param) 145 | self.layer_norm = nn.LayerNorm(rnn_hidden_dim) 146 | 147 | def forward( 148 | self, 149 | inputs: torch.Tensor, 150 | hidden_state: torch.Tensor, 151 | masks: torch.Tensor, 152 | ) -> tuple[Any, torch.Tensor]: 153 | """Forward method for RNNLayer. 154 | 155 | Args: 156 | inputs (torch.Tensor): (num_agents, input_dim) 157 | hidden_state (torch.Tensor): (num_agents, rnn_layers, rnn_hidden_dim) 158 | masks (torch.Tensor): (num_agents, 1) 159 | 160 | Returns: 161 | tuple[Any, torch.Tensor]: (output, hidden_state) 162 | """ 163 | print('inputs.shape: ', inputs.shape) 164 | print('hidden_state.shape: ', hidden_state.shape) 165 | print('masks.shape: ', masks.shape) 166 | 167 | if inputs.size(0) == hidden_state.size(0): 168 | # If the batch size is the same, we can just run the RNN 169 | masks = masks.repeat(1, self.rnn_layers).unsqueeze(-1) 170 | # mask shape (num_agents, rnn_layers, 1) 171 | hidden_state = (hidden_state * masks).transpose(0, 1).contiguous() 172 | # hidden_state shape (rnn_layers, num_agents, rnn_hidden_dim) 173 | inputs = inputs.unsqueeze(0).transpose(0, 1) 174 | # inputs shape (1, num_agents, input_dim) 175 | output, hidden_state = self.rnn(inputs, hidden_state) 176 | output = output.squeeze(1) 177 | hidden_state = hidden_state.transpose(0, 1) 178 | else: 179 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 180 | N = hidden_state.size(0) 181 | T = int(inputs.size(0) / N) 182 | # unflatten 183 | inputs = inputs.view(T, N, inputs.size(1)) 184 | # Same deal with masks 185 | masks = masks.view(T, N) 186 | # Let's figure out which steps in the sequence have a zero for any agent 187 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 188 | has_zeros = (masks[1:] == 0.0).any( 189 | dim=-1).nonzero().squeeze().cpu() 190 | 191 | # +1 to correct the masks[1:] 192 | if has_zeros.dim() == 0: 193 | # Deal with scalar 194 | has_zeros = [has_zeros.item() + 1] 195 | else: 196 | has_zeros = (has_zeros + 1).numpy().tolist() 197 | 198 | # add t=0 and t=T to the list 199 | has_zeros = [0] + has_zeros + [T] 200 | 201 | hidden_state = hidden_state.transpose(0, 1) 202 | 203 | outputs = [] 204 | for i in range(len(has_zeros) - 1): 205 | # We can now process steps that don't have any zeros in masks together! 206 | # This is much faster 207 | start_idx = has_zeros[i] 208 | end_idx = has_zeros[i + 1] 209 | temp = (hidden_state * masks[start_idx].view(1, -1, 1).repeat( 210 | self.rnn_layers, 1, 1)).contiguous() 211 | rnn_scores, hidden_state = self.rnn(inputs[start_idx:end_idx], 212 | temp) 213 | outputs.append(rnn_scores) 214 | 215 | # assert len(outputs) == T 216 | # x is a (T, N, -1) tensor 217 | inputs = torch.cat(outputs, dim=0) 218 | 219 | # flatten 220 | inputs = inputs.reshape(T * N, -1) 221 | hidden_state = hidden_state.transpose(0, 1) 222 | 223 | output = self.layer_norm(outputs) 224 | return output, hidden_state 225 | -------------------------------------------------------------------------------- /marltoolkit/modules/utils/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.distributions import Bernoulli, Categorical, Normal 4 | 5 | 6 | # 7 | # Standardize distribution interfaces 8 | # 9 | # Categorical 10 | class FixedCategorical(Categorical): 11 | 12 | def sample(self): 13 | return super().sample().unsqueeze(-1) 14 | 15 | def log_probs(self, actions): 16 | return (super().log_prob(actions.squeeze(-1)).view( 17 | actions.size(0), -1).sum(-1).unsqueeze(-1)) 18 | 19 | def mode(self): 20 | return self.probs.argmax(dim=-1, keepdim=True) 21 | 22 | 23 | # Normal 24 | class FixedNormal(Normal): 25 | 26 | def log_probs(self, actions): 27 | return super().log_prob(actions).sum(-1, keepdim=True) 28 | 29 | def entropy(self): 30 | return super().entropy().sum(-1) 31 | 32 | def mode(self): 33 | return self.mean 34 | 35 | 36 | # Bernoulli 37 | class FixedBernoulli(Bernoulli): 38 | 39 | def log_probs(self, actions): 40 | return super.log_prob(actions).view(actions.size(0), 41 | -1).sum(-1).unsqueeze(-1) 42 | 43 | def entropy(self): 44 | return super().entropy().sum(-1) 45 | 46 | def mode(self): 47 | return torch.gt(self.probs, 0.5).float() 48 | 49 | 50 | class CustomCategorical(nn.Module): 51 | 52 | def __init__(self, 53 | num_inputs, 54 | num_outputs, 55 | use_orthogonal=True, 56 | gain=0.01): 57 | super(CustomCategorical, self).__init__() 58 | init_method = [nn.init.xavier_uniform_, 59 | nn.init.orthogonal_][use_orthogonal] 60 | 61 | def init_weight(module: nn.Module) -> None: 62 | if isinstance(module, nn.Linear): 63 | init_method(module.weight, gain=gain) 64 | if module.bias is not None: 65 | nn.init.constant_(module.bias, 0) 66 | 67 | self.linear = nn.Linear(num_inputs, num_outputs) 68 | self.apply(init_weight) 69 | 70 | def forward(self, x, available_actions=None): 71 | x = self.linear(x) 72 | if available_actions is not None: 73 | x[available_actions == 0] = -1e10 74 | return FixedCategorical(logits=x) 75 | 76 | 77 | class DiagGaussian(nn.Module): 78 | 79 | def __init__(self, 80 | num_inputs, 81 | num_outputs, 82 | use_orthogonal=True, 83 | gain=0.01): 84 | super(DiagGaussian, self).__init__() 85 | 86 | init_method = [nn.init.xavier_uniform_, 87 | nn.init.orthogonal_][use_orthogonal] 88 | 89 | def init_weight(module: nn.Module) -> None: 90 | if isinstance(module, nn.Linear): 91 | init_method(module.weight, gain=gain) 92 | if module.bias is not None: 93 | nn.init.constant_(module.bias, 0) 94 | 95 | self.fc_mean = nn.Linear(num_inputs, num_outputs) 96 | self.logstd = AddBias(torch.zeros(num_outputs)) 97 | self.apply(init_weight) 98 | 99 | def forward(self, x): 100 | action_mean = self.fc_mean(x) 101 | 102 | # An ugly hack for my KFAC implementation. 103 | zeros = torch.zeros(action_mean.size()) 104 | if x.is_cuda: 105 | zeros = zeros.cuda() 106 | 107 | action_logstd = self.logstd(zeros) 108 | return FixedNormal(action_mean, action_logstd.exp()) 109 | 110 | 111 | class CustomBernoulli(nn.Module): 112 | 113 | def __init__(self, 114 | num_inputs, 115 | num_outputs, 116 | use_orthogonal=True, 117 | gain=0.01): 118 | super(CustomBernoulli, self).__init__() 119 | init_method = [nn.init.xavier_uniform_, 120 | nn.init.orthogonal_][use_orthogonal] 121 | 122 | def init_weight(module: nn.Module) -> None: 123 | if isinstance(module, nn.Linear): 124 | init_method(module.weight, gain=gain) 125 | if module.bias is not None: 126 | nn.init.constant_(module.bias, 0) 127 | 128 | self.linear = nn.Linear(num_inputs, num_outputs) 129 | self.apply(init_weight) 130 | 131 | def forward(self, x): 132 | x = self.linear(x) 133 | return FixedBernoulli(logits=x) 134 | 135 | 136 | class AddBias(nn.Module): 137 | 138 | def __init__(self, bias): 139 | super(AddBias, self).__init__() 140 | self._bias = nn.Parameter(bias.unsqueeze(1)) 141 | 142 | def forward(self, x): 143 | if x.dim() == 2: 144 | bias = self._bias.t().view(1, -1) 145 | else: 146 | bias = self._bias.t().view(1, -1, 1, 1) 147 | 148 | return x + bias 149 | -------------------------------------------------------------------------------- /marltoolkit/modules/utils/popart.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class PopArt(torch.nn.Module): 10 | 11 | def __init__( 12 | self, 13 | input_shape, 14 | output_shape, 15 | norm_axes=1, 16 | beta=0.99999, 17 | epsilon=1e-5, 18 | device=torch.device('cpu'), 19 | ): 20 | super(PopArt, self).__init__() 21 | 22 | self.beta = beta 23 | self.epsilon = epsilon 24 | self.norm_axes = norm_axes 25 | self.tpdv = dict(dtype=torch.float32, device=device) 26 | 27 | self.input_shape = input_shape 28 | self.output_shape = output_shape 29 | 30 | self.weight = nn.Parameter(torch.Tensor(output_shape, 31 | input_shape)).to(**self.tpdv) 32 | self.bias = nn.Parameter(torch.Tensor(output_shape)).to(**self.tpdv) 33 | 34 | self.stddev = nn.Parameter(torch.ones(output_shape), 35 | requires_grad=False).to(**self.tpdv) 36 | self.mean = nn.Parameter(torch.zeros(output_shape), 37 | requires_grad=False).to(**self.tpdv) 38 | self.mean_sq = nn.Parameter(torch.zeros(output_shape), 39 | requires_grad=False).to(**self.tpdv) 40 | self.debiasing_term = nn.Parameter(torch.tensor(0.0), 41 | requires_grad=False).to(**self.tpdv) 42 | 43 | self.reset_parameters() 44 | 45 | def reset_parameters(self): 46 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) 47 | if self.bias is not None: 48 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out( 49 | self.weight) 50 | bound = 1 / math.sqrt(fan_in) 51 | torch.nn.init.uniform_(self.bias, -bound, bound) 52 | self.mean.zero_() 53 | self.mean_sq.zero_() 54 | self.debiasing_term.zero_() 55 | 56 | def forward(self, input_vector): 57 | if type(input_vector) == np.ndarray: 58 | input_vector = torch.from_numpy(input_vector) 59 | input_vector = input_vector.to(**self.tpdv) 60 | 61 | return F.linear(input_vector, self.weight, self.bias) 62 | 63 | @torch.no_grad() 64 | def update(self, input_vector): 65 | if type(input_vector) == np.ndarray: 66 | input_vector = torch.from_numpy(input_vector) 67 | input_vector = input_vector.to(**self.tpdv) 68 | 69 | old_mean, old_var = self.debiased_mean_var() 70 | old_stddev = torch.sqrt(old_var) 71 | 72 | batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes))) 73 | batch_sq_mean = (input_vector**2).mean( 74 | dim=tuple(range(self.norm_axes))) 75 | 76 | self.mean.mul_(self.beta).add_(batch_mean * (1.0 - self.beta)) 77 | self.mean_sq.mul_(self.beta).add_(batch_sq_mean * (1.0 - self.beta)) 78 | self.debiasing_term.mul_(self.beta).add_(1.0 * (1.0 - self.beta)) 79 | 80 | self.stddev = (self.mean_sq - self.mean**2).sqrt().clamp(min=1e-4) 81 | 82 | new_mean, new_var = self.debiased_mean_var() 83 | new_stddev = torch.sqrt(new_var) 84 | 85 | self.weight = self.weight * old_stddev / new_stddev 86 | self.bias = (old_stddev * self.bias + old_mean - new_mean) / new_stddev 87 | 88 | def debiased_mean_var(self): 89 | debiased_mean = self.mean / self.debiasing_term.clamp(min=self.epsilon) 90 | debiased_mean_sq = self.mean_sq / self.debiasing_term.clamp( 91 | min=self.epsilon) 92 | debiased_var = (debiased_mean_sq - debiased_mean**2).clamp(min=1e-2) 93 | return debiased_mean, debiased_var 94 | 95 | def normalize(self, input_vector): 96 | if type(input_vector) == np.ndarray: 97 | input_vector = torch.from_numpy(input_vector) 98 | input_vector = input_vector.to(**self.tpdv) 99 | 100 | mean, var = self.debiased_mean_var() 101 | out = (input_vector - mean[(None, ) * self.norm_axes] 102 | ) / torch.sqrt(var)[(None, ) * self.norm_axes] 103 | 104 | return out 105 | 106 | def denormalize(self, input_vector): 107 | if type(input_vector) == np.ndarray: 108 | input_vector = torch.from_numpy(input_vector) 109 | input_vector = input_vector.to(**self.tpdv) 110 | 111 | mean, var = self.debiased_mean_var() 112 | out = (input_vector * torch.sqrt(var)[(None, ) * self.norm_axes] + 113 | mean[(None, ) * self.norm_axes]) 114 | 115 | out = out.cpu().numpy() 116 | 117 | return out 118 | -------------------------------------------------------------------------------- /marltoolkit/modules/utils/valuenorm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class ValueNorm(nn.Module): 7 | """Normalize a vector of observations - across the first norm_axes dimensions""" 8 | 9 | def __init__( 10 | self, 11 | input_shape, 12 | norm_axes=1, 13 | beta=0.99999, 14 | per_element_update=False, 15 | epsilon=1e-5, 16 | device=torch.device('cpu'), 17 | ): 18 | super(ValueNorm, self).__init__() 19 | 20 | self.input_shape = input_shape 21 | self.norm_axes = norm_axes 22 | self.epsilon = epsilon 23 | self.beta = beta 24 | self.per_element_update = per_element_update 25 | self.tpdv = dict(dtype=torch.float32, device=device) 26 | 27 | self.running_mean = nn.Parameter(torch.zeros(input_shape), 28 | requires_grad=False).to(**self.tpdv) 29 | self.running_mean_sq = nn.Parameter( 30 | torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) 31 | self.debiasing_term = nn.Parameter(torch.tensor(0.0), 32 | requires_grad=False).to(**self.tpdv) 33 | 34 | self.reset_parameters() 35 | 36 | def reset_parameters(self): 37 | self.running_mean.zero_() 38 | self.running_mean_sq.zero_() 39 | self.debiasing_term.zero_() 40 | 41 | def running_mean_var(self): 42 | debiased_mean = self.running_mean / self.debiasing_term.clamp( 43 | min=self.epsilon) 44 | debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp( 45 | min=self.epsilon) 46 | debiased_var = (debiased_mean_sq - debiased_mean**2).clamp(min=1e-2) 47 | return debiased_mean, debiased_var 48 | 49 | @torch.no_grad() 50 | def update(self, input_vector): 51 | if type(input_vector) == np.ndarray: 52 | input_vector = torch.from_numpy(input_vector) 53 | input_vector = input_vector.to(**self.tpdv) 54 | 55 | batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes))) 56 | batch_sq_mean = (input_vector**2).mean( 57 | dim=tuple(range(self.norm_axes))) 58 | 59 | if self.per_element_update: 60 | batch_size = np.prod(input_vector.size()[:self.norm_axes]) 61 | weight = self.beta**batch_size 62 | else: 63 | weight = self.beta 64 | 65 | self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight)) 66 | self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight)) 67 | self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight)) 68 | 69 | def normalize(self, input_vector): 70 | # Make sure input is float32 71 | if type(input_vector) == np.ndarray: 72 | input_vector = torch.from_numpy(input_vector) 73 | input_vector = input_vector.to(**self.tpdv) 74 | 75 | mean, var = self.running_mean_var() 76 | out = (input_vector - mean[(None, ) * self.norm_axes] 77 | ) / torch.sqrt(var)[(None, ) * self.norm_axes] 78 | 79 | return out 80 | 81 | def denormalize(self, input_vector): 82 | """Transform normalized data back into original distribution.""" 83 | if type(input_vector) == np.ndarray: 84 | input_vector = torch.from_numpy(input_vector) 85 | input_vector = input_vector.to(**self.tpdv) 86 | 87 | mean, var = self.running_mean_var() 88 | out = (input_vector * torch.sqrt(var)[(None, ) * self.norm_axes] + 89 | mean[(None, ) * self.norm_axes]) 90 | 91 | out = out.cpu().numpy() 92 | 93 | return out 94 | -------------------------------------------------------------------------------- /marltoolkit/runners/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/runners/__init__.py -------------------------------------------------------------------------------- /marltoolkit/runners/episode_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Tuple 3 | 4 | from marltoolkit.agents.base_agent import BaseAgent 5 | from marltoolkit.data.ma_buffer import ReplayBuffer 6 | from marltoolkit.envs import MultiAgentEnv 7 | from marltoolkit.utils.logger.logs import avg_val_from_list_of_dicts 8 | 9 | 10 | def run_train_episode( 11 | env: MultiAgentEnv, 12 | agent: BaseAgent, 13 | rpm: ReplayBuffer, 14 | args: argparse.Namespace = None, 15 | ) -> dict[str, float]: 16 | episode_reward = 0.0 17 | episode_step = 0 18 | done = False 19 | agent.init_hidden_states(batch_size=1) 20 | (obs, state, info) = env.reset() 21 | while not done: 22 | available_actions = env.get_available_actions() 23 | actions = agent.sample(obs=obs, available_actions=available_actions) 24 | next_obs, next_state, reward, terminated, truncated, info = env.step( 25 | actions) 26 | done = terminated or truncated 27 | transitions = { 28 | 'obs': obs, 29 | 'state': state, 30 | 'actions': actions, 31 | 'available_actions': available_actions, 32 | 'rewards': reward, 33 | 'dones': done, 34 | 'filled': False, 35 | } 36 | 37 | rpm.store_transitions(transitions) 38 | obs, state = next_obs, next_state 39 | 40 | episode_reward += reward 41 | episode_step += 1 42 | 43 | # fill the episode 44 | for _ in range(episode_step, args.episode_limit): 45 | rpm.episode_data.fill_mask() 46 | 47 | # ReplayBuffer store the episode data 48 | rpm.store_episodes() 49 | is_win = env.win_counted 50 | 51 | train_res_lst = [] 52 | if rpm.size() > args.memory_warmup_size: 53 | for _ in range(args.learner_update_freq): 54 | batch = rpm.sample(args.batch_size) 55 | results = agent.learn(batch) 56 | train_res_lst.append(results) 57 | 58 | train_res_dict = avg_val_from_list_of_dicts(train_res_lst) 59 | 60 | train_res_dict['episode_reward'] = episode_reward 61 | train_res_dict['episode_step'] = episode_step 62 | train_res_dict['win_rate'] = is_win 63 | return train_res_dict 64 | 65 | 66 | def run_eval_episode( 67 | env: MultiAgentEnv, 68 | agent: BaseAgent, 69 | args: argparse.Namespace = None, 70 | ) -> Tuple[float, float, float]: 71 | eval_res_list = [] 72 | for _ in range(args.num_eval_episodes): 73 | agent.init_hidden_states(batch_size=1) 74 | episode_reward = 0.0 75 | episode_step = 0 76 | done = False 77 | obs, state, info = env.reset() 78 | 79 | while not done: 80 | available_actions = env.get_available_actions() 81 | actions = agent.predict( 82 | obs=obs, 83 | available_actions=available_actions, 84 | ) 85 | next_obs, next_state, reward, terminated, truncated, info = env.step( 86 | actions) 87 | done = terminated or truncated 88 | obs = next_obs 89 | episode_step += 1 90 | episode_reward += reward 91 | 92 | is_win = env.win_counted 93 | eval_res_list.append({ 94 | 'episode_reward': episode_reward, 95 | 'episode_step': episode_step, 96 | 'win_rate': is_win, 97 | }) 98 | eval_res_dict = avg_val_from_list_of_dicts(eval_res_list) 99 | return eval_res_dict 100 | -------------------------------------------------------------------------------- /marltoolkit/runners/onpolicy_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Tuple 3 | 4 | from marltoolkit.agents.mappo_agent import MAPPOAgent 5 | from marltoolkit.data.shared_buffer import SharedReplayBuffer 6 | from marltoolkit.envs.smacv1 import SMACWrapperEnv 7 | from marltoolkit.utils.logger.logs import avg_val_from_list_of_dicts 8 | 9 | 10 | def run_train_episode( 11 | env: SMACWrapperEnv, 12 | agent: MAPPOAgent, 13 | rpm: SharedReplayBuffer, 14 | args: argparse.Namespace = None, 15 | ) -> dict[str, float]: 16 | episode_reward = 0.0 17 | episode_step = 0 18 | done = False 19 | agent.init_hidden_states(batch_size=1) 20 | (obs, state, info) = env.reset() 21 | available_actions = env.get_available_actions() 22 | 23 | rpm.obs[0] = obs.copy() 24 | rpm.state[0] = state.copy() 25 | rpm.available_actions[0] = available_actions.copy() 26 | 27 | for step in range(args.episode_limit): 28 | available_actions = env.get_available_actions() 29 | actions = agent.sample(obs=obs, available_actions=available_actions) 30 | next_obs, next_state, reward, terminated, truncated, info = env.step( 31 | actions) 32 | done = terminated or truncated 33 | transitions = { 34 | 'obs': obs, 35 | 'state': state, 36 | 'actions': actions, 37 | 'available_actions': available_actions, 38 | 'rewards': reward, 39 | 'dones': done, 40 | 'filled': False, 41 | } 42 | 43 | rpm.store_transitions(transitions) 44 | obs, state = next_obs, next_state 45 | 46 | episode_reward += reward 47 | episode_step += 1 48 | 49 | # fill the episode 50 | for _ in range(episode_step, args.episode_limit): 51 | rpm.episode_data.fill_mask() 52 | 53 | # ReplayBuffer store the episode data 54 | rpm.store_episodes() 55 | is_win = env.win_counted 56 | 57 | train_res_lst = [] 58 | if rpm.size() > args.memory_warmup_size: 59 | for _ in range(args.learner_update_freq): 60 | batch = rpm.sample(args.batch_size) 61 | results = agent.learn(batch) 62 | train_res_lst.append(results) 63 | 64 | train_res_dict = avg_val_from_list_of_dicts(train_res_lst) 65 | 66 | train_res_dict['episode_reward'] = episode_reward 67 | train_res_dict['episode_step'] = episode_step 68 | train_res_dict['win_rate'] = is_win 69 | return train_res_dict 70 | 71 | 72 | def run_eval_episode( 73 | env: SMACWrapperEnv, 74 | agent: MAPPOAgent, 75 | args: argparse.Namespace = None, 76 | ) -> Tuple[float, float, float]: 77 | eval_res_list = [] 78 | for _ in range(args.num_eval_episodes): 79 | agent.init_hidden_states(batch_size=1) 80 | episode_reward = 0.0 81 | episode_step = 0 82 | done = False 83 | obs, state, info = env.reset() 84 | 85 | while not done: 86 | available_actions = env.get_available_actions() 87 | actions = agent.predict( 88 | obs=obs, 89 | available_actions=available_actions, 90 | ) 91 | next_obs, next_state, reward, terminated, truncated, info = env.step( 92 | actions) 93 | done = terminated or truncated 94 | obs = next_obs 95 | episode_step += 1 96 | episode_reward += reward 97 | 98 | is_win = env.win_counted 99 | eval_res_list.append({ 100 | 'episode_reward': episode_reward, 101 | 'episode_step': episode_step, 102 | 'win_rate': is_win, 103 | }) 104 | eval_res_dict = avg_val_from_list_of_dicts(eval_res_list) 105 | return eval_res_dict 106 | -------------------------------------------------------------------------------- /marltoolkit/runners/parallel_episode_runner.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | from marltoolkit.agents import BaseAgent 6 | from marltoolkit.data import MaEpisodeData, OffPolicyBufferRNN 7 | from marltoolkit.envs import BaseVecEnv 8 | 9 | 10 | def run_train_episode( 11 | envs: BaseVecEnv, 12 | agent: BaseAgent, 13 | rpm: OffPolicyBufferRNN, 14 | args: argparse.Namespace = None, 15 | ): 16 | agent.reset_agent() 17 | # reset the environment 18 | obs, state, info = envs.reset() 19 | num_envs = envs.num_envs 20 | env_dones = envs.buf_dones 21 | episode_step = 0 22 | episode_score = np.zeros(num_envs, dtype=np.float32) 23 | filled = np.zeros([args.episode_limit, num_envs, 1], dtype=np.int32) 24 | episode_data = MaEpisodeData( 25 | num_envs, 26 | args.num_agents, 27 | args.episode_limit, 28 | args.obs_shape, 29 | args.state_shape, 30 | args.action_shape, 31 | args.reward_shape, 32 | args.done_shape, 33 | ) 34 | while not env_dones.all(): 35 | available_actions = envs.get_available_actions() 36 | # Get actions from the agent 37 | actions = agent.sample(obs, available_actions) 38 | # Environment step 39 | next_obs, next_state, rewards, env_dones, info = envs.step(actions) 40 | # Fill the episode buffer 41 | filled[episode_step, :] = np.ones([num_envs, 1]) 42 | transitions = dict( 43 | obs=obs, 44 | state=state, 45 | actions=actions, 46 | rewards=rewards, 47 | env_dones=env_dones, 48 | available_actions=available_actions, 49 | ) 50 | # Store the transitions 51 | episode_data.store_transitions(transitions) 52 | # Check if the episode is done 53 | for env_idx in range(num_envs): 54 | if env_dones[env_idx]: 55 | # Fill the rest of the episode with zeros 56 | filled[episode_step, env_idx, :] = 0 57 | # Get the episode score from the info 58 | final_info = info['final_info'] 59 | episode_score[env_idx] = final_info[env_idx]['episode_score'] 60 | # Get the current available actions 61 | available_actions = envs.get_available_actions() 62 | terminal_data = (next_obs, next_state, available_actions, 63 | filled) 64 | # Finish the episode 65 | rpm.finish_path(env_idx, episode_step, *terminal_data) 66 | 67 | # Update the episode step 68 | episode_step += 1 69 | obs, state = next_obs, next_state 70 | # Store the episode data 71 | rpm.store_episodes(episode_data.episode_buffer) 72 | 73 | mean_loss = [] 74 | mean_td_error = [] 75 | if rpm.size() > args.memory_warmup_size: 76 | for _ in range(args.learner_update_freq): 77 | batch = rpm.sample_batch(args.batch_size) 78 | loss, td_error = agent.learn(**batch) 79 | mean_loss.append(loss) 80 | mean_td_error.append(td_error) 81 | 82 | mean_loss = np.mean(mean_loss) if mean_loss else None 83 | mean_td_error = np.mean(mean_td_error) if mean_td_error else None 84 | 85 | return episode_score, episode_step, mean_loss, mean_td_error 86 | 87 | 88 | def run_eval_episode( 89 | env: BaseVecEnv, 90 | agent: BaseAgent, 91 | num_eval_episodes: int = 5, 92 | ): 93 | eval_is_win_buffer = [] 94 | eval_reward_buffer = [] 95 | eval_steps_buffer = [] 96 | for _ in range(num_eval_episodes): 97 | agent.reset_agent() 98 | episode_reward = 0.0 99 | episode_step = 0 100 | terminated = False 101 | obs, state = env.reset() 102 | while not terminated: 103 | available_actions = env.get_available_actions() 104 | actions = agent.predict(obs, available_actions) 105 | state, obs, reward, terminated = env.step(actions) 106 | episode_step += 1 107 | episode_reward += reward 108 | 109 | is_win = env.win_counted 110 | 111 | eval_reward_buffer.append(episode_reward) 112 | eval_steps_buffer.append(episode_step) 113 | eval_is_win_buffer.append(is_win) 114 | 115 | eval_rewards = np.mean(eval_reward_buffer) 116 | eval_steps = np.mean(eval_steps_buffer) 117 | eval_win_rate = np.mean(eval_is_win_buffer) 118 | 119 | return eval_rewards, eval_steps, eval_win_rate 120 | -------------------------------------------------------------------------------- /marltoolkit/runners/qtran_runner.py.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from marltoolkit.data.ma_replaybuffer import EpisodeData, ReplayBuffer 4 | 5 | 6 | def run_train_episode(env, agent, rpm: ReplayBuffer, config: dict = None): 7 | 8 | episode_limit = config['episode_limit'] 9 | agent.reset_agent() 10 | episode_reward = 0.0 11 | episode_step = 0 12 | terminated = False 13 | state, obs = env.reset() 14 | episode_experience = EpisodeData( 15 | episode_limit=episode_limit, 16 | state_shape=config['state_shape'], 17 | obs_shape=config['obs_shape'], 18 | num_actions=config['n_actions'], 19 | num_agents=config['n_agents'], 20 | ) 21 | 22 | while not terminated: 23 | available_actions = env.get_available_actions() 24 | actions = agent.sample(obs, available_actions) 25 | actions_onehot = env._get_actions_one_hot(actions) 26 | next_state, next_obs, reward, terminated = env.step(actions) 27 | episode_reward += reward 28 | episode_step += 1 29 | episode_experience.add(state, obs, actions, actions_onehot, 30 | available_actions, reward, terminated, 0) 31 | state = next_state 32 | obs = next_obs 33 | 34 | # fill the episode 35 | for _ in range(episode_step, episode_limit): 36 | episode_experience.fill_mask() 37 | 38 | episode_data = episode_experience.get_data() 39 | 40 | rpm.store(**episode_data) 41 | is_win = env.win_counted 42 | 43 | mean_loss = [] 44 | mean_td_loss = [] 45 | mean_opt_loss = [] 46 | mean_nopt_loss = [] 47 | if rpm.size() > config['memory_warmup_size']: 48 | for _ in range(config['learner_update_freq']): 49 | batch = rpm.sample_batch(config['batch_size']) 50 | loss, td_loss, opt_loss, nopt_loss = agent.learn(**batch) 51 | mean_loss.append(loss) 52 | mean_td_loss.append(td_loss) 53 | mean_opt_loss.append(opt_loss) 54 | mean_nopt_loss.append(nopt_loss) 55 | 56 | mean_loss = np.mean(mean_loss) if mean_loss else None 57 | mean_td_loss = np.mean(mean_td_loss) if mean_td_loss else None 58 | mean_opt_loss = np.mean(mean_opt_loss) if mean_opt_loss else None 59 | mean_nopt_loss = np.mean(mean_nopt_loss) if mean_nopt_loss else None 60 | 61 | return episode_reward, episode_step, is_win, mean_loss, mean_td_loss, mean_opt_loss, mean_nopt_loss 62 | 63 | 64 | def run_eval_episode(env, agent, num_eval_episodes=5): 65 | eval_is_win_buffer = [] 66 | eval_reward_buffer = [] 67 | eval_steps_buffer = [] 68 | for _ in range(num_eval_episodes): 69 | agent.reset_agent() 70 | episode_reward = 0.0 71 | episode_step = 0 72 | terminated = False 73 | state, obs = env.reset() 74 | while not terminated: 75 | available_actions = env.get_available_actions() 76 | actions = agent.predict(obs, available_actions) 77 | state, obs, reward, terminated = env.step(actions) 78 | episode_step += 1 79 | episode_reward += reward 80 | 81 | is_win = env.win_counted 82 | 83 | eval_reward_buffer.append(episode_reward) 84 | eval_steps_buffer.append(episode_step) 85 | eval_is_win_buffer.append(is_win) 86 | 87 | eval_rewards = np.mean(eval_reward_buffer) 88 | eval_steps = np.mean(eval_steps_buffer) 89 | eval_win_rate = np.mean(eval_is_win_buffer) 90 | 91 | return eval_rewards, eval_steps, eval_win_rate 92 | -------------------------------------------------------------------------------- /marltoolkit/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .logger import (BaseLogger, TensorboardLogger, WandbLogger, get_outdir, 2 | get_root_logger) 3 | from .lr_scheduler import (LinearDecayScheduler, MultiStepScheduler, 4 | PiecewiseScheduler) 5 | from .model_utils import (check_model_method, hard_target_update, 6 | soft_target_update) 7 | from .progressbar import ProgressBar 8 | from .timer import Timer 9 | from .transforms import OneHotTransform 10 | 11 | __all__ = [ 12 | 'BaseLogger', 'TensorboardLogger', 'WandbLogger', 'get_outdir', 13 | 'get_root_logger', 'ProgressBar', 'Timer', 'OneHotTransform', 14 | 'hard_target_update', 'soft_target_update', 'check_model_method', 15 | 'LinearDecayScheduler', 'PiecewiseScheduler', 'MultiStepScheduler' 16 | ] 17 | -------------------------------------------------------------------------------- /marltoolkit/utils/env_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from marltoolkit.envs.smacv1 import SMACWrapperEnv 4 | from marltoolkit.envs.vec_env import SubprocVecEnv 5 | 6 | 7 | def make_vec_env( 8 | env_id: str, 9 | map_name: str, 10 | num_train_envs: int = 10, 11 | num_test_envs: int = 10, 12 | **kwargs, 13 | ): 14 | 15 | def make_env(): 16 | if env_id == 'SMAC-v1': 17 | env = SMACWrapperEnv(map_name, **kwargs) 18 | else: 19 | raise ValueError(f'Unknown environment: {env_id}') 20 | return env 21 | 22 | train_envs = SubprocVecEnv([make_env for _ in range(num_train_envs)]) 23 | test_envs = SubprocVecEnv([make_env for _ in range(num_test_envs)]) 24 | return train_envs, test_envs 25 | 26 | 27 | def get_actor_input_dim(args: argparse.Namespace) -> None: 28 | """Get the input shape of the actor model. 29 | 30 | Args: 31 | args (argparse.Namespace): The arguments 32 | Returns: 33 | input_shape (int): The input shape of the actor model. 34 | """ 35 | input_dim = args.obs_dim 36 | if args.use_global_state: 37 | input_dim += args.state_dim 38 | if args.use_last_actions: 39 | input_dim += args.n_actions 40 | if args.use_agents_id_onehot: 41 | input_dim += args.num_agents 42 | return input_dim 43 | 44 | 45 | def get_critic_input_dim(args: argparse.Namespace) -> None: 46 | """Get the input shape of the critic model. 47 | 48 | Args: 49 | args (argparse.Namespace): The arguments. 50 | 51 | Returns: 52 | input_dim (int): The input shape of the critic model. 53 | """ 54 | input_dim = args.obs_dim 55 | if args.use_global_state: 56 | input_dim += args.state_dim 57 | if args.use_last_actions: 58 | input_dim += args.n_actions 59 | if args.use_agents_id_onehot: 60 | input_dim += args.num_agents 61 | return input_dim 62 | 63 | 64 | def get_shape_from_obs_space(obs_space): 65 | if obs_space.__class__.__name__ == 'Box': 66 | obs_shape = obs_space.shape 67 | elif obs_space.__class__.__name__ == 'Tuple': 68 | obs_shape = obs_space 69 | else: 70 | raise NotImplementedError 71 | return obs_shape 72 | 73 | 74 | def get_shape_from_act_space(act_space): 75 | if act_space.__class__.__name__ == 'Discrete': 76 | act_shape = 1 77 | elif act_space.__class__.__name__ == 'MultiDiscrete': 78 | act_shape = act_space.shape 79 | elif act_space.__class__.__name__ == 'Box': 80 | act_shape = act_space.shape[0] 81 | elif act_space.__class__.__name__ == 'MultiBinary': 82 | act_shape = act_space.shape[0] 83 | else: # agar 84 | act_shape = act_space[0].shape[0] + 1 85 | return act_shape 86 | -------------------------------------------------------------------------------- /marltoolkit/utils/logger/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseLogger 2 | from .logs import get_outdir, get_root_logger 3 | from .tensorboard import TensorboardLogger 4 | from .wandb import WandbLogger 5 | 6 | __all__ = [ 7 | 'BaseLogger', 'TensorboardLogger', 'WandbLogger', 'get_root_logger', 8 | 'get_outdir' 9 | ] 10 | -------------------------------------------------------------------------------- /marltoolkit/utils/logger/base.py: -------------------------------------------------------------------------------- 1 | """Code Reference Tianshou https://github.com/thu- 2 | ml/tianshou/tree/master/tianshou/utils/logger.""" 3 | from abc import ABC, abstractmethod 4 | from enum import Enum 5 | from numbers import Number 6 | from typing import Callable, Dict, Optional, Tuple, Union 7 | 8 | import numpy as np 9 | 10 | LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]] 11 | 12 | 13 | class DataScope(Enum): 14 | TRAIN = 'train' 15 | TEST = 'test' 16 | UPDATE = 'update' 17 | INFO = 'info' 18 | 19 | 20 | class BaseLogger(ABC): 21 | """The base class for any logger which is compatible with trainer. 22 | 23 | Try to overwrite write() method to use your own writer. 24 | 25 | :param train_interval: the log interval in log_train_data(). Default to 1000. 26 | :param test_interval: the log interval in log_test_data(). Default to 1. 27 | :param update_interval: the log interval in log_update_data(). Default to 1000. 28 | :param info_interval: the log interval in log_info_data(). Default to 1. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | train_interval: int = 1000, 34 | test_interval: int = 1, 35 | update_interval: int = 1000, 36 | info_interval: int = 1, 37 | ) -> None: 38 | super().__init__() 39 | self.train_interval = train_interval 40 | self.test_interval = test_interval 41 | self.update_interval = update_interval 42 | self.info_interval = info_interval 43 | self.last_log_train_step = -1 44 | self.last_log_test_step = -1 45 | self.last_log_update_step = -1 46 | self.last_log_info_step = -1 47 | 48 | @abstractmethod 49 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 50 | """Specify how the writer is used to log data. 51 | 52 | :param str step_type: namespace which the data dict belongs to. 53 | :param step: stands for the ordinate of the data dict. 54 | :param data: the data to write with format ``{key: value}``. 55 | """ 56 | pass 57 | 58 | def log_train_data(self, log_data: dict, step: int) -> None: 59 | """Use writer to log statistics generated during training. 60 | 61 | :param log_data: a dict containing information of data collected in 62 | training stage, i.e., returns of collector.collect(). 63 | :param int step: stands for the timestep the log_data being logged. 64 | """ 65 | if step - self.last_log_train_step >= self.train_interval: 66 | log_data = {f'train/{k}': v for k, v in log_data.items()} 67 | self.write('train/env_step', step, log_data) 68 | self.last_log_train_step = step 69 | 70 | def log_test_data(self, log_data: dict, step: int) -> None: 71 | """Use writer to log statistics generated during evaluating. 72 | 73 | :param log_data: a dict containing information of data collected in 74 | evaluating stage, i.e., returns of collector.collect(). 75 | :param int step: stands for the timestep the log_data being logged. 76 | """ 77 | if step - self.last_log_test_step >= self.test_interval: 78 | log_data = {f'test/{k}': v for k, v in log_data.items()} 79 | self.write('test/env_step', step, log_data) 80 | self.last_log_test_step = step 81 | 82 | def log_update_data(self, log_data: dict, step: int) -> None: 83 | """Use writer to log statistics generated during updating. 84 | 85 | :param log_data: a dict containing information of data collected in 86 | updating stage, i.e., returns of policy.update(). 87 | :param int step: stands for the timestep the log_data being logged. 88 | """ 89 | if step - self.last_log_update_step >= self.update_interval: 90 | log_data = {f'update/{k}': v for k, v in log_data.items()} 91 | self.write('update/gradient_step', step, log_data) 92 | self.last_log_update_step = step 93 | 94 | def log_info_data(self, log_data: dict, step: int) -> None: 95 | """Use writer to log global statistics. 96 | 97 | :param log_data: a dict containing information of data collected at the end of an epoch. 98 | :param step: stands for the timestep the training info is logged. 99 | """ 100 | if (step - self.last_log_info_step >= self.info_interval): 101 | log_data = {f'info/{k}': v for k, v in log_data.items()} 102 | self.write('info/epoch', step, log_data) 103 | self.last_log_info_step = step 104 | 105 | @abstractmethod 106 | def save_data( 107 | self, 108 | epoch: int, 109 | env_step: int, 110 | gradient_step: int, 111 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 112 | ) -> None: 113 | """Use writer to log metadata when calling ``save_checkpoint_fn`` in 114 | trainer. 115 | 116 | :param int epoch: the epoch in trainer. 117 | :param int env_step: the env_step in trainer. 118 | :param int gradient_step: the gradient_step in trainer. 119 | :param function save_checkpoint_fn: a hook defined by user, see trainer 120 | documentation for detail. 121 | """ 122 | pass 123 | 124 | @abstractmethod 125 | def restore_data(self) -> Tuple[int, int, int]: 126 | """Return the metadata from existing log. 127 | 128 | If it finds nothing or an error occurs during the recover process, it will 129 | return the default parameters. 130 | 131 | :return: epoch, env_step, gradient_step. 132 | """ 133 | pass 134 | 135 | 136 | class LazyLogger(BaseLogger): 137 | """A logger that does nothing. 138 | 139 | Used as the placeholder in trainer. 140 | """ 141 | 142 | def __init__(self) -> None: 143 | super().__init__() 144 | 145 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 146 | """The LazyLogger writes nothing.""" 147 | pass 148 | 149 | def save_data( 150 | self, 151 | epoch: int, 152 | env_step: int, 153 | gradient_step: int, 154 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 155 | ) -> None: 156 | pass 157 | 158 | def restore_data(self) -> Tuple[int, int, int]: 159 | return 0, 0, 0 160 | -------------------------------------------------------------------------------- /marltoolkit/utils/logger/base_orig.py: -------------------------------------------------------------------------------- 1 | """Code Reference Tianshou https://github.com/thu- 2 | ml/tianshou/tree/master/tianshou/utils/logger.""" 3 | from abc import ABC, abstractmethod 4 | from numbers import Number 5 | from typing import Callable, Dict, Optional, Tuple, Union 6 | 7 | import numpy as np 8 | 9 | LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]] 10 | 11 | 12 | class BaseLogger(ABC): 13 | """The base class for any logger which is compatible with trainer. 14 | 15 | Try to overwrite write() method to use your own writer. 16 | 17 | :param int train_interval: the log interval in log_train_data(). Default to 1000. 18 | :param int test_interval: the log interval in log_test_data(). Default to 1. 19 | :param int update_interval: the log interval in log_update_data(). Default to 1000. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | train_interval: int = 1000, 25 | test_interval: int = 1, 26 | update_interval: int = 1000, 27 | ) -> None: 28 | super().__init__() 29 | self.train_interval = train_interval 30 | self.test_interval = test_interval 31 | self.update_interval = update_interval 32 | self.last_log_train_step = -1 33 | self.last_log_test_step = -1 34 | self.last_log_update_step = -1 35 | 36 | @abstractmethod 37 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 38 | """Specify how the writer is used to log data. 39 | 40 | :param str step_type: namespace which the data dict belongs to. 41 | :param int step: stands for the ordinate of the data dict. 42 | :param dict data: the data to write with format ``{key: value}``. 43 | """ 44 | pass 45 | 46 | def log_train_data(self, collect_result: dict, step: int) -> None: 47 | """Use writer to log statistics generated during training. 48 | 49 | :param collect_result: a dict containing information of data collected in 50 | training stage, i.e., returns of collector.collect(). 51 | :param int step: stands for the timestep the collect_result being logged. 52 | """ 53 | if step - self.last_log_train_step >= self.train_interval: 54 | log_data = {f'train/{k}': v for k, v in collect_result.items()} 55 | self.write('train/env_step', step, log_data) 56 | self.last_log_train_step = step 57 | 58 | def log_test_data(self, collect_result: dict, step: int) -> None: 59 | """Use writer to log statistics generated during evaluating. 60 | 61 | :param collect_result: a dict containing information of data collected in 62 | evaluating stage, i.e., returns of collector.collect(). 63 | :param int step: stands for the timestep the collect_result being logged. 64 | """ 65 | if step - self.last_log_test_step >= self.test_interval: 66 | log_data = {f'test/{k}': v for k, v in collect_result.items()} 67 | self.write('test/env_step', step, log_data) 68 | self.last_log_test_step = step 69 | 70 | def log_update_data(self, update_result: dict, step: int) -> None: 71 | """Use writer to log statistics generated during updating. 72 | 73 | :param update_result: a dict containing information of data collected in 74 | updating stage, i.e., returns of policy.update(). 75 | :param int step: stands for the timestep the collect_result being logged. 76 | """ 77 | if step - self.last_log_update_step >= self.update_interval: 78 | log_data = {f'update/{k}': v for k, v in update_result.items()} 79 | self.write('update/gradient_step', step, log_data) 80 | self.last_log_update_step = step 81 | 82 | @abstractmethod 83 | def save_data( 84 | self, 85 | epoch: int, 86 | env_step: int, 87 | gradient_step: int, 88 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 89 | ) -> None: 90 | """Use writer to log metadata when calling ``save_checkpoint_fn`` in 91 | trainer. 92 | 93 | :param int epoch: the epoch in trainer. 94 | :param int env_step: the env_step in trainer. 95 | :param int gradient_step: the gradient_step in trainer. 96 | :param function save_checkpoint_fn: a hook defined by user, see trainer 97 | documentation for detail. 98 | """ 99 | pass 100 | 101 | @abstractmethod 102 | def restore_data(self) -> Tuple[int, int, int]: 103 | """Return the metadata from existing log. 104 | 105 | If it finds nothing or an error occurs during the recover process, it will 106 | return the default parameters. 107 | 108 | :return: epoch, env_step, gradient_step. 109 | """ 110 | pass 111 | 112 | 113 | class LazyLogger(BaseLogger): 114 | """A logger that does nothing. 115 | 116 | Used as the placeholder in trainer. 117 | """ 118 | 119 | def __init__(self) -> None: 120 | super().__init__() 121 | 122 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 123 | """The LazyLogger writes nothing.""" 124 | pass 125 | 126 | def save_data( 127 | self, 128 | epoch: int, 129 | env_step: int, 130 | gradient_step: int, 131 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 132 | ) -> None: 133 | pass 134 | 135 | def restore_data(self) -> Tuple[int, int, int]: 136 | return 0, 0, 0 137 | -------------------------------------------------------------------------------- /marltoolkit/utils/logger/logging.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | import torch.distributed as dist 5 | 6 | logger_initialized: dict = {} 7 | 8 | 9 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 10 | """Initialize and get a logger by name. 11 | 12 | If the logger has not been initialized, this method will initialize the 13 | logger by adding one or two handlers, otherwise the initialized logger will 14 | be directly returned. During initialization, a StreamHandler will always be 15 | added. If `log_file` is specified and the process rank is 0, a FileHandler 16 | will also be added. 17 | 18 | Args: 19 | name (str): Logger name. 20 | log_file (str | None): The log filename. If specified, a FileHandler 21 | will be added to the logger. 22 | log_level (int): The logger level. Note that only the process of 23 | rank 0 is affected, and other processes will set the level to 24 | "Error" thus be silent most of the time. 25 | file_mode (str): The file mode used in opening log file. 26 | Defaults to 'w'. 27 | 28 | Returns: 29 | logging.Logger: The expected logger. 30 | """ 31 | logger = logging.getLogger(name) 32 | if name in logger_initialized: 33 | return logger 34 | # handle hierarchical names 35 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 36 | # initialization since it is a child of "a". 37 | for logger_name in logger_initialized: 38 | if name.startswith(logger_name): 39 | return logger 40 | 41 | # handle duplicate logs to the console 42 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 43 | # to the root logger. As logger.propagate is True by default, this root 44 | # level handler causes logging messages from rank>0 processes to 45 | # unexpectedly show up on the console, creating much unwanted clutter. 46 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 47 | # at the ERROR level. 48 | for handler in logger.root.handlers: 49 | if type(handler) is logging.StreamHandler: 50 | handler.setLevel(logging.ERROR) 51 | 52 | stream_handler = logging.StreamHandler() 53 | handlers = [stream_handler] 54 | 55 | if dist.is_available() and dist.is_initialized(): 56 | rank = dist.get_rank() 57 | else: 58 | rank = 0 59 | 60 | # only rank 0 will add a FileHandler 61 | if rank == 0 and log_file is not None: 62 | # Here, the default behaviour of the official logger is 'a'. Thus, we 63 | # provide an interface to change the file mode to the default 64 | # behaviour. 65 | file_handler = logging.FileHandler(log_file, file_mode) 66 | handlers.append(file_handler) 67 | 68 | formatter = logging.Formatter( 69 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 70 | for handler in handlers: 71 | handler.setFormatter(formatter) 72 | handler.setLevel(log_level) 73 | logger.addHandler(handler) 74 | 75 | if rank == 0: 76 | logger.setLevel(log_level) 77 | else: 78 | logger.setLevel(logging.ERROR) 79 | 80 | logger_initialized[name] = True 81 | 82 | return logger 83 | 84 | 85 | def print_log(msg, logger=None, level=logging.INFO): 86 | """Print a log message. 87 | 88 | Args: 89 | msg (str): The message to be logged. 90 | logger (logging.Logger | str | None): The logger to be used. 91 | Some special loggers are: 92 | 93 | - "silent": no message will be printed. 94 | - other str: the logger obtained with `get_root_logger(logger)`. 95 | - None: The `print()` method will be used to print log messages. 96 | level (int): Logging level. Only available when `logger` is a Logger 97 | object or "root". 98 | """ 99 | if logger is None: 100 | print(msg) 101 | elif isinstance(logger, logging.Logger): 102 | logger.log(level, msg) 103 | elif logger == 'silent': 104 | pass 105 | elif isinstance(logger, str): 106 | _logger = get_logger(logger) 107 | _logger.log(level, msg) 108 | else: 109 | raise TypeError( 110 | 'logger should be either a logging.Logger object, str, ' 111 | f'"silent" or None, but got {type(logger)}') 112 | -------------------------------------------------------------------------------- /marltoolkit/utils/logger/logs.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from collections import OrderedDict, defaultdict 4 | 5 | from .logging import get_logger 6 | 7 | try: 8 | import wandb 9 | except ImportError: 10 | pass 11 | 12 | 13 | def get_root_logger(log_file=None, log_level=logging.INFO): 14 | """Get root logger. 15 | 16 | Args: 17 | log_file (str, optional): File path of log. Defaults to None. 18 | log_level (int, optional): The level of logger. 19 | Defaults to logging.INFO. 20 | 21 | Returns: 22 | :obj:`logging.Logger`: The obtained logger 23 | """ 24 | logger = get_logger(name='marltoolkit', 25 | log_file=log_file, 26 | log_level=log_level) 27 | 28 | return logger 29 | 30 | 31 | def get_outdir(path, *paths, inc=False): 32 | outdir = os.path.join(path, *paths) 33 | if not os.path.exists(outdir): 34 | os.makedirs(outdir) 35 | elif inc: 36 | count = 1 37 | outdir_inc = outdir + '-' + str(count) 38 | while os.path.exists(outdir_inc): 39 | count = count + 1 40 | outdir_inc = outdir + '-' + str(count) 41 | assert count < 100 42 | outdir = outdir_inc 43 | os.makedirs(outdir) 44 | return outdir 45 | 46 | 47 | def avg_val_from_list_of_dicts(list_of_dicts): 48 | sum_values = defaultdict(int) 49 | count_dicts = defaultdict(int) 50 | 51 | # Transpose the list of dictionaries into a list of key-value pairs 52 | for dictionary in list_of_dicts: 53 | for key, value in dictionary.items(): 54 | sum_values[key] += value 55 | count_dicts[key] += 1 56 | 57 | # Calculate the average values using a dictionary comprehension 58 | avg_val_dict = { 59 | key: sum_value / count_dicts[key] 60 | for key, sum_value in sum_values.items() 61 | } 62 | 63 | return avg_val_dict 64 | 65 | 66 | def update_summary(train_metrics, eval_metrics, log_wandb=False): 67 | rowd = OrderedDict() 68 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) 69 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) 70 | if log_wandb: 71 | wandb.log(rowd) 72 | -------------------------------------------------------------------------------- /marltoolkit/utils/logger/tensorboard.py: -------------------------------------------------------------------------------- 1 | """Code Reference Tianshou https://github.com/thu- 2 | ml/tianshou/tree/master/tianshou/utils/logger.""" 3 | 4 | from typing import Callable, Optional, Tuple 5 | 6 | from tensorboard.backend.event_processing import event_accumulator 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | from .base import LOG_DATA_TYPE, BaseLogger 10 | 11 | 12 | class TensorboardLogger(BaseLogger): 13 | """A logger that relies on tensorboard SummaryWriter by default to 14 | visualize and log statistics. 15 | 16 | :param SummaryWriter writer: the writer to log data. 17 | :param train_interval: the log interval in log_train_data(). Default to 1000. 18 | :param test_interval: the log interval in log_test_data(). Default to 1. 19 | :param update_interval: the log interval in log_update_data(). Default to 1000. 20 | :param info_interval: the log interval in log_info_data(). Default to 1. 21 | :param save_interval: the save interval in save_data(). Default to 1 (save at 22 | the end of each epoch). 23 | :param write_flush: whether to flush tensorboard result after each 24 | add_scalar operation. Default to True. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | writer: SummaryWriter, 30 | train_interval: int = 1000, 31 | test_interval: int = 1, 32 | update_interval: int = 1000, 33 | info_interval: int = 1, 34 | save_interval: int = 1, 35 | write_flush: bool = True, 36 | ) -> None: 37 | super().__init__(train_interval, test_interval, update_interval, 38 | info_interval) 39 | self.save_interval = save_interval 40 | self.write_flush = write_flush 41 | self.last_save_step = -1 42 | self.writer = writer 43 | 44 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 45 | for k, v in data.items(): 46 | self.writer.add_scalar(k, v, global_step=step) 47 | if self.write_flush: # issue 580 48 | self.writer.flush() # issue #482 49 | 50 | def save_data( 51 | self, 52 | epoch: int, 53 | env_step: int, 54 | gradient_step: int, 55 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 56 | ) -> None: 57 | if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: 58 | self.last_save_step = epoch 59 | save_checkpoint_fn(epoch, env_step, gradient_step) 60 | self.write('save/epoch', epoch, {'save/epoch': epoch}) 61 | self.write('save/env_step', env_step, {'save/env_step': env_step}) 62 | self.write('save/gradient_step', gradient_step, 63 | {'save/gradient_step': gradient_step}) 64 | 65 | def restore_data(self) -> Tuple[int, int, int]: 66 | ea = event_accumulator.EventAccumulator(self.writer.log_dir) 67 | ea.Reload() 68 | 69 | try: # epoch / gradient_step 70 | epoch = ea.scalars.Items('save/epoch')[-1].step 71 | self.last_save_step = self.last_log_test_step = epoch 72 | gradient_step = ea.scalars.Items('save/gradient_step')[-1].step 73 | self.last_log_update_step = gradient_step 74 | except KeyError: 75 | epoch, gradient_step = 0, 0 76 | try: # offline trainer doesn't have env_step 77 | env_step = ea.scalars.Items('save/env_step')[-1].step 78 | self.last_log_train_step = env_step 79 | except KeyError: 80 | env_step = 0 81 | 82 | return epoch, env_step, gradient_step 83 | -------------------------------------------------------------------------------- /marltoolkit/utils/logger/wandb.py: -------------------------------------------------------------------------------- 1 | """Code Reference Tianshou https://github.com/thu- 2 | ml/tianshou/tree/master/tianshou/utils/logger.""" 3 | 4 | import argparse 5 | import contextlib 6 | import os 7 | from typing import Callable, Optional, Tuple 8 | 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from .base import LOG_DATA_TYPE, BaseLogger 12 | from .tensorboard import TensorboardLogger 13 | 14 | with contextlib.suppress(ImportError): 15 | import wandb 16 | 17 | 18 | class WandbLogger(BaseLogger): 19 | """Weights and Biases logger that sends data to https://wandb.ai/. 20 | 21 | This logger creates three panels with plots: train, test, and update. 22 | Make sure to select the correct access for each panel in weights and biases: 23 | 24 | Example of usage: 25 | :: 26 | 27 | logger = WandbLogger() 28 | logger.load(SummaryWriter(log_path)) 29 | result = OnpolicyTrainer(policy, train_collector, test_collector, 30 | logger=logger).run() 31 | 32 | :param train_interval: the log interval in log_train_data(). Default to 1000. 33 | :param test_interval: the log interval in log_test_data(). Default to 1. 34 | :param update_interval: the log interval in log_update_data(). 35 | Default to 1000. 36 | :param info_interval: the log interval in log_info_data(). Default to 1. 37 | :param save_interval: the save interval in save_data(). Default to 1 (save at 38 | the end of each epoch). 39 | :param write_flush: whether to flush tensorboard result after each 40 | add_scalar operation. Default to True. 41 | :param str project: W&B project name. Default to "tianshou". 42 | :param str name: W&B run name. Default to None. If None, random name is assigned. 43 | :param str entity: W&B team/organization name. Default to None. 44 | :param str run_id: run id of W&B run to be resumed. Default to None. 45 | :param argparse.Namespace config: experiment configurations. Default to None. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | dir: str = None, 51 | train_interval: int = 1000, 52 | test_interval: int = 1, 53 | update_interval: int = 1000, 54 | info_interval: int = 1, 55 | save_interval: int = 1000, 56 | write_flush: bool = True, 57 | project: Optional[str] = None, 58 | name: Optional[str] = None, 59 | entity: Optional[str] = None, 60 | run_id: Optional[str] = None, 61 | config: Optional[argparse.Namespace] = None, 62 | monitor_gym: bool = True, 63 | ) -> None: 64 | super().__init__(train_interval, test_interval, update_interval, 65 | info_interval) 66 | self.last_save_step = -1 67 | self.save_interval = save_interval 68 | self.write_flush = write_flush 69 | self.restored = False 70 | if project is None: 71 | project = os.getenv('WANDB_PROJECT', 'marltoolkit') 72 | 73 | self.wandb_run = wandb.init( 74 | dir=dir, 75 | project=project, 76 | name=name, 77 | id=run_id, 78 | resume='allow', 79 | entity=entity, 80 | sync_tensorboard=True, 81 | monitor_gym=monitor_gym, 82 | config=config, # type: ignore 83 | ) if not wandb.run else wandb.run 84 | self.wandb_run._label(repo='marltoolkit') # type: ignore 85 | self.tensorboard_logger: Optional[TensorboardLogger] = None 86 | 87 | def load(self, writer: SummaryWriter) -> None: 88 | self.writer = writer 89 | self.tensorboard_logger = TensorboardLogger( 90 | writer, 91 | self.train_interval, 92 | self.test_interval, 93 | self.update_interval, 94 | self.save_interval, 95 | self.write_flush, 96 | ) 97 | 98 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None: 99 | if self.tensorboard_logger is None: 100 | raise Exception( 101 | '`logger` needs to load the Tensorboard Writer before ' 102 | 'writing data. Try `logger.load(SummaryWriter(log_path))`') 103 | else: 104 | self.tensorboard_logger.write(step_type, step, data) 105 | 106 | def save_data( 107 | self, 108 | epoch: int, 109 | env_step: int, 110 | gradient_step: int, 111 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None, 112 | ) -> None: 113 | """Use writer to log metadata when calling ``save_checkpoint_fn`` in 114 | trainer. 115 | 116 | :param epoch: the epoch in trainer. 117 | :param env_step: the env_step in trainer. 118 | :param gradient_step: the gradient_step in trainer. 119 | :param function save_checkpoint_fn: a hook defined by user, see trainer 120 | documentation for detail. 121 | """ 122 | if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval: 123 | self.last_save_step = epoch 124 | checkpoint_path = save_checkpoint_fn(epoch, env_step, 125 | gradient_step) 126 | 127 | checkpoint_artifact = wandb.Artifact( 128 | 'run_' + self.wandb_run.id + '_checkpoint', # type: ignore 129 | type='model', 130 | metadata={ 131 | 'save/epoch': epoch, 132 | 'save/env_step': env_step, 133 | 'save/gradient_step': gradient_step, 134 | 'checkpoint_path': str(checkpoint_path), 135 | }) 136 | checkpoint_artifact.add_file(str(checkpoint_path)) 137 | self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore 138 | 139 | def restore_data(self) -> Tuple[int, int, int]: 140 | checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore 141 | f'run_{self.wandb_run.id}_checkpoint:latest' # type: ignore 142 | ) 143 | assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist" 144 | 145 | checkpoint_artifact.download( 146 | os.path.dirname(checkpoint_artifact.metadata['checkpoint_path'])) 147 | 148 | try: # epoch / gradient_step 149 | epoch = checkpoint_artifact.metadata['save/epoch'] 150 | self.last_save_step = self.last_log_test_step = epoch 151 | gradient_step = checkpoint_artifact.metadata['save/gradient_step'] 152 | self.last_log_update_step = gradient_step 153 | except KeyError: 154 | epoch, gradient_step = 0, 0 155 | try: # offline trainer doesn't have env_step 156 | env_step = checkpoint_artifact.metadata['save/env_step'] 157 | self.last_log_train_step = env_step 158 | except KeyError: 159 | env_step = 0 160 | return epoch, env_step, gradient_step 161 | -------------------------------------------------------------------------------- /marltoolkit/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Author: jianzhnie@126.com 3 | Date: 2022-09-01 15:17:32 4 | LastEditors: jianzhnie@126.com 5 | LastEditTime: 2022-09-01 15:17:35 6 | Description: 7 | 8 | Copyright (c) 2022 by jianzhnie, All Rights Reserved. 9 | ''' 10 | 11 | from collections import Counter 12 | from typing import List 13 | 14 | import six 15 | 16 | __all__ = ['PiecewiseScheduler', 'LinearDecayScheduler'] 17 | 18 | 19 | class PiecewiseScheduler(object): 20 | """Set hyper parameters by a predefined step-based scheduler.""" 21 | 22 | def __init__(self, scheduler_list): 23 | """Piecewise scheduler of hyper parameter. 24 | 25 | Args: 26 | scheduler_list: list of (step, value) pair. E.g. [(0, 0.001), (10000, 0.0005)] 27 | """ 28 | assert len(scheduler_list) > 0 29 | 30 | for i in six.moves.range(len(scheduler_list) - 1): 31 | assert scheduler_list[i][0] < scheduler_list[i + 1][0], \ 32 | 'step of scheduler_list should be incremental.' 33 | 34 | self.scheduler_list = scheduler_list 35 | 36 | self.cur_index = 0 37 | self.cur_step = 0 38 | self.cur_value = self.scheduler_list[0][1] 39 | 40 | self.scheduler_num = len(self.scheduler_list) 41 | 42 | def step(self, step_num=1): 43 | """Step step_num and fetch value according to following rule: 44 | 45 | Given scheduler_list: [(step_0, value_0), (step_1, value_1), ..., (step_N, value_N)], 46 | function will return value_K which satisfying self.cur_step >= step_K and self.cur_step < step_K+1 47 | 48 | Args: 49 | step_num (int): number of steps (default: 1) 50 | """ 51 | assert isinstance(step_num, int) and step_num >= 1 52 | self.cur_step += step_num 53 | 54 | if self.cur_index < self.scheduler_num - 1: 55 | if self.cur_step >= self.scheduler_list[self.cur_index + 1][0]: 56 | self.cur_index += 1 57 | self.cur_value = self.scheduler_list[self.cur_index][1] 58 | 59 | return self.cur_value 60 | 61 | 62 | class LinearDecayScheduler(object): 63 | """Set hyper parameters by a step-based scheduler with linear decay 64 | values.""" 65 | 66 | def __init__(self, start_value, max_steps): 67 | """Linear decay scheduler of hyper parameter. Decay value linearly 68 | until 0. 69 | 70 | Args: 71 | start_value (float): start value 72 | max_steps (int): maximum steps 73 | """ 74 | assert max_steps > 0 75 | self.cur_step = 0 76 | self.max_steps = max_steps 77 | self.start_value = start_value 78 | 79 | def step(self, step_num=1): 80 | """Step step_num and fetch value according to following rule: 81 | 82 | return_value = start_value * (1.0 - (cur_steps / max_steps)) 83 | 84 | Args: 85 | step_num (int): number of steps (default: 1) 86 | 87 | Returns: 88 | value (float): current value 89 | """ 90 | assert isinstance(step_num, int) and step_num >= 1 91 | self.cur_step = min(self.cur_step + step_num, self.max_steps) 92 | 93 | value = self.start_value * (1.0 - 94 | ((self.cur_step * 1.0) / self.max_steps)) 95 | 96 | return value 97 | 98 | 99 | class MultiStepScheduler(object): 100 | """step learning rate scheduler.""" 101 | 102 | def __init__(self, 103 | start_value: float, 104 | max_steps: int, 105 | milestones: List = None, 106 | decay_factor: float = 0.1): 107 | assert max_steps > 0 108 | assert isinstance(decay_factor, float) 109 | assert decay_factor > 0 and decay_factor < 1 110 | self.milestones = Counter(milestones) 111 | self.cur_value = start_value 112 | self.cur_step = 0 113 | self.max_steps = max_steps 114 | self.decay_factor = decay_factor 115 | 116 | def step(self, step_num=1): 117 | assert isinstance(step_num, int) and step_num >= 1 118 | self.cur_step = min(self.cur_step + step_num, self.max_steps) 119 | 120 | if self.cur_step not in self.milestones: 121 | return self.cur_value 122 | else: 123 | self.cur_value *= self.decay_factor**self.milestones[self.cur_step] 124 | 125 | return self.cur_value 126 | 127 | 128 | if __name__ == '__main__': 129 | scheduler = MultiStepScheduler(100, 100, [50, 80], 0.5) 130 | for i in range(101): 131 | value = scheduler.step() 132 | print(value) 133 | -------------------------------------------------------------------------------- /marltoolkit/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def huber_loss(e, d): 7 | a = (abs(e) <= d).float() 8 | b = (abs(e) > d).float() 9 | return a * e**2 / 2 + b * d * (abs(e) - d / 2) 10 | 11 | 12 | def mse_loss(e): 13 | return e**2 / 2 14 | 15 | 16 | def get_gard_norm(it): 17 | sum_grad = 0 18 | for x in it: 19 | if x.grad is None: 20 | continue 21 | sum_grad += x.grad.norm()**2 22 | return math.sqrt(sum_grad) 23 | 24 | 25 | def hard_target_update(src: nn.Module, tgt: nn.Module) -> None: 26 | """Hard update model parameters. 27 | 28 | Params 29 | ====== 30 | src: PyTorch model (weights will be copied from) 31 | tgt: PyTorch model (weights will be copied to) 32 | """ 33 | 34 | tgt.load_state_dict(src.state_dict()) 35 | 36 | 37 | def soft_target_update(src: nn.Module, tgt: nn.Module, tau=0.05) -> None: 38 | """Soft update model parameters. 39 | 40 | θ_target = τ*θ_local + (1 - τ)*θ_target 41 | 42 | Params 43 | ====== 44 | src: PyTorch model (weights will be copied from) 45 | tgt: PyTorch model (weights will be copied to) 46 | tau (float): interpolation parameter 47 | """ 48 | for src_param, tgt_param in zip(src.parameters(), tgt.parameters()): 49 | tgt_param.data.copy_(tau * src_param.data + 50 | (1.0 - tau) * tgt_param.data) 51 | 52 | 53 | def check_model_method(model, method, algo): 54 | """check method existence for input model to algo. 55 | 56 | Args: 57 | model(nn.Model): model for checking 58 | method(str): method name 59 | algo(str): algorithm name 60 | 61 | Raises: 62 | AssertionError: if method is not implemented in model 63 | """ 64 | if method == 'forward': 65 | # check if forward is overridden by the subclass 66 | assert callable( 67 | getattr(model, 'forward', 68 | None)), 'forward should be a function in model class' 69 | assert model.forward.__func__ is not super( 70 | model.__class__, model 71 | ).forward.__func__, "{}'s model needs to implement forward method. \n".format( 72 | algo) 73 | else: 74 | # check if the specified method is implemented 75 | assert hasattr(model, method) and callable(getattr( 76 | model, method, 77 | None)), "{}'s model needs to implement {} method. \n".format( 78 | algo, method) 79 | -------------------------------------------------------------------------------- /marltoolkit/utils/progressbar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import sys 3 | from collections.abc import Iterable 4 | from multiprocessing import Pool 5 | from shutil import get_terminal_size 6 | 7 | from .timer import Timer 8 | 9 | 10 | class ProgressBar: 11 | """A progress bar which can print the progress.""" 12 | 13 | def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): 14 | self.task_num = task_num 15 | self.bar_width = bar_width 16 | self.completed = 0 17 | self.file = file 18 | if start: 19 | self.start() 20 | 21 | @property 22 | def terminal_width(self): 23 | width, _ = get_terminal_size() 24 | return width 25 | 26 | def start(self): 27 | if self.task_num > 0: 28 | self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' 29 | 'elapsed: 0s, ETA:') 30 | else: 31 | self.file.write('completed: 0, elapsed: 0s') 32 | self.file.flush() 33 | self.timer = Timer() 34 | 35 | def update(self, num_tasks=1): 36 | assert num_tasks > 0 37 | self.completed += num_tasks 38 | elapsed = self.timer.since_start() 39 | if elapsed > 0: 40 | fps = self.completed / elapsed 41 | else: 42 | fps = float('inf') 43 | if self.task_num > 0: 44 | percentage = self.completed / float(self.task_num) 45 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 46 | msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ 47 | f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ 48 | f'ETA: {eta:5}s' 49 | 50 | bar_width = min(self.bar_width, 51 | int(self.terminal_width - len(msg)) + 2, 52 | int(self.terminal_width * 0.6)) 53 | bar_width = max(2, bar_width) 54 | mark_width = int(bar_width * percentage) 55 | bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) 56 | self.file.write(msg.format(bar_chars)) 57 | else: 58 | self.file.write( 59 | f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' 60 | f' {fps:.1f} tasks/s') 61 | self.file.flush() 62 | 63 | 64 | def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): 65 | """Track the progress of tasks execution with a progress bar. 66 | 67 | Tasks are done with a simple for-loop. 68 | 69 | Args: 70 | func (callable): The function to be applied to each task. 71 | tasks (list or tuple[Iterable, int]): A list of tasks or 72 | (tasks, total num). 73 | bar_width (int): Width of progress bar. 74 | 75 | Returns: 76 | list: The task results. 77 | """ 78 | if isinstance(tasks, tuple): 79 | assert len(tasks) == 2 80 | assert isinstance(tasks[0], Iterable) 81 | assert isinstance(tasks[1], int) 82 | task_num = tasks[1] 83 | tasks = tasks[0] 84 | elif isinstance(tasks, Iterable): 85 | task_num = len(tasks) 86 | else: 87 | raise TypeError( 88 | '"tasks" must be an iterable object or a (iterator, int) tuple') 89 | prog_bar = ProgressBar(task_num, bar_width, file=file) 90 | results = [] 91 | for task in tasks: 92 | results.append(func(task, **kwargs)) 93 | prog_bar.update() 94 | prog_bar.file.write('\n') 95 | return results 96 | 97 | 98 | def init_pool(process_num, initializer=None, initargs=None): 99 | if initializer is None: 100 | return Pool(process_num) 101 | elif initargs is None: 102 | return Pool(process_num, initializer) 103 | else: 104 | if not isinstance(initargs, tuple): 105 | raise TypeError('"initargs" must be a tuple') 106 | return Pool(process_num, initializer, initargs) 107 | 108 | 109 | def track_parallel_progress(func, 110 | tasks, 111 | nproc, 112 | initializer=None, 113 | initargs=None, 114 | bar_width=50, 115 | chunksize=1, 116 | skip_first=False, 117 | keep_order=True, 118 | file=sys.stdout): 119 | """Track the progress of parallel task execution with a progress bar. 120 | 121 | The built-in :mod:`multiprocessing` module is used for process pools and 122 | tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`. 123 | 124 | Args: 125 | func (callable): The function to be applied to each task. 126 | tasks (list or tuple[Iterable, int]): A list of tasks or 127 | (tasks, total num). 128 | nproc (int): Process (worker) number. 129 | initializer (None or callable): Refer to :class:`multiprocessing.Pool` 130 | for details. 131 | initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for 132 | details. 133 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details. 134 | bar_width (int): Width of progress bar. 135 | skip_first (bool): Whether to skip the first sample for each worker 136 | when estimating fps, since the initialization step may takes 137 | longer. 138 | keep_order (bool): If True, :func:`Pool.imap` is used, otherwise 139 | :func:`Pool.imap_unordered` is used. 140 | 141 | Returns: 142 | list: The task results. 143 | """ 144 | if isinstance(tasks, tuple): 145 | assert len(tasks) == 2 146 | assert isinstance(tasks[0], Iterable) 147 | assert isinstance(tasks[1], int) 148 | task_num = tasks[1] 149 | tasks = tasks[0] 150 | elif isinstance(tasks, Iterable): 151 | task_num = len(tasks) 152 | else: 153 | raise TypeError( 154 | '"tasks" must be an iterable object or a (iterator, int) tuple') 155 | pool = init_pool(nproc, initializer, initargs) 156 | start = not skip_first 157 | task_num -= nproc * chunksize * int(skip_first) 158 | prog_bar = ProgressBar(task_num, bar_width, start, file=file) 159 | results = [] 160 | if keep_order: 161 | gen = pool.imap(func, tasks, chunksize) 162 | else: 163 | gen = pool.imap_unordered(func, tasks, chunksize) 164 | for result in gen: 165 | results.append(result) 166 | if skip_first: 167 | if len(results) < nproc * chunksize: 168 | continue 169 | elif len(results) == nproc * chunksize: 170 | prog_bar.start() 171 | continue 172 | prog_bar.update() 173 | prog_bar.file.write('\n') 174 | pool.close() 175 | pool.join() 176 | return results 177 | 178 | 179 | def track_iter_progress(tasks, bar_width=50, file=sys.stdout): 180 | """Track the progress of tasks iteration or enumeration with a progress 181 | bar. 182 | 183 | Tasks are yielded with a simple for-loop. 184 | 185 | Args: 186 | tasks (list or tuple[Iterable, int]): A list of tasks or 187 | (tasks, total num). 188 | bar_width (int): Width of progress bar. 189 | 190 | Yields: 191 | list: The task results. 192 | """ 193 | if isinstance(tasks, tuple): 194 | assert len(tasks) == 2 195 | assert isinstance(tasks[0], Iterable) 196 | assert isinstance(tasks[1], int) 197 | task_num = tasks[1] 198 | tasks = tasks[0] 199 | elif isinstance(tasks, Iterable): 200 | task_num = len(tasks) 201 | else: 202 | raise TypeError( 203 | '"tasks" must be an iterable object or a (iterator, int) tuple') 204 | prog_bar = ProgressBar(task_num, bar_width, file=file) 205 | for task in tasks: 206 | yield task 207 | prog_bar.update() 208 | prog_bar.file.write('\n') 209 | -------------------------------------------------------------------------------- /marltoolkit/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from time import time 3 | 4 | 5 | class TimerError(Exception): 6 | 7 | def __init__(self, message): 8 | self.message = message 9 | super().__init__(message) 10 | 11 | 12 | class Timer: 13 | """A flexible Timer class. 14 | 15 | Examples: 16 | >>> import time 17 | >>> import mmcv 18 | >>> with mmcv.Timer(): 19 | >>> # simulate a code block that will run for 1s 20 | >>> time.sleep(1) 21 | 1.000 22 | >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'): 23 | >>> # simulate a code block that will run for 1s 24 | >>> time.sleep(1) 25 | it takes 1.0 seconds 26 | >>> timer = mmcv.Timer() 27 | >>> time.sleep(0.5) 28 | >>> print(timer.since_start()) 29 | 0.500 30 | >>> time.sleep(0.5) 31 | >>> print(timer.since_last_check()) 32 | 0.500 33 | >>> print(timer.since_start()) 34 | 1.000 35 | """ 36 | 37 | def __init__(self, start=True, print_tmpl=None): 38 | self._is_running = False 39 | self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}' 40 | if start: 41 | self.start() 42 | 43 | @property 44 | def is_running(self): 45 | """bool: indicate whether the timer is running""" 46 | return self._is_running 47 | 48 | def __enter__(self): 49 | self.start() 50 | return self 51 | 52 | def __exit__(self, type, value, traceback): 53 | print(self.print_tmpl.format(self.since_last_check())) 54 | self._is_running = False 55 | 56 | def start(self): 57 | """Start the timer.""" 58 | if not self._is_running: 59 | self._t_start = time() 60 | self._is_running = True 61 | self._t_last = time() 62 | 63 | def since_start(self): 64 | """Total time since the timer is started. 65 | 66 | Returns: 67 | float: Time in seconds. 68 | """ 69 | if not self._is_running: 70 | raise TimerError('timer is not running') 71 | self._t_last = time() 72 | return self._t_last - self._t_start 73 | 74 | def since_last_check(self): 75 | """Time since the last checking. 76 | 77 | Either :func:`since_start` or :func:`since_last_check` is a checking 78 | operation. 79 | 80 | Returns: 81 | float: Time in seconds. 82 | """ 83 | if not self._is_running: 84 | raise TimerError('timer is not running') 85 | dur = time() - self._t_last 86 | self._t_last = time() 87 | return dur 88 | 89 | 90 | _g_timers = {} # global timers 91 | 92 | 93 | def check_time(timer_id): 94 | """Add check points in a single line. 95 | 96 | This method is suitable for running a task on a list of items. A timer will 97 | be registered when the method is called for the first time. 98 | 99 | Examples: 100 | >>> import time 101 | >>> import mmcv 102 | >>> for i in range(1, 6): 103 | >>> # simulate a code block 104 | >>> time.sleep(i) 105 | >>> mmcv.check_time('task1') 106 | 2.000 107 | 3.000 108 | 4.000 109 | 5.000 110 | 111 | Args: 112 | str: Timer identifier. 113 | """ 114 | if timer_id not in _g_timers: 115 | _g_timers[timer_id] = Timer() 116 | return 0 117 | else: 118 | return _g_timers[timer_id].since_last_check() 119 | -------------------------------------------------------------------------------- /marltoolkit/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class OneHotTransform(object): 5 | """One hot transform, convert index to one hot vector. 6 | 7 | Args: 8 | out_dim (int): The dimension of one hot vector. 9 | """ 10 | 11 | def __init__(self, out_dim: int) -> None: 12 | self.out_dim = out_dim 13 | 14 | def __call__(self, index: int) -> np.ndarray: 15 | assert index < self.out_dim 16 | one_hot_id = np.zeros([self.out_dim]) 17 | one_hot_id[index] = 1.0 18 | return one_hot_id 19 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/oxwhirl/smac.git 2 | git+https://github.com/oxwhirl/smacv2.git 3 | numpy 4 | protobuf 5 | pygame 6 | pyglet 7 | PyYAML 8 | scipy 9 | six 10 | tensorboard 11 | tensorboard-logger 12 | tensorboardX 13 | torch 14 | torchvision 15 | tqdm 16 | wandb 17 | -------------------------------------------------------------------------------- /scripts/main_coma.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | sys.path.append('../') 10 | 11 | from configs.arguments import get_common_args 12 | from configs.coma_config import ComaConfig 13 | from marltoolkit.agents.coma_agent import ComaAgent 14 | from marltoolkit.data import ReplayBuffer 15 | from marltoolkit.envs.smacv1.smac_env import SMACWrapperEnv 16 | from marltoolkit.modules.actors import RNNActorModel 17 | from marltoolkit.modules.critics.coma import MLPCriticModel 18 | from marltoolkit.runners.episode_runner import (run_eval_episode, 19 | run_train_episode) 20 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger, 21 | get_outdir, get_root_logger) 22 | 23 | 24 | def main(): 25 | """Main function for running the QMix algorithm. 26 | 27 | This function initializes the necessary configurations, environment, logger, models, and agents. 28 | It then runs training episodes and evaluates the agent's performance periodically. 29 | 30 | Returns: 31 | None 32 | """ 33 | coma_config = ComaConfig() 34 | common_args = get_common_args() 35 | args = argparse.Namespace(**vars(common_args), **vars(coma_config)) 36 | device = (torch.device('cuda') if torch.cuda.is_available() and args.cuda 37 | else torch.device('cpu')) 38 | 39 | env = SMACWrapperEnv(map_name=args.scenario, difficulty=args.difficulty) 40 | args.episode_limit = env.episode_limit 41 | args.obs_dim = env.obs_dim 42 | args.obs_shape = env.obs_shape 43 | args.state_dim = env.state_dim 44 | args.state_shape = env.state_shape 45 | args.num_agents = env.num_agents 46 | args.n_actions = env.n_actions 47 | args.action_shape = env.action_shape 48 | args.reward_shape = env.reward_shape 49 | args.done_shape = env.done_shape 50 | args.device = device 51 | 52 | # init the logger before other steps 53 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 54 | # log 55 | log_name = os.path.join(args.project, args.scenario, args.algo_name, 56 | timestamp).replace(os.path.sep, '_') 57 | log_path = os.path.join(args.log_dir, args.project, args.scenario, 58 | args.algo_name) 59 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir') 60 | log_file = os.path.join(log_path, log_name + '.log') 61 | text_logger = get_root_logger(log_file=log_file, log_level='INFO') 62 | 63 | if args.logger == 'wandb': 64 | logger = WandbLogger( 65 | train_interval=args.train_log_interval, 66 | test_interval=args.test_log_interval, 67 | update_interval=args.train_log_interval, 68 | project=args.project, 69 | name=log_name, 70 | save_interval=1, 71 | config=args, 72 | ) 73 | writer = SummaryWriter(tensorboard_log_path) 74 | writer.add_text('args', str(args)) 75 | if args.logger == 'tensorboard': 76 | logger = TensorboardLogger(writer) 77 | else: # wandb 78 | logger.load(writer) 79 | 80 | rpm = ReplayBuffer( 81 | max_size=args.replay_buffer_size, 82 | num_agents=args.num_agents, 83 | episode_limit=args.episode_limit, 84 | obs_space=args.obs_shape, 85 | state_space=args.state_shape, 86 | action_space=args.action_shape, 87 | reward_space=args.reward_shape, 88 | done_space=args.done_shape, 89 | device=device, 90 | ) 91 | 92 | actor_model = RNNActorModel( 93 | input_dim=args.obs_dim, 94 | rnn_hidden_dim=args.rnn_hidden_dim, 95 | n_actions=args.n_actions, 96 | ) 97 | 98 | critic_model = MLPCriticModel( 99 | input_dim=args.obs_dim, 100 | hidden_dim=args.fc_hidden_dim, 101 | output_dim=args.n_actions, 102 | ) 103 | agent = ComaAgent( 104 | actor_model=actor_model, 105 | critic_model=critic_model, 106 | num_agents=args.num_agents, 107 | double_q=args.double_q, 108 | total_steps=args.total_steps, 109 | gamma=args.gamma, 110 | actor_lr=args.learning_rate, 111 | critic_lr=args.learning_rate, 112 | egreedy_exploration=args.egreedy_exploration, 113 | min_exploration=args.min_exploration, 114 | target_update_interval=args.target_update_interval, 115 | learner_update_freq=args.learner_update_freq, 116 | clip_grad_norm=args.clip_grad_norm, 117 | device=args.device, 118 | ) 119 | 120 | progress_bar = ProgressBar(args.memory_warmup_size) 121 | while rpm.size() < args.memory_warmup_size: 122 | run_train_episode(env, agent, rpm, args) 123 | progress_bar.update() 124 | 125 | steps_cnt = 0 126 | episode_cnt = 0 127 | progress_bar = ProgressBar(args.total_steps) 128 | while steps_cnt < args.total_steps: 129 | train_res_dict = run_train_episode(env, agent, rpm, args) 130 | # update episodes and steps 131 | episode_cnt += 1 132 | steps_cnt += train_res_dict['episode_step'] 133 | 134 | # learning rate decay 135 | agent.learning_rate = max( 136 | agent.lr_scheduler.step(train_res_dict['episode_step']), 137 | agent.min_learning_rate, 138 | ) 139 | 140 | train_res_dict.update({ 141 | 'exploration': agent.exploration, 142 | 'learning_rate': agent.learning_rate, 143 | 'replay_max_size': rpm.size(), 144 | 'target_update_count': agent.target_update_count, 145 | }) 146 | if episode_cnt % args.train_log_interval == 0: 147 | text_logger.info( 148 | '[Train], episode: {}, train_episode_step: {}, train_win_rate: {:.2f}, train_reward: {:.2f}' 149 | .format( 150 | episode_cnt, 151 | train_res_dict['episode_step'], 152 | train_res_dict['win_rate'], 153 | train_res_dict['episode_reward'], 154 | )) 155 | logger.log_train_data(train_res_dict, steps_cnt) 156 | 157 | if episode_cnt % args.test_log_interval == 0: 158 | eval_res_dict = run_eval_episode(env, agent, args=args) 159 | text_logger.info( 160 | '[Eval], episode: {}, eval_episode_step:{:.2f}, eval_win_rate: {:.2f}, eval_reward: {:.2f}' 161 | .format( 162 | episode_cnt, 163 | eval_res_dict['episode_step'], 164 | eval_res_dict['win_rate'], 165 | eval_res_dict['episode_reward'], 166 | )) 167 | logger.log_test_data(eval_res_dict, steps_cnt) 168 | 169 | progress_bar.update(train_res_dict['episode_step']) 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /scripts/main_idqn.py: -------------------------------------------------------------------------------- 1 | """This script contains the main function for running the IDQN (Independent 2 | Deep Q-Network) algorithm in a StarCraft II environment.""" 3 | 4 | import argparse 5 | import os 6 | import sys 7 | import time 8 | 9 | import torch 10 | from smac.env import StarCraft2Env 11 | from torch.utils.tensorboard import SummaryWriter 12 | 13 | sys.path.append('../') 14 | from configs.arguments import get_common_args 15 | from configs.idqn_config import IDQNConfig 16 | from marltoolkit.agents import IDQNAgent 17 | from marltoolkit.data import ReplayBuffer 18 | from marltoolkit.envs.smacv1.env_wrapper import SC2EnvWrapper 19 | from marltoolkit.modules.actors import RNNActorModel 20 | from marltoolkit.runners.episode_runner import (run_eval_episode, 21 | run_train_episode) 22 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger, 23 | get_outdir, get_root_logger) 24 | 25 | 26 | def main(): 27 | """The main function for running the IDQN algorithm. 28 | 29 | It initializes the necessary components such as the environment, agent, 30 | logger, and replay buffer. Then, it performs training episodes and 31 | evaluation episodes, logging the results at specified intervals. 32 | """ 33 | 34 | idn_config = IDQNConfig() 35 | common_args = get_common_args() 36 | args = argparse.Namespace(**vars(common_args), **vars(idn_config)) 37 | device = (torch.device('cuda') if torch.cuda.is_available() and args.cuda 38 | else torch.device('cpu')) 39 | 40 | env = StarCraft2Env(map_name=args.scenario, difficulty=args.difficulty) 41 | 42 | env = SC2EnvWrapper(env) 43 | args.episode_limit = env.episode_limit 44 | args.obs_shape = env.obs_shape 45 | args.state_shape = env.state_shape 46 | args.num_agents = env.num_agents 47 | args.n_actions = env.n_actions 48 | args.device = device 49 | 50 | # init the logger before other steps 51 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 52 | # log 53 | log_name = os.path.join(args.project, args.scenario, args.algo_name, 54 | timestamp).replace(os.path.sep, '_') 55 | log_path = os.path.join(args.log_dir, args.project, args.scenario, 56 | args.algo_name) 57 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir') 58 | log_file = os.path.join(log_path, log_name + '.log') 59 | text_logger = get_root_logger(log_file=log_file, log_level='INFO') 60 | 61 | if args.logger == 'wandb': 62 | logger = WandbLogger( 63 | train_interval=args.train_log_interval, 64 | test_interval=args.test_log_interval, 65 | update_interval=args.train_log_interval, 66 | project=args.project, 67 | name=log_name, 68 | save_interval=1, 69 | config=args, 70 | ) 71 | writer = SummaryWriter(tensorboard_log_path) 72 | writer.add_text('args', str(args)) 73 | if args.logger == 'tensorboard': 74 | logger = TensorboardLogger(writer) 75 | else: # wandb 76 | logger.load(writer) 77 | 78 | rpm = ReplayBuffer( 79 | max_size=args.replay_buffer_size, 80 | num_agents=args.num_agents, 81 | num_actions=args.n_actions, 82 | episode_limit=args.episode_limit, 83 | obs_shape=args.obs_shape, 84 | state_shape=args.state_shape, 85 | device=args.device, 86 | ) 87 | actor_model = RNNActorModel( 88 | input_dim=args.obs_shape, 89 | rnn_hidden_dim=args.rnn_hidden_dim, 90 | n_actions=args.n_actions, 91 | ) 92 | 93 | marl_agent = IDQNAgent( 94 | actor_model=actor_model, 95 | mixer_model=None, 96 | num_agents=args.num_agents, 97 | double_q=args.double_q, 98 | total_steps=args.total_steps, 99 | gamma=args.gamma, 100 | learning_rate=args.learning_rate, 101 | min_learning_rate=args.min_learning_rate, 102 | egreedy_exploration=args.egreedy_exploration, 103 | min_exploration=args.min_exploration, 104 | target_update_interval=args.target_update_interval, 105 | learner_update_freq=args.learner_update_freq, 106 | clip_grad_norm=args.clip_grad_norm, 107 | device=args.device, 108 | ) 109 | 110 | progress_bar = ProgressBar(args.memory_warmup_size) 111 | while rpm.size() < args.memory_warmup_size: 112 | run_train_episode(env, marl_agent, rpm, args) 113 | progress_bar.update() 114 | 115 | steps_cnt = 0 116 | episode_cnt = 0 117 | progress_bar = ProgressBar(args.total_steps) 118 | while steps_cnt < args.total_steps: 119 | ( 120 | episode_reward, 121 | episode_step, 122 | is_win, 123 | mean_loss, 124 | mean_td_error, 125 | ) = run_train_episode(env, marl_agent, rpm, args) 126 | # update episodes and steps 127 | episode_cnt += 1 128 | steps_cnt += episode_step 129 | 130 | # learning rate decay 131 | marl_agent.learning_rate = max( 132 | marl_agent.lr_scheduler.step(episode_step), 133 | marl_agent.min_learning_rate) 134 | 135 | train_results = { 136 | 'env_step': episode_step, 137 | 'rewards': episode_reward, 138 | 'win_rate': is_win, 139 | 'mean_loss': mean_loss, 140 | 'mean_td_error': mean_td_error, 141 | 'exploration': marl_agent.exploration, 142 | 'learning_rate': marl_agent.learning_rate, 143 | 'replay_buffer_size': rpm.size(), 144 | 'target_update_count': marl_agent.target_update_count, 145 | } 146 | if episode_cnt % args.train_log_interval == 0: 147 | text_logger.info( 148 | '[Train], episode: {}, train_win_rate: {:.2f}, train_reward: {:.2f}' 149 | .format(episode_cnt, is_win, episode_reward)) 150 | logger.log_train_data(train_results, steps_cnt) 151 | 152 | if episode_cnt % args.test_log_interval == 0: 153 | eval_rewards, eval_steps, eval_win_rate = run_eval_episode( 154 | env, marl_agent, num_eval_episodes=5) 155 | text_logger.info( 156 | '[Eval], episode: {}, eval_win_rate: {:.2f}, eval_rewards: {:.2f}' 157 | .format(episode_cnt, eval_win_rate, eval_rewards)) 158 | 159 | test_results = { 160 | 'env_step': eval_steps, 161 | 'rewards': eval_rewards, 162 | 'win_rate': eval_win_rate, 163 | } 164 | logger.log_test_data(test_results, steps_cnt) 165 | 166 | progress_bar.update(episode_step) 167 | 168 | 169 | if __name__ == '__main__': 170 | main() 171 | -------------------------------------------------------------------------------- /scripts/main_mappo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | 5 | import torch 6 | from torch.utils.tensorboard import SummaryWriter 7 | 8 | sys.path.append('../') 9 | 10 | from configs.arguments import get_common_args 11 | from marltoolkit.agents.mappo_agent import MAPPOAgent 12 | from marltoolkit.data.shared_buffer import SharedReplayBuffer 13 | from marltoolkit.envs.smacv1 import SMACWrapperEnv 14 | from marltoolkit.runners.episode_runner import (run_eval_episode, 15 | run_train_episode) 16 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger, 17 | get_outdir, get_root_logger) 18 | 19 | 20 | def main() -> None: 21 | """Main function for running the QMix algorithm. 22 | 23 | This function initializes the necessary configurations, environment, logger, models, and agents. 24 | It then runs training episodes and evaluates the agent's performance periodically. 25 | 26 | Returns: 27 | None 28 | """ 29 | args = get_common_args() 30 | if args.algorithm_name == 'rmappo': 31 | args.use_recurrent_policy = True 32 | args.use_naive_recurrent_policy = False 33 | 34 | if args.algorithm_name == 'rmappo': 35 | print( 36 | 'u are choosing to use rmappo, we set use_recurrent_policy to be True' 37 | ) 38 | args.use_recurrent_policy = True 39 | args.use_naive_recurrent_policy = False 40 | elif (args.algorithm_name == 'mappo' or args.algorithm_name == 'mat' 41 | or args.algorithm_name == 'mat_dec'): 42 | assert (args.use_recurrent_policy is False 43 | and args.use_naive_recurrent_policy is False 44 | ), 'check recurrent policy!' 45 | print( 46 | 'U are choosing to use mappo, we set use_recurrent_policy & use_naive_recurrent_policy to be False' 47 | ) 48 | args.use_recurrent_policy = False 49 | args.use_naive_recurrent_policy = False 50 | 51 | elif args.algorithm_name == 'ippo': 52 | print( 53 | 'u are choosing to use ippo, we set use_centralized_v to be False') 54 | args.use_centralized_v = False 55 | elif args.algorithm_name == 'happo' or args.algorithm_name == 'hatrpo': 56 | # can or cannot use recurrent network? 57 | print('using', args.algorithm_name, 'without recurrent network') 58 | args.use_recurrent_policy = False 59 | args.use_naive_recurrent_policy = False 60 | else: 61 | raise NotImplementedError 62 | 63 | if args.algorithm_name == 'mat_dec': 64 | args.dec_actor = True 65 | args.share_actor = True 66 | 67 | # cuda 68 | if args.cuda and torch.cuda.is_available(): 69 | print('choose to use gpu...') 70 | args.device = torch.device('cuda') 71 | torch.set_num_threads(args.n_training_threads) 72 | if args.cuda_deterministic: 73 | torch.backends.cudnn.benchmark = False 74 | torch.backends.cudnn.deterministic = True 75 | else: 76 | print('choose to use cpu...') 77 | args.device = torch.device('cpu') 78 | torch.set_num_threads(args.n_training_threads) 79 | 80 | # Environment 81 | env = SMACWrapperEnv(map_name=args.scenario, args=args) 82 | args.episode_limit = env.episode_limit 83 | args.obs_dim = env.obs_dim 84 | args.obs_shape = env.obs_shape 85 | args.state_dim = env.state_dim 86 | args.state_shape = env.state_shape 87 | args.num_agents = env.num_agents 88 | args.n_actions = env.n_actions 89 | args.action_shape = env.action_shape 90 | args.reward_shape = env.reward_shape 91 | args.done_shape = env.done_shape 92 | 93 | # init the logger before other steps 94 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 95 | # log 96 | log_name = os.path.join(args.project, args.scenario, args.algo_name, 97 | timestamp).replace(os.path.sep, '_') 98 | log_path = os.path.join(args.log_dir, args.project, args.scenario, 99 | args.algo_name) 100 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir') 101 | log_file = os.path.join(log_path, log_name + '.log') 102 | text_logger = get_root_logger(log_file=log_file, log_level='INFO') 103 | 104 | if args.logger == 'wandb': 105 | logger = WandbLogger( 106 | train_interval=args.train_log_interval, 107 | test_interval=args.test_log_interval, 108 | update_interval=args.train_log_interval, 109 | project=args.project, 110 | name=log_name, 111 | save_interval=1, 112 | config=args, 113 | ) 114 | writer = SummaryWriter(tensorboard_log_path) 115 | writer.add_text('args', str(args)) 116 | if args.logger == 'tensorboard': 117 | logger = TensorboardLogger(writer) 118 | else: # wandb 119 | logger.load(writer) 120 | 121 | args.obs_shape = env.get_actor_input_shape() 122 | 123 | # policy network 124 | agent = MAPPOAgent(args) 125 | # buffer 126 | rpm = SharedReplayBuffer( 127 | args.num_envs, 128 | args.num_agents, 129 | args.episode_limit, 130 | args.obs_shape, 131 | args.state_shape, 132 | args.action_shape, 133 | args.reward_shape, 134 | args.done_shape, 135 | args, 136 | ) 137 | 138 | progress_bar = ProgressBar(args.memory_warmup_size) 139 | while rpm.size() < args.memory_warmup_size: 140 | run_train_episode(env, agent, rpm, args) 141 | progress_bar.update() 142 | 143 | steps_cnt = 0 144 | episode_cnt = 0 145 | progress_bar = ProgressBar(args.total_steps) 146 | while steps_cnt < args.total_steps: 147 | train_res_dict = run_train_episode(env, agent, rpm, args) 148 | # update episodes and steps 149 | episode_cnt += 1 150 | steps_cnt += train_res_dict['episode_step'] 151 | 152 | # learning rate decay 153 | agent.learning_rate = max( 154 | agent.lr_scheduler.step(train_res_dict['episode_step']), 155 | agent.min_learning_rate, 156 | ) 157 | 158 | train_res_dict.update({ 159 | 'exploration': agent.exploration, 160 | 'learning_rate': agent.learning_rate, 161 | 'replay_max_size': rpm.size(), 162 | 'target_update_count': agent.target_update_count, 163 | }) 164 | if episode_cnt % args.train_log_interval == 0: 165 | text_logger.info( 166 | '[Train], episode: {}, train_episode_step: {}, train_win_rate: {:.2f}, train_reward: {:.2f}' 167 | .format( 168 | episode_cnt, 169 | train_res_dict['episode_step'], 170 | train_res_dict['win_rate'], 171 | train_res_dict['episode_reward'], 172 | )) 173 | logger.log_train_data(train_res_dict, steps_cnt) 174 | 175 | if episode_cnt % args.test_log_interval == 0: 176 | eval_res_dict = run_eval_episode(env, agent, args=args) 177 | text_logger.info( 178 | '[Eval], episode: {}, eval_episode_step:{:.2f}, eval_win_rate: {:.2f}, eval_reward: {:.2f}' 179 | .format( 180 | episode_cnt, 181 | eval_res_dict['episode_step'], 182 | eval_res_dict['win_rate'], 183 | eval_res_dict['episode_reward'], 184 | )) 185 | logger.log_test_data(eval_res_dict, steps_cnt) 186 | 187 | progress_bar.update(train_res_dict['episode_step']) 188 | 189 | 190 | if __name__ == '__main__': 191 | main() 192 | -------------------------------------------------------------------------------- /scripts/main_qmix.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | sys.path.append('../') 10 | 11 | from configs.arguments import get_common_args 12 | from configs.qmix_config import QMixConfig 13 | from marltoolkit.agents.qmix_agent import QMixAgent 14 | from marltoolkit.data import ReplayBuffer 15 | from marltoolkit.envs.smacv1.smac_env import SMACWrapperEnv 16 | from marltoolkit.modules.actors import RNNActorModel 17 | from marltoolkit.modules.mixers import QMixerModel 18 | from marltoolkit.runners.episode_runner import (run_eval_episode, 19 | run_train_episode) 20 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger, 21 | get_outdir, get_root_logger) 22 | 23 | 24 | def main(): 25 | """Main function for running the QMix algorithm. 26 | 27 | This function initializes the necessary configurations, environment, logger, models, and agents. 28 | It then runs training episodes and evaluates the agent's performance periodically. 29 | 30 | Returns: 31 | None 32 | """ 33 | qmix_config = QMixConfig() 34 | common_args = get_common_args() 35 | args = argparse.Namespace(**vars(qmix_config)) 36 | args.__dict__.update(**vars(common_args)) 37 | device = torch.device('cuda') if torch.cuda.is_available( 38 | ) and args.cuda else torch.device('cpu') 39 | 40 | env = SMACWrapperEnv(map_name=args.scenario, args=args) 41 | args.episode_limit = env.episode_limit 42 | args.obs_dim = env.obs_dim 43 | args.obs_shape = env.obs_shape 44 | args.state_dim = env.state_dim 45 | args.state_shape = env.state_shape 46 | args.num_agents = env.num_agents 47 | args.n_actions = env.n_actions 48 | args.action_shape = env.action_shape 49 | args.reward_shape = env.reward_shape 50 | args.done_shape = env.done_shape 51 | args.device = device 52 | 53 | # init the logger before other steps 54 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 55 | # log 56 | log_name = os.path.join(args.project, args.scenario, args.algo_name, 57 | timestamp).replace(os.path.sep, '_') 58 | log_path = os.path.join(args.log_dir, args.project, args.scenario, 59 | args.algo_name) 60 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir') 61 | log_file = os.path.join(log_path, log_name + '.log') 62 | text_logger = get_root_logger(log_file=log_file, log_level='INFO') 63 | 64 | if args.logger == 'wandb': 65 | logger = WandbLogger( 66 | train_interval=args.train_log_interval, 67 | test_interval=args.test_log_interval, 68 | update_interval=args.train_log_interval, 69 | project=args.project, 70 | name=log_name, 71 | save_interval=1, 72 | config=args, 73 | ) 74 | writer = SummaryWriter(tensorboard_log_path) 75 | writer.add_text('args', str(args)) 76 | if args.logger == 'tensorboard': 77 | logger = TensorboardLogger(writer) 78 | else: # wandb 79 | logger.load(writer) 80 | 81 | args.obs_shape = env.get_actor_input_shape() 82 | 83 | rpm = ReplayBuffer( 84 | max_size=args.replay_buffer_size, 85 | num_agents=args.num_agents, 86 | episode_limit=args.episode_limit, 87 | obs_shape=args.obs_shape, 88 | state_shape=args.state_shape, 89 | action_shape=args.action_shape, 90 | reward_shape=args.reward_shape, 91 | done_shape=args.done_shape, 92 | device=device, 93 | ) 94 | 95 | actor_model = RNNActorModel(args) 96 | 97 | mixer_model = QMixerModel( 98 | num_agents=args.num_agents, 99 | state_dim=args.state_dim, 100 | mixing_embed_dim=args.mixing_embed_dim, 101 | hypernet_layers=args.hypernet_layers, 102 | hypernet_embed_dim=args.hypernet_embed_dim, 103 | ) 104 | 105 | agent = QMixAgent( 106 | actor_model=actor_model, 107 | mixer_model=mixer_model, 108 | num_agents=args.num_agents, 109 | double_q=args.double_q, 110 | total_steps=args.total_steps, 111 | gamma=args.gamma, 112 | learning_rate=args.learning_rate, 113 | min_learning_rate=args.min_learning_rate, 114 | egreedy_exploration=args.egreedy_exploration, 115 | min_exploration=args.min_exploration, 116 | target_update_tau=args.target_update_tau, 117 | target_update_interval=args.target_update_interval, 118 | learner_update_freq=args.learner_update_freq, 119 | clip_grad_norm=args.clip_grad_norm, 120 | device=args.device, 121 | ) 122 | 123 | progress_bar = ProgressBar(args.memory_warmup_size) 124 | while rpm.size() < args.memory_warmup_size: 125 | run_train_episode(env, agent, rpm, args) 126 | progress_bar.update() 127 | 128 | steps_cnt = 0 129 | episode_cnt = 0 130 | progress_bar = ProgressBar(args.total_steps) 131 | while steps_cnt < args.total_steps: 132 | train_res_dict = run_train_episode(env, agent, rpm, args) 133 | # update episodes and steps 134 | episode_cnt += 1 135 | steps_cnt += train_res_dict['episode_step'] 136 | 137 | # learning rate decay 138 | agent.learning_rate = max( 139 | agent.lr_scheduler.step(train_res_dict['episode_step']), 140 | agent.min_learning_rate, 141 | ) 142 | 143 | train_res_dict.update({ 144 | 'exploration': agent.exploration, 145 | 'learning_rate': agent.learning_rate, 146 | 'replay_max_size': rpm.size(), 147 | 'target_update_count': agent.target_update_count, 148 | }) 149 | if episode_cnt % args.train_log_interval == 0: 150 | text_logger.info( 151 | '[Train], episode: {}, train_episode_step: {}, train_win_rate: {:.2f}, train_reward: {:.2f}' 152 | .format( 153 | episode_cnt, 154 | train_res_dict['episode_step'], 155 | train_res_dict['win_rate'], 156 | train_res_dict['episode_reward'], 157 | )) 158 | logger.log_train_data(train_res_dict, steps_cnt) 159 | 160 | if episode_cnt % args.test_log_interval == 0: 161 | eval_res_dict = run_eval_episode(env, agent, args=args) 162 | text_logger.info( 163 | '[Eval], episode: {}, eval_episode_step:{:.2f}, eval_win_rate: {:.2f}, eval_reward: {:.2f}' 164 | .format( 165 | episode_cnt, 166 | eval_res_dict['episode_step'], 167 | eval_res_dict['win_rate'], 168 | eval_res_dict['episode_reward'], 169 | )) 170 | logger.log_test_data(eval_res_dict, steps_cnt) 171 | 172 | progress_bar.update(train_res_dict['episode_step']) 173 | 174 | 175 | if __name__ == '__main__': 176 | main() 177 | -------------------------------------------------------------------------------- /scripts/main_qtran.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | from copy import deepcopy 6 | 7 | import torch 8 | from smac.env import StarCraft2Env 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | sys.path.append('../') 12 | from configs.arguments import get_common_args 13 | from configs.qtran_config import QTranConfig 14 | from marltoolkit.agents.qtran_agent import QTranAgent 15 | from marltoolkit.data import MaReplayBuffer 16 | from marltoolkit.envs import SC2EnvWrapper 17 | from marltoolkit.modules.actors import RNNModel 18 | from marltoolkit.modules.mixers.qtran_mixer import QTransModel 19 | from marltoolkit.runners.episode_runner import (run_eval_episode, 20 | run_train_episode) 21 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger, 22 | get_outdir, get_root_logger) 23 | 24 | 25 | def main(): 26 | device = torch.device( 27 | 'cuda') if torch.cuda.is_available() else torch.device('cpu') 28 | 29 | config = deepcopy(QTranConfig) 30 | common_args = get_common_args() 31 | common_dict = vars(common_args) 32 | config.update(common_dict) 33 | 34 | env = StarCraft2Env(map_name=config['scenario'], 35 | difficulty=config['difficulty']) 36 | 37 | env = SC2EnvWrapper(env) 38 | config['episode_limit'] = env.episode_limit 39 | config['obs_shape'] = env.obs_shape 40 | config['state_shape'] = env.state_shape 41 | config['n_agents'] = env.n_agents 42 | config['n_actions'] = env.n_actions 43 | 44 | args = argparse.Namespace(**config) 45 | 46 | # init the logger before other steps 47 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 48 | # log 49 | log_name = os.path.join(args.project, args.scenario, args.algo, timestamp) 50 | text_log_path = os.path.join(args.log_dir, args.project, args.scenario, 51 | args.algo) 52 | tensorboard_log_path = get_outdir(text_log_path, 'log_dir') 53 | log_file = os.path.join(text_log_path, f'{timestamp}.log') 54 | text_logger = get_root_logger(log_file=log_file, log_level='INFO') 55 | 56 | if args.logger == 'wandb': 57 | logger = WandbLogger(train_interval=args.train_log_interval, 58 | test_interval=args.test_log_interval, 59 | update_interval=args.train_log_interval, 60 | project=args.project, 61 | name=log_name.replace(os.path.sep, '_'), 62 | save_interval=1, 63 | config=args, 64 | entity='jianzhnie') 65 | writer = SummaryWriter(tensorboard_log_path) 66 | writer.add_text('args', str(args)) 67 | if args.logger == 'tensorboard': 68 | logger = TensorboardLogger(writer) 69 | else: # wandb 70 | logger.load(writer) 71 | 72 | rpm = MaReplayBuffer(max_size=config['replay_buffer_size'], 73 | episode_limit=config['episode_limit'], 74 | state_shape=config['state_shape'], 75 | obs_shape=config['obs_shape'], 76 | num_agents=config['n_agents'], 77 | num_actions=config['n_actions'], 78 | batch_size=config['batch_size'], 79 | device=device) 80 | 81 | actor_model = RNNModel( 82 | input_shape=config['obs_shape'], 83 | rnn_hidden_dim=config['rnn_hidden_dim'], 84 | n_actions=config['n_actions'], 85 | ) 86 | mixer_model = QTransModel(n_agents=config['n_agents'], 87 | n_actions=config['n_actions'], 88 | state_dim=config['state_shape'], 89 | rnn_hidden_dim=config['rnn_hidden_dim'], 90 | mixing_embed_dim=config['mixing_embed_dim']) 91 | 92 | qmix_agent = QTranAgent( 93 | actor_model=actor_model, 94 | mixer_model=mixer_model, 95 | n_agents=config['n_agents'], 96 | n_actions=config['n_actions'], 97 | double_q=config['double_q'], 98 | total_steps=config['total_steps'], 99 | gamma=config['gamma'], 100 | learning_rate=config['learning_rate'], 101 | min_learning_rate=config['min_learning_rate'], 102 | egreedy_exploration=config['egreedy_exploration'], 103 | min_exploration=config['min_exploration'], 104 | target_update_interval=config['target_update_interval'], 105 | learner_update_freq=config['learner_update_freq'], 106 | clip_grad_norm=config['clip_grad_norm'], 107 | opt_loss_coef=config['opt_loss_coef'], 108 | nopt_min_loss_coef=config['nopt_min_loss_coef'], 109 | device=device, 110 | ) 111 | 112 | progress_bar = ProgressBar(config['memory_warmup_size']) 113 | while rpm.size() < config['memory_warmup_size']: 114 | run_train_episode(env, qmix_agent, rpm, config) 115 | progress_bar.update() 116 | 117 | steps_cnt = 0 118 | episode_cnt = 0 119 | progress_bar = ProgressBar(config['total_steps']) 120 | while steps_cnt < config['total_steps']: 121 | episode_reward, episode_step, is_win, mean_loss, mean_td_loss, mean_opt_loss, mean_nopt_loss = run_train_episode( 122 | env, qmix_agent, rpm, config) 123 | # update episodes and steps 124 | episode_cnt += 1 125 | steps_cnt += episode_step 126 | 127 | # learning rate decay 128 | qmix_agent.learning_rate = max( 129 | qmix_agent.lr_scheduler.step(episode_step), 130 | qmix_agent.min_learning_rate) 131 | 132 | train_results = { 133 | 'env_step': episode_step, 134 | 'rewards': episode_reward, 135 | 'win_rate': is_win, 136 | 'mean_loss': mean_loss, 137 | 'mean_td_loss': mean_td_loss, 138 | 'exploration': qmix_agent.exploration, 139 | 'learning_rate': qmix_agent.learning_rate, 140 | 'replay_buffer_size': rpm.size(), 141 | 'target_update_count': qmix_agent.target_update_count, 142 | } 143 | if episode_cnt % config['train_log_interval'] == 0: 144 | text_logger.info( 145 | '[Train], episode: {}, train_win_rate: {:.2f}, train_reward: {:.2f}' 146 | .format(episode_cnt, is_win, episode_reward)) 147 | logger.log_train_data(train_results, steps_cnt) 148 | 149 | if episode_cnt % config['test_log_interval'] == 0: 150 | eval_rewards, eval_steps, eval_win_rate = run_eval_episode( 151 | env, qmix_agent, num_eval_episodes=5) 152 | text_logger.info( 153 | '[Eval], episode: {}, eval_win_rate: {:.2f}, eval_rewards: {:.2f}' 154 | .format(episode_cnt, eval_win_rate, eval_rewards)) 155 | 156 | test_results = { 157 | 'env_step': eval_steps, 158 | 'rewards': eval_rewards, 159 | 'win_rate': eval_win_rate 160 | } 161 | logger.log_test_data(test_results, steps_cnt) 162 | 163 | progress_bar.update(episode_step) 164 | 165 | 166 | if __name__ == '__main__': 167 | main() 168 | -------------------------------------------------------------------------------- /scripts/main_vdn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import time 5 | 6 | import torch 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | sys.path.append('../') 10 | from configs.arguments import get_common_args 11 | from configs.qmix_config import QMixConfig 12 | from marltoolkit.agents.vdn_agent import VDNAgent 13 | from marltoolkit.data import OffPolicyBufferRNN 14 | from marltoolkit.envs.smacv1 import SMACWrapperEnv 15 | from marltoolkit.modules.actors import RNNActorModel 16 | from marltoolkit.modules.mixers import VDNMixer 17 | from marltoolkit.runners.parallel_episode_runner import (run_eval_episode, 18 | run_train_episode) 19 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger, 20 | get_outdir, get_root_logger) 21 | from marltoolkit.utils.env_utils import make_vec_env 22 | 23 | 24 | def main(): 25 | vdn_config = QMixConfig() 26 | common_args = get_common_args() 27 | args = argparse.Namespace(**vars(common_args), **vars(vdn_config)) 28 | device = torch.device('cuda') if torch.cuda.is_available( 29 | ) and args.cuda else torch.device('cpu') 30 | 31 | train_envs, test_envs = make_vec_env( 32 | env_id=args.env_id, 33 | map_name=args.scenario, 34 | num_train_envs=args.num_train_envs, 35 | num_test_envs=args.num_test_envs, 36 | difficulty=args.difficulty, 37 | ) 38 | env = SMACWrapperEnv(map_name=args.scenario) 39 | args.episode_limit = env.episode_limit 40 | args.obs_dim = env.obs_dim 41 | args.obs_shape = env.obs_shape 42 | args.state_shape = env.state_shape 43 | args.num_agents = env.num_agents 44 | args.n_actions = env.n_actions 45 | args.action_shape = env.action_shape 46 | args.reward_shape = env.reward_shape 47 | args.done_shape = env.done_shape 48 | args.device = device 49 | 50 | # init the logger before other steps 51 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 52 | # log 53 | log_name = os.path.join(args.project, args.scenario, args.algo_name, 54 | timestamp).replace(os.path.sep, '_') 55 | log_path = os.path.join(args.log_dir, args.project, args.scenario, 56 | args.algo_name) 57 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir') 58 | log_file = os.path.join(log_path, log_name + '.log') 59 | text_logger = get_root_logger(log_file=log_file, log_level='INFO') 60 | 61 | if args.logger == 'wandb': 62 | logger = WandbLogger( 63 | train_interval=args.train_log_interval, 64 | test_interval=args.test_log_interval, 65 | update_interval=args.train_log_interval, 66 | project=args.project, 67 | name=log_name, 68 | save_interval=1, 69 | config=args, 70 | ) 71 | writer = SummaryWriter(tensorboard_log_path) 72 | writer.add_text('args', str(args)) 73 | if args.logger == 'tensorboard': 74 | logger = TensorboardLogger(writer) 75 | else: # wandb 76 | logger.load(writer) 77 | 78 | rpm = OffPolicyBufferRNN( 79 | max_size=args.replay_buffer_size, 80 | num_envs=args.num_train_envs, 81 | num_agents=args.num_agents, 82 | episode_limit=args.episode_limit, 83 | obs_space=args.obs_shape, 84 | state_space=args.state_shape, 85 | action_space=args.action_shape, 86 | reward_space=args.reward_shape, 87 | done_space=args.done_shape, 88 | device=args.device, 89 | ) 90 | actor_model = RNNActorModel( 91 | input_dim=args.obs_dim, 92 | fc_hidden_dim=args.fc_hidden_dim, 93 | rnn_hidden_dim=args.rnn_hidden_dim, 94 | n_actions=args.n_actions, 95 | ) 96 | 97 | mixer_model = VDNMixer() 98 | 99 | marl_agent = VDNAgent( 100 | actor_model=actor_model, 101 | mixer_model=mixer_model, 102 | num_envs=args.num_train_envs, 103 | num_agents=args.num_agents, 104 | action_dim=args.n_actions, 105 | double_q=args.double_q, 106 | total_steps=args.total_steps, 107 | gamma=args.gamma, 108 | learning_rate=args.learning_rate, 109 | min_learning_rate=args.min_learning_rate, 110 | egreedy_exploration=args.egreedy_exploration, 111 | min_exploration=args.min_exploration, 112 | target_update_interval=args.target_update_interval, 113 | learner_update_freq=args.learner_update_freq, 114 | clip_grad_norm=args.clip_grad_norm, 115 | device=args.device, 116 | ) 117 | 118 | progress_bar = ProgressBar(args.memory_warmup_size) 119 | while rpm.size() < args.memory_warmup_size: 120 | run_train_episode(train_envs, marl_agent, rpm, args) 121 | progress_bar.update() 122 | 123 | steps_cnt = 0 124 | episode_cnt = 0 125 | progress_bar = ProgressBar(args.total_steps) 126 | while steps_cnt < args.total_steps: 127 | episode_reward, episode_step, is_win, mean_loss, mean_td_error = ( 128 | run_train_episode(train_envs, marl_agent, rpm, args)) 129 | # update episodes and steps 130 | episode_cnt += 1 131 | steps_cnt += episode_step 132 | 133 | # learning rate decay 134 | marl_agent.learning_rate = max( 135 | marl_agent.lr_scheduler.step(episode_step), 136 | marl_agent.min_learning_rate) 137 | 138 | train_results = { 139 | 'env_step': episode_step, 140 | 'rewards': episode_reward, 141 | 'win_rate': is_win, 142 | 'mean_loss': mean_loss, 143 | 'mean_td_error': mean_td_error, 144 | 'exploration': marl_agent.exploration, 145 | 'learning_rate': marl_agent.learning_rate, 146 | 'replay_buffer_size': rpm.size(), 147 | 'target_update_count': marl_agent.target_update_count, 148 | } 149 | if episode_cnt % args.train_log_interval == 0: 150 | text_logger.info( 151 | '[Train], episode: {}, train_win_rate: {:.2f}, train_reward: {:.2f}' 152 | .format(episode_cnt, is_win, episode_reward)) 153 | logger.log_train_data(train_results, steps_cnt) 154 | 155 | if episode_cnt % args.test_log_interval == 0: 156 | eval_rewards, eval_steps, eval_win_rate = run_eval_episode( 157 | test_envs, marl_agent, num_eval_episodes=5) 158 | text_logger.info( 159 | '[Eval], episode: {}, eval_win_rate: {:.2f}, eval_rewards: {:.2f}' 160 | .format(episode_cnt, eval_win_rate, eval_rewards)) 161 | 162 | test_results = { 163 | 'env_step': eval_steps, 164 | 'rewards': eval_rewards, 165 | 'win_rate': eval_win_rate 166 | } 167 | logger.log_test_data(test_results, steps_cnt) 168 | 169 | progress_bar.update(episode_step) 170 | 171 | 172 | if __name__ == '__main__': 173 | main() 174 | -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | # # main 2 | # nohup python main.py --scenario 3m --total_steps 1000000 > runoob1.log 2>&1 & 3 | # nohup python main.py --scenario 8m --total_steps 1000000 > runoob2.log 2>&1 & 4 | # nohup python main.py --scenario 5m_vs_6m --total_steps 1000000 > runoob3.log 2>&1 & 5 | # nohup python main.py --scenario 8m_vs_9m --total_steps 1000000 > runoob4.log 2>&1 & 6 | # nohup python main.py --scenario MMM --total_steps 1000000 > runoob5.log 2>&1 & 7 | # nohup python main.py --scenario 2s3z --total_steps 1000000 > runoob6.log 2>&1 & 8 | # nohup python main.py --scenario 3s5z --total_steps 1000000 > runoob7.log 2>&1 & 9 | 10 | 11 | # main 12 | nohup python main_qmix.py --scenario 3m --total_steps 1000000 > runoob1.log 2>&1 & 13 | nohup python main_qmix.py --scenario 8m --total_steps 1000000 > runoob2.log 2>&1 & 14 | nohup python main_qmix.py --scenario 5m_vs_6m --total_steps 1000000 > runoob3.log 2>&1 & 15 | nohup python main_qmix.py --scenario 8m_vs_9m --total_steps 1000000 > runoob4.log 2>&1 & 16 | nohup python main_qmix.py --scenario MMM --total_steps 1000000 > runoob5.log 2>&1 & 17 | nohup python main_qmix.py --scenario 2s3z --total_steps 1000000 > runoob6.log 2>&1 & 18 | nohup python main_qmix.py --scenario 3s5z --total_steps 1000000 > runoob7.log 2>&1 & 19 | -------------------------------------------------------------------------------- /scripts/test_multi_env.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('../') 4 | import torch 5 | from torch.distributions import Categorical 6 | 7 | from marltoolkit.envs.smacv1.smac_env import SMACWrapperEnv 8 | 9 | if __name__ == '__main__': 10 | env = SMACWrapperEnv(map_name='3m') 11 | env_info = env.get_env_info() 12 | print('env_info:', env_info) 13 | 14 | action_dim = env_info['n_actions'] 15 | num_agents = env_info['num_agents'] 16 | results = env.reset() 17 | print('Reset:', results) 18 | 19 | avail_actions = env.get_available_actions() 20 | print('avail_actions:', avail_actions) 21 | available_actions = torch.tensor(avail_actions) 22 | actions_dist = Categorical(available_actions) 23 | random_actions = actions_dist.sample().numpy().tolist() 24 | 25 | print('random_actions: ', random_actions) 26 | results = env.step(random_actions) 27 | print('Step:', results) 28 | env.close() 29 | -------------------------------------------------------------------------------- /scripts/test_multi_env2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import torch 4 | from torch.distributions import Categorical 5 | 6 | sys.path.append('../') 7 | from marltoolkit.envs.smacv1.smac_env import SMACWrapperEnv 8 | from marltoolkit.envs.vec_env import DummyVecEnv, SubprocVecEnv 9 | 10 | 11 | def make_subproc_envs(map_name='3m', parallels=8) -> SubprocVecEnv: 12 | 13 | def _thunk(): 14 | env = SMACWrapperEnv(map_name=map_name) 15 | return env 16 | 17 | return SubprocVecEnv([_thunk for _ in range(parallels)], 18 | shared_memory=False) 19 | 20 | 21 | def make_dummy_envs(map_name='3m', parallels=8) -> DummyVecEnv: 22 | 23 | def _thunk(): 24 | env = SMACWrapperEnv(map_name=map_name) 25 | return env 26 | 27 | return DummyVecEnv([_thunk for _ in range(parallels)]) 28 | 29 | 30 | def test_dummy_envs(): 31 | parallels = 2 32 | env = SMACWrapperEnv(map_name='3m') 33 | env.reset() 34 | env_info = env.get_env_info() 35 | print('env_info:', env_info) 36 | 37 | train_envs = make_dummy_envs(parallels=parallels) 38 | results = train_envs.reset() 39 | print('Reset:', results) 40 | 41 | num_envs = parallels 42 | avail_actions = env.get_available_actions() 43 | print('avail_actions:', avail_actions) 44 | available_actions = torch.tensor(avail_actions) 45 | actions_dist = Categorical(available_actions) 46 | random_actions = actions_dist.sample().numpy().tolist() 47 | 48 | print('random_actions: ', random_actions) 49 | dummy_actions = [random_actions for _ in range(num_envs)] 50 | print('dummy_actions:', dummy_actions) 51 | results = train_envs.step(dummy_actions) 52 | print('Step:', results) 53 | train_envs.close() 54 | 55 | 56 | def test_subproc_envs(): 57 | parallels = 10 58 | train_envs = make_subproc_envs(parallels=parallels) 59 | results = train_envs.reset() 60 | obs, state, info = results 61 | print('Env Reset:', '*' * 100) 62 | print('obs:', obs.shape) 63 | print('state:', state.shape) 64 | print('info:', info) 65 | 66 | avail_actions = train_envs.get_available_actions() 67 | print('avail_actions:', avail_actions) 68 | available_actions = torch.tensor(avail_actions) 69 | actions_dist = Categorical(available_actions) 70 | random_actions = actions_dist.sample().numpy().tolist() 71 | print('random_actions: ', random_actions) 72 | results = train_envs.step(random_actions) 73 | obs, state, rewards, dones, info = results 74 | print('Env Step:', '*' * 100) 75 | print('obs:', obs.shape) 76 | print('state:', state.shape) 77 | print('rewards:', rewards) 78 | print('dones:', dones) 79 | print('info:', info) 80 | 81 | 82 | if __name__ == '__main__': 83 | test_subproc_envs() 84 | --------------------------------------------------------------------------------