├── .flake8
├── .gitignore
├── .isort.cfg
├── .pre-commit-config.yaml
├── LICENSE
├── README.md
├── configs
├── __init__.py
├── arguments.py
├── coma_config.py
├── idqn_config.py
├── qmix_config.py
└── qtran_config.py
├── docs
├── awesome_marl.md
├── harl.md
├── marllib.md
└── pymarl.md
├── marltoolkit
├── __init__.py
├── __version__.py
├── agents
│ ├── __init__.py
│ ├── ac_agent.py
│ ├── base_agent.py
│ ├── coma_agent.py
│ ├── idqn_agent.py
│ ├── maddpg_agent.py
│ ├── mappo_agent.py
│ ├── qatten_agent.py
│ ├── qmix2_agent.py
│ ├── qmix_agent.py
│ ├── qtran_agent.py
│ └── vdn_agent.py
├── data
│ ├── __init__.py
│ ├── base_buffer.py
│ ├── ma_buffer.py
│ ├── offpolicy_buffer.py
│ ├── onpolicy_buffer.py
│ └── shared_buffer.py
├── envs
│ ├── __init__.py
│ ├── base_env.py
│ ├── pettingzoo
│ │ ├── __init__.py
│ │ ├── custom_env1.py
│ │ ├── custom_env2.py
│ │ └── pettingzoo_env.py
│ ├── smacv1
│ │ ├── __init__.py
│ │ ├── smac_env.py
│ │ └── smac_vec_env.py
│ ├── smacv2
│ │ ├── __init__.py
│ │ └── smacv2_env.py
│ ├── vec_env
│ │ ├── __init__.py
│ │ ├── base_vec_env.py
│ │ ├── dummy_vec_env.py
│ │ ├── subproc_vec_env.py
│ │ ├── utils.py
│ │ └── vec_monitor.py
│ └── waregame
│ │ ├── __init__.py
│ │ └── wargame_wrapper.py
├── modules
│ ├── __init__.py
│ ├── actors
│ │ ├── __init__.py
│ │ ├── mlp.py
│ │ ├── r_actor.py
│ │ └── rnn.py
│ ├── critics
│ │ ├── __init__.py
│ │ ├── coma.py
│ │ ├── maddpg.py
│ │ ├── mlp.py
│ │ └── r_critic.py
│ ├── mixers
│ │ ├── __init__.py
│ │ ├── qatten.py
│ │ ├── qmixer.py
│ │ ├── qtran.py
│ │ └── vdn.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── act.py
│ │ ├── common.py
│ │ ├── distributions.py
│ │ ├── popart.py
│ │ └── valuenorm.py
├── runners
│ ├── __init__.py
│ ├── episode_runner.py
│ ├── onpolicy_runner.py
│ ├── parallel_episode_runner.py
│ └── qtran_runner.py.py
└── utils
│ ├── __init__.py
│ ├── env_utils.py
│ ├── logger
│ ├── __init__.py
│ ├── base.py
│ ├── base_orig.py
│ ├── logging.py
│ ├── logs.py
│ ├── tensorboard.py
│ └── wandb.py
│ ├── lr_scheduler.py
│ ├── model_utils.py
│ ├── progressbar.py
│ ├── timer.py
│ └── transforms.py
├── requirements.txt
└── scripts
├── main_coma.py
├── main_idqn.py
├── main_mappo.py
├── main_qmix.py
├── main_qtran.py
├── main_vdn.py
├── run.sh
├── smac_runner.py
├── test_multi_env.py
└── test_multi_env2.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore = E501, W503, E126, W504, E402,E251
3 | max-line-length = 79
4 | show-source = False
5 | application-import-names = marltoolkit
6 | exclude =
7 | .git
8 | docs/
9 | venv
10 | build
11 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | environment/
108 | ENVIORNMENT/
109 | venv/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | .pyre/
131 | work_dir/
132 | work_dirs/
133 | .DS_Store
134 | events*
135 | logs/
136 | wandb/
137 | nohup.out
138 | protein_data
139 | **/checkpoints
140 | *.npy
141 | hub/
142 | 周报/
143 | multiagent-particle-envs/
144 | results/
145 | *.mp4
146 | *.json
147 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | known_third_party =numpy,smac,torch,wandb
3 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://gitee.com/openmmlab/mirrors-flake8
3 | rev: 5.0.4
4 | hooks:
5 | - id: flake8
6 | - repo: https://gitee.com/openmmlab/mirrors-isort
7 | rev: 5.11.5
8 | hooks:
9 | - id: isort
10 | - repo: https://gitee.com/openmmlab/mirrors-yapf
11 | rev: v0.32.0
12 | hooks:
13 | - id: yapf
14 | - repo: https://gitee.com/openmmlab/mirrors-pre-commit-hooks
15 | rev: v4.3.0
16 | hooks:
17 | - id: trailing-whitespace
18 | - id: check-yaml
19 | - id: end-of-file-fixer
20 | - id: requirements-txt-fixer
21 | - id: double-quote-string-fixer
22 | - id: check-merge-conflict
23 | - id: fix-encoding-pragma
24 | args: ["--remove"]
25 | - id: mixed-line-ending
26 | args: ["--fix=lf"]
27 | - repo: https://gitee.com/openmmlab/mirrors-mdformat
28 | rev: 0.7.9
29 | hooks:
30 | - id: mdformat
31 | args: ["--number"]
32 | additional_dependencies:
33 | - mdformat-openmmlab
34 | - mdformat_frontmatter
35 | - linkify-it-py
36 | - repo: https://gitee.com/openmmlab/mirrors-docformatter
37 | rev: v1.3.1
38 | hooks:
39 | - id: docformatter
40 | args: ["--in-place", "--wrap-descriptions", "79"]
41 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/configs/__init__.py
--------------------------------------------------------------------------------
/configs/coma_config.py:
--------------------------------------------------------------------------------
1 | class ComaConfig:
2 | """Configuration class for Coma model.
3 |
4 | ComaConfig contains parameters used to instantiate a Coma model.
5 | These parameters define the model architecture, behavior, and training settings.
6 |
7 | Args:
8 | rnn_hidden_dim (int, optional): Dimension of GRU's hidden state.
9 | gamma (float, optional): Discount factor in reinforcement learning.
10 | td_lambda (float, optional): Lambda in TD(lambda).
11 | egreedy_exploration (float, optional): Initial 'epsilon' in epsilon-greedy exploration.
12 | min_exploration (float, optional): Minimum 'epsilon' in epsilon-greedy.
13 | target_update_interval (int, optional): Sync parameters to target model after 'target_update_interval' times.
14 | learning_rate (float, optional): Learning rate of the optimizer.
15 | min_learning_rate (float, optional): Minimum learning rate of the optimizer.
16 | clip_grad_norm (float, optional): Clipped value of the global norm of gradients.
17 | learner_update_freq (int, optional): Update learner frequency.
18 | double_q (bool, optional): Use Double-DQN.
19 | algo_name (str, optional): Name of the algorithm.
20 | """
21 |
22 | model_type: str = 'coma'
23 |
24 | def __init__(
25 | self,
26 | fc_hidden_dim: int = 64,
27 | rnn_hidden_dim: int = 64,
28 | gamma: float = 0.99,
29 | td_lambda: float = 0.8,
30 | egreedy_exploration: float = 1.0,
31 | min_exploration: float = 0.1,
32 | target_update_interval: int = 1000,
33 | learning_rate: float = 0.0005,
34 | min_learning_rate: float = 0.0001,
35 | clip_grad_norm: float = 10,
36 | learner_update_freq: int = 2,
37 | double_q: bool = True,
38 | algo_name: str = 'coma',
39 | ) -> None:
40 | # Network architecture parameters
41 | self.fc_hidden_dim = fc_hidden_dim
42 | self.rnn_hidden_dim = rnn_hidden_dim
43 |
44 | # Training parameters
45 | self.gamma = gamma
46 | self.td_lambda = td_lambda
47 | self.egreedy_exploration = egreedy_exploration
48 | self.min_exploration = min_exploration
49 | self.target_update_interval = target_update_interval
50 | self.learning_rate = learning_rate
51 | self.min_learning_rate = min_learning_rate
52 | self.clip_grad_norm = clip_grad_norm
53 | self.learner_update_freq = learner_update_freq
54 | self.double_q = double_q
55 |
56 | # Logging parameters
57 | self.algo_name = algo_name
58 |
--------------------------------------------------------------------------------
/configs/idqn_config.py:
--------------------------------------------------------------------------------
1 | class IDQNConfig:
2 | """Configuration class for QMix model.
3 |
4 | IDQNConfig contains parameters used to instantiate a QMix model.
5 | These parameters define the model architecture, behavior, and training settings.
6 |
7 | Args:
8 | rnn_hidden_dim (int, optional): Dimension of GRU's hidden state.
9 | gamma (float, optional): Discount factor in reinforcement learning.
10 | egreedy_exploration (float, optional): Initial 'epsilon' in epsilon-greedy exploration.
11 | min_exploration (float, optional): Minimum 'epsilon' in epsilon-greedy.
12 | target_update_interval (int, optional): Sync parameters to target model after 'target_update_interval' times.
13 | learning_rate (float, optional): Learning rate of the optimizer.
14 | min_learning_rate (float, optional): Minimum learning rate of the optimizer.
15 | clip_grad_norm (float, optional): Clipped value of the global norm of gradients.
16 | hypernet_layers (int, optional): Number of layers in hypernetwork.
17 | hypernet_embed_dim (int, optional): Embedding dimension for hypernetwork.
18 | learner_update_freq (int, optional): Update learner frequency.
19 | double_q (bool, optional): Use Double-DQN.
20 | algo_name (str, optional): Name of the algorithm.
21 | """
22 |
23 | def __init__(
24 | self,
25 | rnn_hidden_dim: int = 64,
26 | gamma: float = 0.99,
27 | egreedy_exploration: float = 1.0,
28 | min_exploration: float = 0.1,
29 | target_update_interval: int = 1000,
30 | learning_rate: float = 0.0005,
31 | min_learning_rate: float = 0.0001,
32 | clip_grad_norm: float = 10,
33 | learner_update_freq: int = 2,
34 | double_q: bool = True,
35 | algo_name: str = 'idqn',
36 | ) -> None:
37 |
38 | # Network architecture parameters
39 | self.rnn_hidden_dim = rnn_hidden_dim
40 |
41 | # Training parameters
42 | self.gamma = gamma
43 | self.egreedy_exploration = egreedy_exploration
44 | self.min_exploration = min_exploration
45 | self.target_update_interval = target_update_interval
46 | self.learning_rate = learning_rate
47 | self.min_learning_rate = min_learning_rate
48 | self.clip_grad_norm = clip_grad_norm
49 | self.learner_update_freq = learner_update_freq
50 | self.double_q = double_q
51 |
52 | # Logging parameters
53 | self.algo_name = algo_name
54 |
--------------------------------------------------------------------------------
/configs/qmix_config.py:
--------------------------------------------------------------------------------
1 | class QMixConfig:
2 | """Configuration class for QMix model.
3 |
4 | QMixConfig contains parameters used to instantiate a QMix model.
5 | These parameters define the model architecture, behavior, and training settings.
6 |
7 | Args:
8 | mixing_embed_dim (int, optional): Embedding dimension of the mixing network.
9 | rnn_hidden_dim (int, optional): Dimension of GRU's hidden state.
10 | gamma (float, optional): Discount factor in reinforcement learning.
11 | egreedy_exploration (float, optional): Initial 'epsilon' in epsilon-greedy exploration.
12 | min_exploration (float, optional): Minimum 'epsilon' in epsilon-greedy.
13 | target_update_interval (int, optional): Sync parameters to target model after 'target_update_interval' times.
14 | learning_rate (float, optional): Learning rate of the optimizer.
15 | min_learning_rate (float, optional): Minimum learning rate of the optimizer.
16 | clip_grad_norm (float, optional): Clipped value of the global norm of gradients.
17 | hypernet_layers (int, optional): Number of layers in hypernetwork.
18 | hypernet_embed_dim (int, optional): Embedding dimension for hypernetwork.
19 | learner_update_freq (int, optional): Update learner frequency.
20 | double_q (bool, optional): Use Double-DQN.
21 | algo_name (str, optional): Name of the algorithm.
22 | """
23 |
24 | model_type: str = 'qmix'
25 |
26 | def __init__(
27 | self,
28 | fc_hidden_dim: int = 64,
29 | rnn_hidden_dim: int = 64,
30 | gamma: float = 0.99,
31 | egreedy_exploration: float = 1.0,
32 | min_exploration: float = 0.01,
33 | target_update_tau: float = 0.05,
34 | target_update_interval: int = 100,
35 | learning_rate: float = 0.1,
36 | min_learning_rate: float = 0.00001,
37 | clip_grad_norm: float = 10,
38 | hypernet_layers: int = 2,
39 | hypernet_embed_dim: int = 64,
40 | mixing_embed_dim: int = 32,
41 | learner_update_freq: int = 3,
42 | double_q: bool = True,
43 | algo_name: str = 'qmix',
44 | ) -> None:
45 | # Network architecture parameters
46 | self.fc_hidden_dim = fc_hidden_dim
47 | self.rnn_hidden_dim = rnn_hidden_dim
48 | self.hypernet_layers = hypernet_layers
49 | self.hypernet_embed_dim = hypernet_embed_dim
50 | self.mixing_embed_dim = mixing_embed_dim
51 |
52 | # Training parameters
53 | self.gamma = gamma
54 | self.egreedy_exploration = egreedy_exploration
55 | self.min_exploration = min_exploration
56 | self.target_update_tau = target_update_tau
57 | self.target_update_interval = target_update_interval
58 | self.learning_rate = learning_rate
59 | self.min_learning_rate = min_learning_rate
60 | self.clip_grad_norm = clip_grad_norm
61 | self.learner_update_freq = learner_update_freq
62 | self.double_q = double_q
63 |
64 | # Logging parameters
65 | self.algo_name = algo_name
66 |
--------------------------------------------------------------------------------
/configs/qtran_config.py:
--------------------------------------------------------------------------------
1 | QTranConfig = {
2 | 'project': 'StarCraft-II',
3 | 'scenario': '3m',
4 | 'replay_buffer_size': 5000,
5 | 'mixing_embed_dim': 32,
6 | 'rnn_hidden_dim': 64,
7 | 'learning_rate': 0.0005,
8 | 'min_learning_rate': 0.0001,
9 | 'memory_warmup_size': 32,
10 | 'gamma': 0.99,
11 | 'egreedy_exploration': 1.0,
12 | 'min_exploration': 0.1,
13 | 'target_update_interval': 2000,
14 | 'batch_size': 32,
15 | 'total_steps': 1000000,
16 | 'train_log_interval': 5, # log every 10 episode
17 | 'test_log_interval': 20, # log every 100 epidode
18 | 'clip_grad_norm': 10,
19 | 'learner_update_freq': 2,
20 | 'double_q': True,
21 | 'opt_loss_coef': 1.0,
22 | 'nopt_min_loss_coef': 0.1,
23 | 'difficulty': '7',
24 | 'algo': 'qtran',
25 | 'log_dir': 'work_dirs/',
26 | 'logger': 'wandb',
27 | }
28 |
--------------------------------------------------------------------------------
/docs/harl.md:
--------------------------------------------------------------------------------
1 | # 异构智能体强化学习
2 |
3 | HARL 算法是 在一般异构智能体设置中实现有效多智能体合作的新颖解决方案,而不依赖于限制性参数共享技巧。
4 |
5 | ## 支持的算法
6 |
7 | - HAPPO
8 | - HATRPO
9 | - HAA2C
10 | - HADDPG
11 | - HATD3
12 | - HAD3QN
13 | - HASAC
14 |
15 | ## 主要特性
16 |
17 | - HARL算法通过采用顺序更新方案来实现协调智能体更新,这与MAPPO和MADDPG采用的同时更新方案不同。
18 | - HARL 算法享有单调改进和收敛于均衡的理论保证,确保其促进智能体之间合作行为的有效性。
19 | - 在线策略和离线策略 HARL 算法(分别以HAPPO和HASAC为例)在各种基准测试中都表现出了卓越的性能
20 |
--------------------------------------------------------------------------------
/docs/marllib.md:
--------------------------------------------------------------------------------
1 | # MARLlib
2 |
3 | ## Intro
4 |
5 | MARLlib 是一个基于Ray和其工具包RLlib的综合多智能体强化学习算法库。它为多智能体强化学习研究社区提供了一个统一的平台,用于构建、训练和评估多智能体强化学习算法。
6 |
7 |
8 |
9 | > Overview of the MARLlib architecture.\`\`
10 |
11 | MARLlib 提供了几个突出的关键特点:
12 |
13 | 1. **统一算法Pipeline:** MARLlib通过基于Agenty 级别的分布式数据流将多样的算法Pipeline统一起来,使研究人员能够在不同的任务和环境中开发、测试和评估多智能体强化学习算法。
14 | 2. **支持各种任务模式:** MARLlib支持协作、协同、竞争和混合等所有任务模式。
15 | 3. **与Gym结构一致的接口:** MARLlib提供了一个新的与Gym 相似的环境接口,方便研究人员更加轻松的处理多智能环境。
16 | 4. **灵活参数共享策略:** MARLlib提供了灵活且可定制的参数共享策略。
17 |
18 | 使用 MARLlib,您可以享受到各种好处,例如:
19 |
20 | 1. **零MARL知识门槛:** MARLlib提供了18个预置算法,并具有简单的API,使研究人员能够在不具备多智能体强化学习领域知识的情况下开始进行实验。
21 | 2. **支持所有任务模式:** MARLlib支持几乎所有多智能体环境,使研究人员能够更轻松地尝试不同的任务模式。
22 | 3. **可定制的模型架构:** 研究人员可以从模型库中选择他们喜欢的模型架构,或者构建自己的模型。
23 | 4. **可定制的策略共享:** MARLlib提供了策略共享的分组选项,研究人员也可以创建自己的共享策略。
24 | 5. **访问上千个发布的实验:** 研究人员可以访问上千个已发布的实验,了解其他研究人员如何使用MARLlib。
25 |
26 | ## MARLlib 框架
27 |
28 | ## 环境接口
29 |
30 |
31 |
32 | Agent-Environment Interface in MARLlib
33 |
34 | MARLlib 中的环境接口支持以下功能:
35 |
36 | 1. 智能体无关:每个智能体在训练阶段都有隔离的数据
37 | 2. 任务无关:多种环境支持一个统一的接口
38 | 3. 异步采样:灵活的智能体-环境交互模式
39 |
40 | 首先,MARLlib 将 MARL 视为单智能体 RL 过程的组合。
41 |
42 | 其次,MARLlib 将所有十种环境统一为一个抽象接口,有助于减轻算法设计工作的负担。并且该接口下的环境可以是任何实例,从而实现多任务 或者 任务无关的学习。
43 |
44 | 第三,与大多数现有的 MARL 框架仅支持智能体和环境之间的同步交互不同,MARLlib 支持异步交互风格。这要归功于RLlib灵活的数据收集机制,不同agent的数据可以通过同步和异步的方式收集和存储。
45 |
46 | ## 工作流
47 |
48 | ### 第一阶段:预学习
49 |
50 | MARLlib 通过实例化环境和智能体模型来开始强化学习过程。随后,根据环境特征生成模拟批次,并将其输入指定算法的采样/训练Pipeline中。成功完成学习工作流程并且没有遇到错误后,MARLlib 将进入后续阶段。
51 |
52 | 
53 |
54 | > 预习阶段
55 |
56 | ### 第二阶段:采样和训练
57 |
58 | 预学习阶段完成后,MARLlib 将实际工作分配给 Worker 和 Leaner,并在执行计划下安排这些流程以启动学习过程。
59 |
60 | 在标准学习迭代期间,每个 Worker 使用智能体模型与其环境实例交互以采样数据,然后将数据传递到ReplayBuffer。ReplayBuffer 根据算法进行初始化,并决定如何存储数据。例如,对于on-policy算法,缓冲区是一个串联操作,而对于off-policy算法,缓冲区是一个FIFO队列。
61 |
62 | 随后,预定义的策略映射功能将收集到的数据分发给不同的智能体。一旦完全收集了一次训练迭代的所有数据,学习器就开始使用这些数据优化策略,并将新模型广播给每个Worker以进行下一轮采样。
63 |
64 | 
65 |
66 | 采样和训练阶段
67 |
68 | ### 算法Pipeline
69 |
70 | 
71 |
72 | ### 独立学习
73 |
74 | 在 MARLlib 中,由于 RLlib 提供了许多算法,实现独立学习(左)非常简单。要开始训练,可以从 RLlib 中选择一种算法并将其应用到多智能体环境中,与 RLlib 相比,无需额外的工作。尽管 MARL 中的独立学习不需要任何数据交换,但在大多数任务中其性能通常不如中心化训练策略。
75 |
76 | ### 中心化 Critic
77 |
78 | 中心化 Critic 是 MARLlib 支持的 CTDE 框架中的两种中心化训练策略之一。在这种方法下,智能体需要在获得策略输出之后但在临界值计算之前相互共享信息。这些共享信息包括 独立观察、动作和全局状态(如果有)。
79 |
80 | 交换的数据在采样阶段被收集并存储为过渡数据,其中每个过渡数据都包含自收集的数据和交换的数据。然后利用这些数据来优化中心化评价函数和分散式策略函数。信息共享的实现主要是在同策略算法的后处理函数中完成的。对于像 MADDPG 这样的Off-Policy算法,在数据进入训练迭代批次之前会收集其他数据,例如其他智能体提供的动作值。
81 |
82 | ### 价值分解
83 |
84 | 在 MARLlib 中,价值分解(VD)是另一类中心化训练策略,与中心化评价者的不同之处在于需要智能体共享信息。具体来说,仅需要在智能体之间共享预测的 Q 值或临界值,并且根据所使用的算法可能需要额外的数据。例如,QMIX 需要全局状态来计算混合 Q 值。
85 |
86 | VD 的数据收集和存储机制与中心化Critic的数据收集和存储机制类似,智能体在采样阶段收集和存储转换数据。联合Q学习方法(VDN、QMIX)基于原始PyMARL,五种VD算法中只有FACMAC、VDA2C和VDPPO遵循标准RLlib训练流程。
87 |
88 | ## 关键组件
89 |
90 | ### 数据收集前的处理
91 |
92 | MARL 算法采用 中心化 训练和分散执行(CTDE)范式,需要在学习阶段在Agent之间共享信息。在 QMIX、FACMAC 和 VDA2C 等值分解算法中,总 Q 或 V 值的计算需要Agent提供各自的 Q 或 V 值估计。相反,基于中心化评价的算法,如 MADDPG、MAPPO 和 HAPPO,需要Agent共享他们的观察和动作数据,以确定中心化评价值。后处理模块是Agent与同伴交换数据的理想位置。对于中心化评价算法,Agent可以从其他Agent获取附加信息来计算中心化评价值。另一方面,对于值分解算法,智能体必须向其他智能体提供其预测的 Q 或 V 值。此外,后处理模块还负责使用 GAE 或 N 步奖励调整等技术来计算各种学习目标。
93 |
94 | 
95 |
96 | 数据收集前的后处理
97 |
98 | ### 批量学习前的后处理
99 |
100 | 在 MARL 算法的背景下,并非所有算法都可以利用后处理模块。其中一个例子是像 MADDPG 和 FACMAC 这样的Off-Policy算法,它们面临着ReplayBuffer中过时数据无法用于当前训练交互的挑战。为了应对这一挑战,我们实现了一个额外的“批量学习之前”功能,以在采样批次进入训练循环之前准确计算当前模型的 Q 或 V 值。这确保了用于训练的数据是最新且准确的,从而提高了训练效果。
101 |
102 | 
103 |
104 | 批量学习前的后处理
105 |
106 | ### 中心化价值函数
107 |
108 | 在中心化的评价家Agent模型中,传统的仅基于Agent自我观察的价值函数被能够适应算法要求的中心化 critic所取代。中心化 critic 负责处理从其他Agent收到的信息并生成 中心化 值作为输出。
109 |
110 | ### 混合价值函数
111 |
112 | 在价值分解Agent模型中,保留了原来的价值函数,但引入了新的混合价值函数以获得总体混合值。混合功能灵活,可根据用户要求定制。目前支持VDN和QMIX混合功能。
113 |
114 | ### 异构优化
115 |
116 | 在异构优化中,各个Agent参数是独立更新的,因此,策略函数不会在不同Agent之间共享。然而,根据算法证明,顺序更新Agent的策略并设置与丢失相关的传票的值可以导致任何正更新的增量求和。
117 |
118 | 为了保证算法的增量单调性,利用信任域来获得合适的参数更新,就像HATRPO算法中的情况一样。为了在考虑计算效率的同时加速策略和评价者更新过程,HAPPO 算法中采用了近端策略优化技术。
119 |
120 | 
121 |
122 | 异构Agent评价优化
123 |
124 | ### 策略映射
125 |
126 | 策略映射在标准化多智能体强化学习 (MARL) 环境的接口方面发挥着至关重要的作用。在 MARLlib 中,策略映射被实现为具有层次结构的字典。顶级键代表场景名称,第二级键包含组信息,四个附加键(**description**、**team_prefix**、 **all_agents_one_policy**和**one_agent_one_policy**)用于定义各种策略设置。 team_prefix键根据Agent的名称**对**Agent进行分组,而最后两个键指示完全共享或非共享策略策略是否适用于给定场景。利用策略映射方法来初始化策略并将其分配给不同的智能体,并且每个策略仅使用其相应策略组中的智能体采样的数据来训练。
127 |
--------------------------------------------------------------------------------
/marltoolkit/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/__init__.py
--------------------------------------------------------------------------------
/marltoolkit/__version__.py:
--------------------------------------------------------------------------------
1 | """Version information."""
2 |
3 | # The following line *must* be the last in the module, exactly as formatted:
4 | __version__ = '0.10.0'
5 |
--------------------------------------------------------------------------------
/marltoolkit/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_agent import BaseAgent
2 | from .idqn_agent import IDQNAgent
3 | from .qmix_agent import QMixAgent
4 | from .vdn_agent import VDNAgent
5 |
6 | __all__ = ['BaseAgent', 'QMixAgent', 'IDQNAgent', 'VDNAgent']
7 |
--------------------------------------------------------------------------------
/marltoolkit/agents/base_agent.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 |
4 | class BaseAgent(ABC):
5 |
6 | def __init__(self):
7 | pass
8 |
9 | def init_hidden_states(self, **kwargs):
10 | raise NotImplementedError
11 |
12 | def sample(self, **kwargs):
13 | raise NotImplementedError
14 |
15 | def predict(self, **kwargs):
16 | raise NotImplementedError
17 |
18 | def learn(self, **kwargs):
19 | raise NotImplementedError
20 |
21 | def save_model(self, **kwargs):
22 | raise NotImplementedError
23 |
24 | def load_model(self, **kwargs):
25 | raise NotImplementedError
26 |
--------------------------------------------------------------------------------
/marltoolkit/agents/maddpg_agent.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from marltoolkit.modules.actors.mlp import MLPActorModel
8 | from marltoolkit.modules.critics.mlp import MLPCriticModel
9 |
10 | from .base_agent import BaseAgent
11 |
12 |
13 | class MaddpgAgent(BaseAgent):
14 | """MADDPG algorithm.
15 |
16 | Args:
17 | model (parl.Model): forward network of actor and critic.
18 | The function get_actor_params() of model should be implemented.
19 | agent_index (int): index of agent, in multiagent env
20 | action_space (list): action_space, gym space
21 | gamma (float): discounted factor for reward computation.
22 | tau (float): decay coefficient when updating the weights of self.target_model with self.model
23 | critic_lr (float): learning rate of the critic model
24 | actor_lr (float): learning rate of the actor model
25 | """
26 |
27 | def __init__(
28 | self,
29 | actor_model: nn.Module,
30 | critic_model: nn.Module,
31 | n_agents: int = None,
32 | obs_shape: int = None,
33 | n_actions: int = None,
34 | gamma: float = 0.95,
35 | tau: float = 0.01,
36 | actor_lr: float = 0.01,
37 | critic_lr: float = 0.01,
38 | device: str = 'cpu',
39 | ):
40 |
41 | # checks
42 | assert isinstance(gamma, float)
43 | assert isinstance(tau, float)
44 | assert isinstance(actor_lr, float)
45 | assert isinstance(critic_lr, float)
46 |
47 | self.actor_lr = actor_lr
48 | self.critic_lr = critic_lr
49 | self.global_steps = 0
50 | self.device = device
51 |
52 | self.actor_model = [
53 | MLPActorModel(obs_shape, n_actions) for _ in range(n_agents)
54 | ]
55 | self.critic_model = [
56 | MLPCriticModel(n_agents, obs_shape, n_actions)
57 | for _ in range(n_agents)
58 | ]
59 | self.actor_target = copy.deepcopy(actor_model)
60 | self.critic_target = copy.deepcopy(critic_model)
61 | self.actor_optimizer = [
62 | torch.optim.Adam(model.parameters(), lr=self.actor_lr)
63 | for model in self.actor_model
64 | ]
65 | self.critic_optimizer = [
66 | torch.optim.Adam(model.parameters(), lr=self.critic_lr)
67 | for model in self.critic_model
68 | ]
69 |
70 | def sample(self, obs, use_target_model=False):
71 | """use the policy model to sample actions.
72 |
73 | Args:
74 | obs (torch tensor): observation, shape([B] + shape of obs_n[agent_index])
75 | use_target_model (bool): use target_model or not
76 |
77 | Returns:
78 | act (torch tensor): action, shape([B] + shape of act_n[agent_index]),
79 | noted that in the discrete case we take the argmax along the last axis as action
80 | """
81 | obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)
82 |
83 | if use_target_model:
84 | policy = self.actor_target(obs)
85 | else:
86 | policy = self.actor_model(obs)
87 |
88 | # add noise for action exploration
89 | if self.continuous_actions:
90 | random_normal = torch.randn(size=policy[0].shape).to(self.device)
91 | action = policy[0] + torch.exp(policy[1]) * random_normal
92 | action = torch.tanh(action)
93 | else:
94 | uniform = torch.rand_like(policy)
95 | soft_uniform = torch.log(-1.0 * torch.log(uniform)).to(self.device)
96 | action = F.softmax(policy - soft_uniform, dim=-1)
97 |
98 | action = action.detach().cpu().numpy().flatten()
99 | return action
100 |
101 | def predict(self, obs: torch.Tensor):
102 | """use the policy model to predict actions.
103 |
104 | Args:
105 | obs (torch tensor): observation, shape([B] + shape of obs_n[agent_index])
106 |
107 | Returns:
108 | act (torch tensor): action, shape([B] + shape of act_n[agent_index]),
109 | noted that in the discrete case we take the argmax along the last axis as action
110 | """
111 | obs = torch.FloatTensor(obs.reshape(1, -1)).to(self.device)
112 | policy = self.actor_model(obs)
113 | if self.continuous_actions:
114 | action = policy[0]
115 | action = torch.tanh(action)
116 | else:
117 | action = torch.argmax(policy, dim=-1)
118 |
119 | action = action.detach().cpu().numpy().flatten()
120 | return action
121 |
122 | def q_value(self, obs_n, act_n, use_target_model=False):
123 | """use the value model to predict Q values.
124 |
125 | Args:
126 | obs_n (list of torch tensor): all agents' observation, len(agent's num) + shape([B] + shape of obs_n)
127 | act_n (list of torch tensor): all agents' action, len(agent's num) + shape([B] + shape of act_n)
128 | use_target_model (bool): use target_model or not
129 |
130 | Returns:
131 | Q (torch tensor): Q value of this agent, shape([B])
132 | """
133 | if use_target_model:
134 | return self.critic_target.value(obs_n, act_n)
135 | else:
136 | return self.critic_model.value(obs_n, act_n)
137 |
138 | def agent_learn(self, obs_n, act_n, target_q):
139 | """update actor and critic model with MADDPG algorithm."""
140 | self.global_steps += 1
141 |
142 | def learn(self, obs_n, act_n, target_q):
143 | """update actor and critic model with MADDPG algorithm."""
144 | acotr_loss = self.actor_learn(obs_n, act_n)
145 | critic_loss = self.critic_learn(obs_n, act_n, target_q)
146 | return acotr_loss, critic_loss
147 |
148 | def actor_learn(self, obs_n, act_n):
149 | i = self.agent_index
150 |
151 | sample_this_action = self.sample(obs_n[i])
152 | action_input_n = act_n + []
153 | action_input_n[i] = sample_this_action
154 | eval_q = self.q_value(obs_n, action_input_n)
155 | act_cost = torch.mean(-1.0 * eval_q)
156 |
157 | this_policy = self.actor_model.policy(obs_n[i])
158 | # when continuous, 'this_policy' will be a tuple with two element: (mean, std)
159 | if self.continuous_actions:
160 | this_policy = torch.cat(this_policy, dim=-1)
161 | act_reg = torch.mean(torch.square(this_policy))
162 |
163 | cost = act_cost + act_reg * 1e-3
164 |
165 | self.actor_optimizer.zero_grad()
166 | cost.backward()
167 | torch.nn.utils.clip_grad_norm_(self.actor_model.parameters(), 0.5)
168 | self.actor_optimizer.step()
169 | return cost
170 |
171 | def critic_learn(self, obs_n, act_n, target_q):
172 | pred_q = self.q_value(obs_n, act_n)
173 | cost = F.mse_loss(pred_q, target_q)
174 |
175 | self.critic_optimizer.zero_grad()
176 | cost.backward()
177 | torch.nn.utils.clip_grad_norm_(self.critic_model.parameters(), 0.5)
178 | self.critic_optimizer.step()
179 | return cost
180 |
181 | def update_target(self, tau):
182 | for target_param, param in zip(self.actor_target.parameters(),
183 | self.actor_model.parameters()):
184 | target_param.data.copy_(tau * param.data +
185 | (1 - tau) * target_param.data)
186 | for target_param, param in zip(self.critic_target.parameters(),
187 | self.critic_model.parameters()):
188 | target_param.data.copy_(tau * param.data +
189 | (1 - tau) * target_param.data)
190 |
--------------------------------------------------------------------------------
/marltoolkit/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .ma_buffer import EpisodeData, ReplayBuffer
2 | from .offpolicy_buffer import (BaseBuffer, MaEpisodeData, OffPolicyBuffer,
3 | OffPolicyBufferRNN)
4 |
5 | __all__ = [
6 | 'ReplayBuffer',
7 | 'EpisodeData',
8 | 'MaEpisodeData',
9 | 'OffPolicyBuffer',
10 | 'BaseBuffer',
11 | 'OffPolicyBufferRNN',
12 | ]
13 |
--------------------------------------------------------------------------------
/marltoolkit/data/base_buffer.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | class BaseBuffer(ABC):
9 | """Abstract base class for all buffers.
10 |
11 | This class provides a basic structure for reinforcement learning experience
12 | replay buffers.
13 |
14 | Args:
15 | :param num_envs: Number of environments.
16 | :param num_agents: Number of agents.
17 | :param obs_shape: Dimensionality of the observation space.
18 | :param state_shape: Dimensionality of the state space.
19 | :param action_shape: Dimensionality of the action space.
20 | :param reward_shape: Dimensionality of the reward space.
21 | :param done_shape: Dimensionality of the done space.
22 | :param device: Device on which to store the buffer data.
23 | :param kwargs: Additional keyword arguments.
24 | """
25 |
26 | def __init__(
27 | self,
28 | num_envs: int,
29 | num_agents: int,
30 | obs_shape: Union[int, Tuple],
31 | state_shape: Union[int, Tuple],
32 | action_shape: Union[int, Tuple],
33 | reward_shape: Union[int, Tuple],
34 | done_shape: Union[int, Tuple],
35 | device: Union[torch.device, str] = 'cpu',
36 | **kwargs,
37 | ) -> None:
38 | super().__init__()
39 | self.num_envs = num_envs
40 | self.num_agents = num_agents
41 | self.obs_shape = obs_shape
42 | self.state_shape = state_shape
43 | self.action_shape = action_shape
44 | self.reward_shape = reward_shape
45 | self.done_shape = done_shape
46 | self.device = device
47 | self.curr_ptr = 0
48 | self.curr_size = 0
49 |
50 | def reset(self) -> None:
51 | """Reset the buffer."""
52 | self.curr_ptr = 0
53 | self.curr_size = 0
54 |
55 | @abstractmethod
56 | def store(self, *args) -> None:
57 | """Add elements to the buffer."""
58 | raise NotImplementedError
59 |
60 | @abstractmethod
61 | def extend(self, *args, **kwargs) -> None:
62 | """Add a new batch of transitions to the buffer."""
63 | for data in zip(*args):
64 | self.store(*data)
65 |
66 | @abstractmethod
67 | def sample(self, **kwargs):
68 | """Sample elements from the buffer."""
69 | raise NotImplementedError
70 |
71 | @abstractmethod
72 | def store_transitions(self, **kwargs):
73 | """Store transitions in the buffer."""
74 | raise NotImplementedError
75 |
76 | @abstractmethod
77 | def store_episodes(self, **kwargs):
78 | """Store episodes in the buffer."""
79 | raise NotImplementedError
80 |
81 | @abstractmethod
82 | def finish_path(self, **kwargs):
83 | """Finish a trajectory path in the buffer."""
84 | raise NotImplementedError
85 |
86 | def size(self) -> int:
87 | """Get the current size of the buffer.
88 |
89 | :return: The current size of the buffer.
90 | """
91 | return self.curr_size
92 |
93 | def to_torch(self, array: np.ndarray, copy: bool = True) -> torch.Tensor:
94 | """Convert a numpy array to a PyTorch tensor.
95 |
96 | :param array: Numpy array to be converted.
97 | :param copy: Whether to copy the data or not.
98 | :return: PyTorch tensor.
99 | """
100 | if copy:
101 | return torch.tensor(array).to(self.device)
102 | return torch.as_tensor(array).to(self.device)
103 |
--------------------------------------------------------------------------------
/marltoolkit/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_env import MultiAgentEnv
2 | from .smacv1.smac_env import SMACWrapperEnv
3 | from .vec_env import BaseVecEnv, DummyVecEnv, SubprocVecEnv
4 |
5 | __all__ = [
6 | 'SMACWrapperEnv',
7 | 'BaseVecEnv',
8 | 'DummyVecEnv',
9 | 'SubprocVecEnv',
10 | 'MultiAgentEnv',
11 | ]
12 |
--------------------------------------------------------------------------------
/marltoolkit/envs/base_env.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import gymnasium as gym
4 |
5 |
6 | class MultiAgentEnv(gym.Env):
7 | """A multi-agent environment wrapper. An environment that hosts multiple
8 | independent agents.
9 |
10 | Agents are identified by (string) agent ids. Note that these "agents" here
11 | are not to be confused with RLlib Algorithms, which are also sometimes
12 | referred to as "agents" or "RL agents".
13 |
14 | The preferred format for action- and observation space is a mapping from agent
15 | ids to their individual spaces. If that is not provided, the respective methods'
16 | observation_space_contains(), action_space_contains(),
17 | action_space_sample() and observation_space_sample() have to be overwritten.
18 | """
19 |
20 | def __init__(self) -> None:
21 | self.num_agents = None
22 | self.n_actions = None
23 | self.episode_limit = None
24 | if not hasattr(self, 'obs_space'):
25 | self.observation_space = None
26 | if not hasattr(self, 'state_space'):
27 | self.observation_space = None
28 | if not hasattr(self, 'action_space'):
29 | self.action_space = None
30 | if not hasattr(self, '_agent_ids'):
31 | self._agent_ids = set()
32 |
33 | def step(self, actions):
34 | """Returns reward, terminated, info."""
35 | raise NotImplementedError
36 |
37 | def get_obs(self):
38 | """Returns all agent observations in a list."""
39 | raise NotImplementedError
40 |
41 | def get_obs_agent(self, agent_id):
42 | """Returns observation for agent_id."""
43 | raise NotImplementedError
44 |
45 | def get_obs_size(self):
46 | """Returns the shape of the observation."""
47 | raise NotImplementedError
48 |
49 | def get_state(self):
50 | raise NotImplementedError
51 |
52 | def get_state_size(self):
53 | """Returns the shape of the state."""
54 | raise NotImplementedError
55 |
56 | def get_avail_agent_actions(self, agent_id):
57 | """Returns the available actions for agent_id."""
58 | raise NotImplementedError
59 |
60 | def get_available_actions(self):
61 | raise NotImplementedError
62 |
63 | def get_actions_one_hot(self):
64 | raise NotImplementedError
65 |
66 | def get_agents_id_one_hot(self):
67 | raise NotImplementedError
68 |
69 | def reset(
70 | self,
71 | seed: Optional[int] = None,
72 | options: Optional[dict] = None,
73 | ):
74 | """Returns initial observations and states."""
75 | super().reset(seed=seed, options=options)
76 |
77 | def render(self):
78 | raise NotImplementedError
79 |
80 | def close(self):
81 | raise NotImplementedError
82 |
83 | def seed(self):
84 | raise NotImplementedError
85 |
86 | def save_replay(self):
87 | raise NotImplementedError
88 |
89 | def get_env_info(self):
90 | env_info = {
91 | 'obs_shape': self.get_obs_size(),
92 | 'state_shape': self.get_state_size(),
93 | 'num_agents': self.num_agents,
94 | 'actions_shape': self.n_actions,
95 | 'episode_limit': self.episode_limit,
96 | }
97 | return env_info
98 |
--------------------------------------------------------------------------------
/marltoolkit/envs/pettingzoo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/envs/pettingzoo/__init__.py
--------------------------------------------------------------------------------
/marltoolkit/envs/pettingzoo/custom_env2.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import gymnasium
4 | from gymnasium.spaces import Discrete
5 | from pettingzoo import ParallelEnv
6 | from pettingzoo.utils import parallel_to_aec, wrappers
7 |
8 | ROCK = 0
9 | PAPER = 1
10 | SCISSORS = 2
11 | NONE = 3
12 | MOVES = ['ROCK', 'PAPER', 'SCISSORS', 'None']
13 | NUM_ITERS = 100
14 | REWARD_MAP = {
15 | (ROCK, ROCK): (0, 0),
16 | (ROCK, PAPER): (-1, 1),
17 | (ROCK, SCISSORS): (1, -1),
18 | (PAPER, ROCK): (1, -1),
19 | (PAPER, PAPER): (0, 0),
20 | (PAPER, SCISSORS): (-1, 1),
21 | (SCISSORS, ROCK): (-1, 1),
22 | (SCISSORS, PAPER): (1, -1),
23 | (SCISSORS, SCISSORS): (0, 0),
24 | }
25 |
26 |
27 | def env(render_mode=None):
28 | """The env function often wraps the environment in wrappers by default.
29 |
30 | You can find full documentation for these methods elsewhere in the
31 | developer documentation.
32 | """
33 | internal_render_mode = render_mode if render_mode != 'ansi' else 'human'
34 | env = raw_env(render_mode=internal_render_mode)
35 | # This wrapper is only for environments which print results to the terminal
36 | if render_mode == 'ansi':
37 | env = wrappers.CaptureStdoutWrapper(env)
38 | # this wrapper helps error handling for discrete action spaces
39 | env = wrappers.AssertOutOfBoundsWrapper(env)
40 | # Provides a wide vareity of helpful user errors
41 | # Strongly recommended
42 | env = wrappers.OrderEnforcingWrapper(env)
43 | return env
44 |
45 |
46 | def raw_env(render_mode=None):
47 | """To support the AEC API, the raw_env() function just uses the
48 | from_parallel function to convert from a ParallelEnv to an AEC env."""
49 | env = parallel_env(render_mode=render_mode)
50 | env = parallel_to_aec(env)
51 | return env
52 |
53 |
54 | class parallel_env(ParallelEnv):
55 | metadata = {'render_modes': ['human'], 'name': 'rps_v2'}
56 |
57 | def __init__(self, render_mode=None):
58 | """The init method takes in environment arguments and should define the
59 | following attributes:
60 |
61 | - possible_agents
62 | - render_mode
63 |
64 | Note: as of v1.18.1, the action_spaces and observation_spaces attributes are deprecated.
65 | Spaces should be defined in the action_space() and observation_space() methods.
66 | If these methods are not overridden, spaces will be inferred from self.observation_spaces/action_spaces, raising a warning.
67 |
68 | These attributes should not be changed after initialization.
69 | """
70 | self.possible_agents = ['player_' + str(r) for r in range(2)]
71 |
72 | # optional: a mapping between agent name and ID
73 | self.agent_name_mapping = dict(
74 | zip(self.possible_agents, list(range(len(self.possible_agents)))))
75 | self.render_mode = render_mode
76 |
77 | # Observation space should be defined here.
78 | # lru_cache allows observation and action spaces to be memoized, reducing clock cycles required to get each agent's space.
79 | # If your spaces change over time, remove this line (disable caching).
80 | @functools.lru_cache(maxsize=None)
81 | def observation_space(self, agent):
82 | # gymnasium spaces are defined and documented here: https://gymnasium.farama.org/api/spaces/
83 | return Discrete(4)
84 |
85 | # Action space should be defined here.
86 | # If your spaces change over time, remove this line (disable caching).
87 | @functools.lru_cache(maxsize=None)
88 | def action_space(self, agent):
89 | return Discrete(3)
90 |
91 | def render(self):
92 | """Renders the environment.
93 |
94 | In human mode, it can print to terminal, open up a graphical window, or
95 | open up some other display that a human can see and understand.
96 | """
97 | if self.render_mode is None:
98 | gymnasium.logger.warn(
99 | 'You are calling render method without specifying any render mode.'
100 | )
101 | return
102 |
103 | if len(self.agents) == 2:
104 | string = 'Current state: Agent1: {} , Agent2: {}'.format(
105 | MOVES[self.state[self.agents[0]]],
106 | MOVES[self.state[self.agents[1]]])
107 | else:
108 | string = 'Game over'
109 | print(string)
110 |
111 | def close(self):
112 | """Close should release any graphical displays, subprocesses, network
113 | connections or any other environment data which should not be kept
114 | around after the user is no longer using the environment."""
115 | pass
116 |
117 | def reset(self, seed=None, options=None):
118 | """Reset needs to initialize the `agents` attribute and must set up the
119 | environment so that render(), and step() can be called without issues.
120 |
121 | Here it initializes the `num_moves` variable which counts the number of
122 | hands that are played. Returns the observations for each agent
123 | """
124 | self.agents = self.possible_agents[:]
125 | self.num_moves = 0
126 | observations = {agent: NONE for agent in self.agents}
127 | infos = {agent: {} for agent in self.agents}
128 | self.state = observations
129 |
130 | return observations, infos
131 |
132 | def step(self, actions):
133 | """step(action) takes in an action for each agent and should return
134 | the.
135 |
136 | - observations
137 | - rewards
138 | - terminations
139 | - truncations
140 | - infos
141 | dicts where each dict looks like {agent_1: item_1, agent_2: item_2}
142 | """
143 | # If a user passes in actions with no agents, then just return empty observations, etc.
144 | if not actions:
145 | self.agents = []
146 | return {}, {}, {}, {}, {}
147 |
148 | # rewards for all agents are placed in the rewards dictionary to be returned
149 | rewards = {}
150 | rewards[self.agents[0]], rewards[self.agents[1]] = REWARD_MAP[(
151 | actions[self.agents[0]], actions[self.agents[1]])]
152 |
153 | terminations = {agent: False for agent in self.agents}
154 |
155 | self.num_moves += 1
156 | env_truncation = self.num_moves >= NUM_ITERS
157 | truncations = {agent: env_truncation for agent in self.agents}
158 |
159 | # current observation is just the other player's most recent action
160 | observations = {
161 | self.agents[i]: int(actions[self.agents[1 - i]])
162 | for i in range(len(self.agents))
163 | }
164 | self.state = observations
165 |
166 | # typically there won't be any information in the infos, but there must
167 | # still be an entry for each agent
168 | infos = {agent: {} for agent in self.agents}
169 |
170 | if env_truncation:
171 | self.agents = []
172 |
173 | if self.render_mode == 'human':
174 | self.render()
175 | return observations, rewards, terminations, truncations, infos
176 |
--------------------------------------------------------------------------------
/marltoolkit/envs/pettingzoo/pettingzoo_env.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from typing import Any, Dict, List, Tuple
3 |
4 | import numpy as np
5 | from gymnasium import spaces
6 | from pettingzoo.utils.env import AECEnv
7 | from pettingzoo.utils.wrappers import BaseWrapper
8 |
9 |
10 | class PettingZooEnv(AECEnv, ABC):
11 |
12 | def __init__(self, env: BaseWrapper):
13 | super(PettingZooEnv, self).__init__()
14 | self.env = env
15 | self.agents = self.env.possible_agents
16 | self.agent_idx = {}
17 | for i, agent_id in enumerate(self.agents):
18 | self.agent_idx[agent_id] = i
19 |
20 | self.action_spaces = {
21 | k: self.env.action_space(k)
22 | for k in self.env.agents
23 | }
24 | self.observation_spaces = {
25 | k: self.env.observation_space(k)
26 | for k in self.env.agents
27 | }
28 | assert all(
29 | self.observation_space(agent) == self.env.observation_space(
30 | self.agents[0])
31 | for agent in self.agents), (
32 | 'Observation spaces for all agents must be identical. Perhaps '
33 | "SuperSuit's pad_observations wrapper can help (useage: "
34 | '`supersuit.aec_wrappers.pad_observations(env)`')
35 |
36 | assert all(
37 | self.action_space(agent) == self.env.action_space(self.agents[0])
38 | for agent in self.agents), (
39 | 'Action spaces for all agents must be identical. Perhaps '
40 | "SuperSuit's pad_action_space wrapper can help (useage: "
41 | '`supersuit.aec_wrappers.pad_action_space(env)`')
42 | try:
43 | self.state_space = self.env.state_space
44 | except Exception:
45 | self.state_space = None
46 |
47 | self.rewards = [0] * len(self.agents)
48 | self.metadata = self.env.metadata
49 | self.env.reset()
50 |
51 | def reset(self, *args: Any, **kwargs: Any) -> Tuple[dict, dict]:
52 | self.env.reset(*args, **kwargs)
53 |
54 | observation, reward, terminated, truncated, info = self.env.last(self)
55 |
56 | if isinstance(observation, dict) and 'action_mask' in observation:
57 | observation_dict = {
58 | 'agent_id':
59 | self.env.agent_selection,
60 | 'obs':
61 | observation['observation'],
62 | 'mask': [
63 | True if obm == 1 else False
64 | for obm in observation['action_mask']
65 | ],
66 | }
67 | else:
68 | if isinstance(self.action_space, spaces.Discrete):
69 | observation_dict = {
70 | 'agent_id':
71 | self.env.agent_selection,
72 | 'obs':
73 | observation,
74 | 'mask':
75 | [True] * self.env.action_space(self.env.agent_selection).n,
76 | }
77 | else:
78 | observation_dict = {
79 | 'agent_id': self.env.agent_selection,
80 | 'obs': observation,
81 | }
82 |
83 | return observation_dict, info
84 |
85 | def step(self, action: Any) -> Tuple[Dict, List[int], bool, bool, Dict]:
86 | self.env.step(action)
87 |
88 | observation, reward, term, trunc, info = self.env.last()
89 |
90 | if isinstance(observation, dict) and 'action_mask' in observation:
91 | obs = {
92 | 'agent_id':
93 | self.env.agent_selection,
94 | 'obs':
95 | observation['observation'],
96 | 'mask': [
97 | True if obm == 1 else False
98 | for obm in observation['action_mask']
99 | ],
100 | }
101 | else:
102 | if isinstance(self.action_space, spaces.Discrete):
103 | obs = {
104 | 'agent_id':
105 | self.env.agent_selection,
106 | 'obs':
107 | observation,
108 | 'mask':
109 | [True] * self.env.action_space(self.env.agent_selection).n,
110 | }
111 | else:
112 | obs = {
113 | 'agent_id': self.env.agent_selection,
114 | 'obs': observation
115 | }
116 |
117 | for agent_id, reward in self.env.rewards.items():
118 | self.rewards[self.agent_idx[agent_id]] = reward
119 | return obs, self.rewards, term, trunc, info
120 |
121 | def state(self):
122 | try:
123 | return np.array(self.env.state())
124 | except Exception:
125 | return None
126 |
127 | def seed(self, seed: Any = None) -> None:
128 | try:
129 | self.env.seed(seed)
130 | except (NotImplementedError, AttributeError):
131 | self.env.reset(seed=seed)
132 |
133 | def render(self) -> Any:
134 | return self.env.render()
135 |
136 | def close(self):
137 | self.env.close()
138 |
--------------------------------------------------------------------------------
/marltoolkit/envs/smacv1/__init__.py:
--------------------------------------------------------------------------------
1 | from .smac_env import SMACWrapperEnv
2 |
3 | __all__ = ['SMACWrapperEnv']
4 |
--------------------------------------------------------------------------------
/marltoolkit/envs/smacv2/__init__.py:
--------------------------------------------------------------------------------
1 | from .smacv2_env import SMACv2Env
2 |
3 | __all__ = ['SMACv2Env']
4 |
--------------------------------------------------------------------------------
/marltoolkit/envs/smacv2/smacv2_env.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | from gymnasium.spaces import Discrete
5 | from smacv2.env.starcraft2.wrapper import StarCraftCapabilityEnvWrapper
6 |
7 |
8 | class SMACv2Env(StarCraftCapabilityEnvWrapper):
9 |
10 | def __init__(self, **kwargs):
11 | super(SMACv2Env, self).__init__(obs_last_action=False, **kwargs)
12 | self.action_space = []
13 | self.observation_space = []
14 | self.share_observation_space = []
15 |
16 | self.n_agents = self.env.n_agents
17 |
18 | for i in range(self.env.n_agents):
19 | self.action_space.append(Discrete(self.env.n_actions))
20 | self.observation_space.append([self.env.get_obs_size()])
21 | self.share_observation_space.append([self.env.get_state_size()])
22 |
23 | def seed(self, seed):
24 | random.seed(seed)
25 | np.random.seed(seed)
26 |
27 | def step(self, actions):
28 | reward, terminated, info = self.env.step(actions)
29 | obs = self.env.get_obs()
30 | state = self.env.get_state()
31 | global_state = [state] * self.n_agents
32 | rewards = [[reward]] * self.n_agents
33 | dones = [terminated] * self.n_agents
34 | infos = [info] * self.n_agents
35 | avail_actions = self.env.get_avail_actions()
36 |
37 | bad_transition = True if self.env._episode_steps >= self.env.episode_limit else False
38 | for info in infos:
39 | info['bad_transition'] = bad_transition
40 | info['battles_won'] = self.env.battles_won
41 | info['battles_game'] = self.env.battles_game
42 | info['battles_draw'] = self.env.timeouts
43 | info['restarts'] = self.env.force_restarts
44 | info['won'] = self.env.win_counted
45 |
46 | return obs, global_state, rewards, dones, infos, avail_actions
47 |
48 | def reset(self):
49 | obs, state = super().reset()
50 | state = [state for i in range(self.env.n_agents)]
51 | avail_actions = self.env.get_avail_actions()
52 | return obs, state, avail_actions
53 |
--------------------------------------------------------------------------------
/marltoolkit/envs/vec_env/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_vec_env import BaseVecEnv, CloudpickleWrapper
2 | from .dummy_vec_env import DummyVecEnv
3 | from .subproc_vec_env import SubprocVecEnv
4 |
5 | __all__ = [
6 | 'BaseVecEnv',
7 | 'DummyVecEnv',
8 | 'SubprocVecEnv',
9 | 'CloudpickleWrapper',
10 | ]
11 |
--------------------------------------------------------------------------------
/marltoolkit/envs/waregame/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/envs/waregame/__init__.py
--------------------------------------------------------------------------------
/marltoolkit/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/modules/__init__.py
--------------------------------------------------------------------------------
/marltoolkit/modules/actors/__init__.py:
--------------------------------------------------------------------------------
1 | """Utils package."""
2 |
3 | from .mlp import MLPActorModel
4 | from .rnn import RNNActorModel
5 |
6 | __all__ = [
7 | 'RNNActorModel',
8 | 'MLPActorModel',
9 | ]
10 |
--------------------------------------------------------------------------------
/marltoolkit/modules/actors/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | # Initialize Policy weights
6 | def weights_init_(module: nn.Module):
7 | if isinstance(module, nn.Linear):
8 | nn.init.xavier_uniform_(module.weight, gain=1)
9 | nn.init.constant_(module.bias, 0)
10 |
11 |
12 | class MLPActorModel(nn.Module):
13 |
14 | def __init__(
15 | self,
16 | input_dim: int = None,
17 | hidden_dim: int = 64,
18 | n_actions: int = None,
19 | ) -> None:
20 | """Initialize the Actor network.
21 |
22 | Args:
23 | input_dim (int, optional): obs, include the agent's id and last action,
24 | shape: (batch, obs_shape + n_action + n_agents)
25 | hidden_dim (int, optional): hidden size of the network. Defaults to 64.
26 | n_actions (int, optional): number of actions. Defaults to None.
27 | """
28 | super(MLPActorModel, self).__init__()
29 |
30 | self.fc1 = nn.Linear(input_dim, hidden_dim)
31 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
32 | self.fc3 = nn.Linear(hidden_dim, n_actions)
33 | self.relu1 = nn.ReLU(inplace=True)
34 | self.relu2 = nn.ReLU(inplace=True)
35 | self.apply(weights_init_)
36 |
37 | def forward(self, obs: torch.Tensor):
38 | hid1 = self.relu1(self.fc1(obs))
39 | hid2 = self.relu2(self.fc2(hid1))
40 | policy = self.fc3(hid2)
41 | return policy
42 |
--------------------------------------------------------------------------------
/marltoolkit/modules/actors/r_actor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from marltoolkit.modules.utils.act import ACTLayer
5 | from marltoolkit.modules.utils.common import MLPBase, RNNLayer
6 |
7 |
8 | class R_Actor(nn.Module):
9 | """Actor network class for MAPPO. Outputs actions given observations.
10 |
11 | :param args: (argparse.Namespace) arguments containing relevant model information.
12 | :param obs_space: (gym.Space) observation space.
13 | :param action_space: (gym.Space) action space.
14 | :param device: (torch.device) specifies the device to run on (cpu/gpu).
15 | """
16 |
17 | def __init__(self, args) -> None:
18 | super(R_Actor, self).__init__()
19 |
20 | self.use_recurrent_policy = args.use_recurrent_policy
21 | self.use_policy_active_masks = args.use_policy_active_masks
22 | self.base = MLPBase(
23 | input_dim=args.actor_input_dim,
24 | hidden_dim=args.hidden_size,
25 | activation=args.activation,
26 | use_orthogonal=args.use_orthogonal,
27 | use_feature_normalization=args.use_feature_normalization,
28 | )
29 | if args.use_recurrent_policy:
30 | self.rnn = RNNLayer(
31 | args.hidden_size,
32 | args.hidden_size,
33 | args.rnn_layers,
34 | args.use_orthogonal,
35 | )
36 | self.act = ACTLayer(args)
37 | self.algorithm_name = args.algorithm_name
38 |
39 | def forward(
40 | self,
41 | obs: torch.Tensor,
42 | masks: torch.Tensor,
43 | available_actions: torch.Tensor = None,
44 | rnn_hidden_states: torch.Tensor = None,
45 | deterministic: bool = False,
46 | ):
47 | """Compute actions from the given inputs.
48 |
49 | :param obs: (np.ndarray / torch.Tensor) observation inputs into network.
50 | :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
51 | :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent
52 | (if None, all actions available)
53 | :param rnn_hidden_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
54 |
55 | :param deterministic: (bool) whether to sample from action distribution or return the mode.
56 |
57 | :return actions: (torch.Tensor) actions to take.
58 | :return action_log_probs: (torch.Tensor) log probabilities of taken actions.
59 | :return rnn_hidden_states: (torch.Tensor) updated RNN hidden states.
60 | """
61 | actor_features = self.base(obs)
62 |
63 | if self.use_recurrent_policy:
64 | actor_features, rnn_hidden_states = self.rnn(
65 | actor_features, rnn_hidden_states, masks)
66 |
67 | actions, action_log_probs = self.act(actor_features, available_actions,
68 | deterministic)
69 |
70 | return actions, action_log_probs, rnn_hidden_states
71 |
72 | def evaluate_actions(
73 | self,
74 | obs: torch.Tensor,
75 | action: torch.Tensor,
76 | masks: torch.Tensor,
77 | active_masks: torch.Tensor = None,
78 | available_actions: torch.Tensor = None,
79 | rnn_hidden_states: torch.Tensor = None,
80 | ):
81 | """Compute log probability and entropy of given actions.
82 |
83 | :param obs: (torch.Tensor) observation inputs into network.
84 | :param action: (torch.Tensor) actions whose entropy and log probability to evaluate.
85 | :param rnn_hidden_states: (torch.Tensor) if RNN network, hidden states for RNN.
86 | :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros.
87 | :param available_actions: (torch.Tensor) denotes which actions are available to agent
88 | (if None, all actions available)
89 | :param active_masks: (torch.Tensor) denotes whether an agent is active or dead.
90 |
91 | :return action_log_probs: (torch.Tensor) log probabilities of the input actions.
92 | :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs.
93 | """
94 | actor_features = self.base(obs)
95 |
96 | if self.use_recurrent_policy:
97 | actor_features, rnn_hidden_states = self.rnn(
98 | actor_features, rnn_hidden_states, masks)
99 |
100 | if self.algorithm_name == 'hatrpo':
101 | action_log_probs, dist_entropy, action_mu, action_std, all_probs = (
102 | self.act.evaluate_actions_trpo(
103 | actor_features,
104 | action,
105 | available_actions,
106 | active_masks=active_masks
107 | if self.use_policy_active_masks else None,
108 | ))
109 |
110 | return action_log_probs, dist_entropy, action_mu, action_std, all_probs
111 | else:
112 | action_log_probs, dist_entropy = self.act.evaluate_actions(
113 | actor_features,
114 | action,
115 | available_actions,
116 | active_masks=active_masks
117 | if self.use_policy_active_masks else None,
118 | )
119 |
120 | return action_log_probs, dist_entropy
121 |
--------------------------------------------------------------------------------
/marltoolkit/modules/actors/rnn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class RNNActorModel(nn.Module):
9 | """Because all the agents share the same network,
10 | input_shape=obs_shape+n_actions+n_agents.
11 |
12 | Args:
13 | input_dim (int): The input dimension.
14 | fc_hidden_dim (int): The hidden dimension of the fully connected layer.
15 | rnn_hidden_dim (int): The hidden dimension of the RNN layer.
16 | n_actions (int): The number of actions.
17 | """
18 |
19 | def __init__(self, args: argparse.Namespace = None) -> None:
20 | super(RNNActorModel, self).__init__()
21 | self.args = args
22 | self.rnn_hidden_dim = args.rnn_hidden_dim
23 | self.fc1 = nn.Linear(self.actor_input_dim, args.fc_hidden_dim)
24 | self.rnn = nn.GRUCell(input_size=args.fc_hidden_dim,
25 | hidden_size=args.rnn_hidden_dim)
26 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.n_actions)
27 |
28 | def init_hidden(self):
29 | # make hidden states on same device as model
30 | return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_()
31 |
32 | def forward(
33 | self,
34 | inputs: torch.Tensor = None,
35 | hidden_state: torch.Tensor = None,
36 | ) -> tuple[torch.Tensor, torch.Tensor]:
37 | out = F.relu(self.fc1(inputs), inplace=True)
38 |
39 | if hidden_state is not None:
40 | h_in = hidden_state.reshape(-1, self.rnn_hidden_dim)
41 | else:
42 | h_in = torch.zeros(out.shape[0],
43 | self.rnn_hidden_dim).to(inputs.device)
44 |
45 | hidden_state = self.rnn(out, h_in)
46 | out = self.fc2(hidden_state) # (batch_size, n_actions)
47 | return out, hidden_state
48 |
49 | def update(self, model: nn.Module) -> None:
50 | self.load_state_dict(model.state_dict())
51 |
52 |
53 | class MultiLayerRNNActorModel(nn.Module):
54 | """Because all the agents share the same network,
55 | input_shape=obs_shape+n_actions+n_agents.
56 |
57 | Args:
58 | input_dim (int): The input dimension.
59 | fc_hidden_dim (int): The hidden dimension of the fully connected layer.
60 | rnn_hidden_dim (int): The hidden dimension of the RNN layer.
61 | n_actions (int): The number of actions.
62 | """
63 |
64 | def __init__(
65 | self,
66 | input_dim: int = None,
67 | fc_hidden_dim: int = 64,
68 | rnn_hidden_dim: int = 64,
69 | rnn_num_layers: int = 2,
70 | n_actions: int = None,
71 | **kwargs,
72 | ) -> None:
73 | super(MultiLayerRNNActorModel, self).__init__()
74 |
75 | self.rnn_hidden_dim = rnn_hidden_dim
76 | self.rnn_num_layers = rnn_num_layers
77 |
78 | self.fc1 = nn.Linear(input_dim, fc_hidden_dim)
79 | self.rnn = nn.GRU(
80 | input_size=fc_hidden_dim,
81 | hidden_size=rnn_hidden_dim,
82 | num_layers=rnn_num_layers,
83 | batch_first=True,
84 | )
85 | self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)
86 |
87 | def init_hidden(self, batch_size: int) -> torch.Tensor:
88 | # make hidden states on same device as model
89 | return self.fc1.weight.new(self.rnn_num_layers, batch_size,
90 | self.rnn_hidden_dim).zero_()
91 |
92 | def forward(
93 | self,
94 | input: torch.Tensor = None,
95 | hidden_state: torch.Tensor = None,
96 | ) -> tuple[torch.Tensor, torch.Tensor]:
97 | # input: (batch_size, episode_length, obs_dim)
98 | out = F.relu(self.fc1(input), inplace=True)
99 | # out: (batch_size, episode_length, fc_hidden_dim)
100 | batch_size = input.shape[0]
101 | if hidden_state is None:
102 | hidden_state = self.init_hidden(batch_size)
103 | else:
104 | hidden_state = hidden_state.reshape(
105 | self.rnn_num_layers, batch_size,
106 | self.rnn_hidden_dim).to(input.device)
107 |
108 | out, hidden_state = self.rnn(out, hidden_state)
109 | # out: (batch_size, seq_len, rnn_hidden_dim)
110 | # hidden_state: (num_ayers, batch_size, rnn_hidden_dim)
111 | logits = self.fc2(out)
112 | return logits, hidden_state
113 |
114 | def update(self, model: nn.Module) -> None:
115 | self.load_state_dict(model.state_dict())
116 |
117 |
118 | if __name__ == '__main__':
119 | rnn = nn.GRU(input_size=512,
120 | hidden_size=256,
121 | num_layers=2,
122 | batch_first=True)
123 | input = torch.randn(3, 512)
124 | h0 = torch.randn(2, 256)
125 | output, hn = rnn(input, h0)
126 | print(output.shape, hn.shape)
127 | # torch.Size([32, 512, 20]) torch.Size([2, 32, 20])
128 |
--------------------------------------------------------------------------------
/marltoolkit/modules/critics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/modules/critics/__init__.py
--------------------------------------------------------------------------------
/marltoolkit/modules/critics/coma.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from marltoolkit.modules.actors.mlp import MLPActorModel
4 | from marltoolkit.modules.critics.mlp import MLPCriticModel
5 |
6 |
7 | class ComaModel(nn.Module):
8 |
9 | def __init__(
10 | self,
11 | num_agents: int = None,
12 | n_actions: int = None,
13 | obs_shape: int = None,
14 | state_shape: int = None,
15 | hidden_dim: int = 64,
16 | **kwargs,
17 | ):
18 | super(ComaModel, self).__init__()
19 |
20 | self.num_agents = num_agents
21 | self.obs_shape = obs_shape
22 | self.state_shape = state_shape
23 | self.n_actions = n_actions
24 | self.actions_onehot_shape = self.n_actions * self.num_agents
25 | actor_input_dim = self._get_actor_input_dim()
26 | critic_input_dim = self._get_critic_input_dim()
27 |
28 | # Set up network layers
29 | self.actor_model = MLPActorModel(input_dim=actor_input_dim,
30 | hidden_dim=hidden_dim,
31 | n_actions=n_actions)
32 | self.critic_model = MLPCriticModel(input_dim=critic_input_dim,
33 | hidden_dim=hidden_dim,
34 | output_dim=1)
35 |
36 | def policy(self, obs, hidden_state):
37 | return self.actor_model(obs, hidden_state)
38 |
39 | def value(self, inputs):
40 | return self.critic_model(inputs)
41 |
42 | def get_actor_params(self):
43 | return self.actor_model.parameters()
44 |
45 | def get_critic_params(self):
46 | return self.critic_model.parameters()
47 |
48 | def _get_actor_input_dim(self):
49 | # observation
50 | input_dim = self.obs_shape
51 | # agent id
52 | input_dim += self.num_agents
53 | # actions and last actions
54 | input_dim += self.actions_onehot_shape * self.num_agents * 2
55 | return input_dim
56 |
57 | def _get_critic_input_dim(self):
58 | input_dim = self.state_shape # state: 48 in 3m map
59 | input_dim += self.obs_shape # obs: 30 in 3m map
60 | input_dim += self.num_agents # agent_id: 3 in 3m map
61 | input_dim += (
62 | self.n_actions * self.num_agents * 2
63 | ) # all agents' action and last_action (one-hot): 54 in 3m map
64 | return input_dim # 48 + 30+ 3 = 135
65 |
--------------------------------------------------------------------------------
/marltoolkit/modules/critics/maddpg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MADDPGCritic(nn.Module):
7 |
8 | def __init__(
9 | self,
10 | n_agents,
11 | state_shape,
12 | obs_shape,
13 | hidden_dim,
14 | n_actions,
15 | obs_agent_id: bool,
16 | obs_last_action: bool,
17 | obs_individual_obs: bool,
18 | ):
19 | super(MADDPGCritic, self).__init__()
20 | self.n_agents = n_agents
21 | self.n_actions = n_actions
22 | self.state_shape = state_shape
23 | self.obs_shape = obs_shape
24 | self.obs_individual_obs = obs_individual_obs
25 | self.obs_agent_id = obs_agent_id
26 | self.obs_last_action = obs_last_action
27 |
28 | self.input_shape = self._get_input_shape(
29 | ) + self.n_actions * self.n_agents
30 | if self.obs_last_action:
31 | self.input_shape += self.n_actions
32 |
33 | # Set up network layers
34 | self.fc1 = nn.Linear(self.input_shape, hidden_dim)
35 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
36 | self.fc3 = nn.Linear(hidden_dim, 1)
37 |
38 | def forward(self, inputs, actions):
39 | inputs = torch.cat((inputs, actions), dim=-1)
40 | x = F.relu(self.fc1(inputs))
41 | x = F.relu(self.fc2(x))
42 | q = self.fc3(x)
43 | return q
44 |
45 | def _get_input_shape(self):
46 | # state_shape
47 | input_shape = self.state_shape
48 | # whether to add the individual observation
49 | if self.obs_individual_obs:
50 | input_shape += self.obs_shape
51 | # agent id
52 | if self.obs_agent_id:
53 | input_shape += self.n_agents
54 | return input_shape
55 |
--------------------------------------------------------------------------------
/marltoolkit/modules/critics/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # Initialize Policy weights
7 | def weights_init_(module: nn.Module):
8 | if isinstance(module, nn.Linear):
9 | nn.init.xavier_uniform_(module.weight, gain=1)
10 | nn.init.constant_(module.bias, 0)
11 |
12 |
13 | class MLPCriticModel(nn.Module):
14 | """MLP Critic Network.
15 |
16 | Args:
17 | nn (_type_): _description_
18 | """
19 |
20 | def __init__(
21 | self,
22 | input_dim: int = None,
23 | hidden_dim: int = 64,
24 | output_dim: int = 1,
25 | ) -> None:
26 | super(MLPCriticModel, self).__init__()
27 |
28 | # Set up network layers
29 | self.fc1 = nn.Linear(input_dim, hidden_dim)
30 | self.fc2 = nn.Linear(hidden_dim, hidden_dim)
31 | self.fc3 = nn.Linear(hidden_dim, output_dim)
32 | self.apply(weights_init_)
33 |
34 | def forward(self, inputs: torch.Tensor) -> torch.Tensor:
35 | """
36 | Args:
37 | inputs (torch.Tensor):
38 | Returns:
39 | q_total (torch.Tensor):
40 | """
41 | hidden = F.relu(self.fc1(inputs))
42 | hidden = F.relu(self.fc2(hidden))
43 | qvalue = self.fc3(hidden)
44 | return qvalue
45 |
--------------------------------------------------------------------------------
/marltoolkit/modules/critics/r_critic.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 | from marltoolkit.modules.utils.common import MLPBase, RNNLayer
7 | from marltoolkit.modules.utils.popart import PopArt
8 |
9 |
10 | class R_Critic(nn.Module):
11 | """Critic network class for MAPPO. Outputs value function predictions given
12 | centralized input (MAPPO) or local observations (IPPO).
13 |
14 | :param args: (argparse.Namespace) arguments containing relevant model information.
15 | :param state_space: (gym.Space) (centralized) observation space.
16 | :param device: (torch.device) specifies the device to run on (cpu/gpu).
17 | """
18 |
19 | def __init__(self, args: argparse.Namespace):
20 | super(R_Critic, self).__init__()
21 | self.use_recurrent_policy = args.use_recurrent_policy
22 | init_method = [nn.init.xavier_uniform_,
23 | nn.init.orthogonal_][args.use_orthogonal]
24 | self.base = MLPBase(
25 | input_dim=args.state_dim,
26 | hidden_dim=args.hidden_size,
27 | use_orthogonal=args.use_orthogonal,
28 | use_feature_normalization=args.use_feature_normalization,
29 | )
30 | if self.use_recurrent_policy:
31 | self.rnn = RNNLayer(
32 | args.hidden_size,
33 | args.hidden_size,
34 | args.rnn_layers,
35 | args.use_orthogonal,
36 | )
37 |
38 | def init_weight(module: nn.Module) -> None:
39 | if isinstance(module, nn.Linear):
40 | init_method(module.weight, gain=args.gain)
41 | if module.bias is not None:
42 | nn.init.constant_(module.bias, 0)
43 |
44 | if args.use_popart:
45 | self.v_out = PopArt(args.hidden_size, 1)
46 | else:
47 | self.v_out = nn.Linear(args.hidden_size, 1)
48 |
49 | self.apply(init_weight)
50 |
51 | def forward(
52 | self,
53 | state: torch.Tensor,
54 | masks: torch.Tensor,
55 | rnn_hidden_states: torch.Tensor,
56 | ):
57 | """Compute actions from the given inputs.
58 |
59 | :param state: (np.ndarray / torch.Tensor) global observation inputs into network.
60 | :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros.
61 | :param rnn_hidden_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN.
62 |
63 | :return values: (torch.Tensor) value function predictions.
64 | :return rnn_hidden_states: (torch.Tensor) updated RNN hidden states.
65 | """
66 | critic_features = self.base(state)
67 | if self.use_recurrent_policy:
68 | critic_features, rnn_hidden_states = self.rnn(
69 | critic_features, rnn_hidden_states, masks)
70 | values = self.v_out(critic_features)
71 | return values, rnn_hidden_states
72 |
--------------------------------------------------------------------------------
/marltoolkit/modules/mixers/__init__.py:
--------------------------------------------------------------------------------
1 | """Utils package."""
2 |
3 | from .qmixer import QMixerModel
4 | from .vdn import VDNMixer
5 |
6 | __all__ = ['QMixerModel', 'VDNMixer']
7 |
--------------------------------------------------------------------------------
/marltoolkit/modules/mixers/qatten.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | class QattenMixer(nn.Module):
8 |
9 | def __init__(self,
10 | n_agents: int = None,
11 | state_dim: int = None,
12 | agent_own_state_size: int = None,
13 | n_query_embedding_layer1: int = 64,
14 | n_query_embedding_layer2: int = 32,
15 | n_key_embedding_layer1: int = 32,
16 | n_head_embedding_layer1: int = 64,
17 | n_head_embedding_layer2: int = 4,
18 | num_attention_heads: int = 4,
19 | n_constrant_value: int = 32):
20 | super(QattenMixer, self).__init__()
21 | self.n_agents = n_agents
22 | self.state_dim = state_dim
23 | self.agent_own_state_size = agent_own_state_size
24 |
25 | self.n_query_embedding_layer1 = n_query_embedding_layer1
26 | self.n_query_embedding_layer2 = n_query_embedding_layer2
27 | self.n_key_embedding_layer1 = n_key_embedding_layer1
28 | self.n_head_embedding_layer1 = n_head_embedding_layer1
29 | self.n_head_embedding_layer2 = n_head_embedding_layer2
30 | self.num_attention_heads = num_attention_heads
31 | self.n_constrant_value = n_constrant_value
32 |
33 | self.query_embedding_layers = nn.ModuleList()
34 | for _ in range(self.num_attention_heads):
35 | self.query_embedding_layers.append(
36 | nn.Sequential(
37 | nn.Linear(state_dim, n_query_embedding_layer1),
38 | nn.ReLU(inplace=True),
39 | nn.Linear(n_query_embedding_layer1,
40 | n_query_embedding_layer2)))
41 |
42 | self.key_embedding_layers = nn.ModuleList()
43 | for _ in range(num_attention_heads):
44 | self.key_embedding_layers.append(
45 | nn.Linear(self.agent_own_state_size, n_key_embedding_layer1))
46 |
47 | self.scaled_product_value = np.sqrt(n_query_embedding_layer2)
48 |
49 | self.head_embedding_layer = nn.Sequential(
50 | nn.Linear(state_dim, n_head_embedding_layer1),
51 | nn.ReLU(inplace=True),
52 | nn.Linear(n_head_embedding_layer1, n_head_embedding_layer2))
53 |
54 | self.constrant_value_layer = nn.Sequential(
55 | nn.Linear(state_dim, n_constrant_value), nn.ReLU(inplace=True),
56 | nn.Linear(n_constrant_value, 1))
57 |
58 | def forward(self, agent_qs: torch.Tensor, states: torch.Tensor):
59 | '''
60 | Args:
61 | agent_qs (torch.Tensor): (batch_size, T, n_agents)
62 | states (torch.Tensor): (batch_size, T, state_shape)
63 | Returns:
64 | q_total (torch.Tensor): (batch_size, T, 1)
65 | '''
66 | bs = agent_qs.size(0)
67 | # states : (batch_size * T, state_shape)
68 | states = states.reshape(-1, self.state_dim)
69 | # agent_qs: (batch_size * T, 1, n_agents)
70 | agent_qs = agent_qs.view(-1, 1, self.n_agents)
71 | us = self._get_us(states)
72 |
73 | q_lambda_list = []
74 | for i in range(self.num_attention_heads):
75 | state_embedding = self.query_embedding_layers[i](states)
76 | u_embedding = self.key_embedding_layers[i](us)
77 |
78 | # shape: [batch_size * T, 1, state_dim]
79 | state_embedding = state_embedding.reshape(
80 | -1, 1, self.n_query_embedding_layer2)
81 | # shape: [batch_size * T, state_dim, n_agent]
82 | u_embedding = u_embedding.reshape(-1, self.n_agents,
83 | self.n_key_embedding_layer1)
84 | u_embedding = u_embedding.permute(0, 2, 1)
85 |
86 | # shape: [batch_size * T, 1, n_agent]
87 | raw_lambda = torch.matmul(state_embedding,
88 | u_embedding) / self.scaled_product_value
89 | q_lambda = F.softmax(raw_lambda, dim=-1)
90 |
91 | q_lambda_list.append(q_lambda)
92 |
93 | # shape: [batch_size * T, num_attention_heads, n_agent]
94 | q_lambda_list = torch.stack(q_lambda_list, dim=1).squeeze(-2)
95 |
96 | # shape: [batch_size * T, n_agent, num_attention_heads]
97 | q_lambda_list = q_lambda_list.permute(0, 2, 1)
98 |
99 | # shape: [batch_size * T, 1, num_attention_heads]
100 | q_h = torch.matmul(agent_qs, q_lambda_list)
101 |
102 | if self.type == 'weighted':
103 | # shape: [batch_size * T, num_attention_heads]
104 | w_h = torch.abs(self.head_embedding_layer(states))
105 | # shape: [batch_size * T, num_attention_heads, 1]
106 | w_h = w_h.reshape(-1, self.n_head_embedding_layer2, 1)
107 |
108 | # shape: [batch_size * T, 1,1]
109 | sum_q_h = torch.matmul(q_h, w_h)
110 | # shape: [batch_size * T, 1]
111 | sum_q_h = sum_q_h.reshape(-1, 1)
112 | else:
113 | # shape: [-1, 1]
114 | sum_q_h = q_h.sum(-1)
115 | sum_q_h = sum_q_h.reshape(-1, 1)
116 |
117 | c = self.constrant_value_layer(states)
118 | q_tot = sum_q_h + c
119 | q_tot = q_tot.view(bs, -1, 1)
120 | return q_tot
121 |
122 | def _get_us(self, states):
123 | agent_own_state_size = self.agent_own_state_size
124 | with torch.no_grad():
125 | us = states[:, :agent_own_state_size * self.n_agents].reshape(
126 | -1, agent_own_state_size)
127 | return us
128 |
--------------------------------------------------------------------------------
/marltoolkit/modules/mixers/qmixer.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: jianzhnie
3 | LastEditors: jianzhnie
4 | Description: RLToolKit is a flexible and high-efficient reinforcement learning framework.
5 | Copyright (c) 2022 by jianzhnie@126.com, All Rights Reserved.
6 | '''
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class QMixerModel(nn.Module):
13 | """Implementation of the QMixer network.
14 |
15 | Args:
16 | num_agents (int): The number of agents.
17 | state_dim (int): The shape of the state.
18 | hypernet_layers (int): The number of layers in the hypernetwork.
19 | mixing_embed_dim (int): The dimension of the mixing embedding.
20 | hypernet_embed_dim (int): The dimension of the hypernetwork embedding.
21 | """
22 |
23 | def __init__(
24 | self,
25 | num_agents: int = None,
26 | state_dim: int = None,
27 | hypernet_layers: int = 2,
28 | mixing_embed_dim: int = 32,
29 | hypernet_embed_dim: int = 64,
30 | ):
31 | super(QMixerModel, self).__init__()
32 |
33 | self.num_agents = num_agents
34 | self.state_dim = state_dim
35 | self.mixing_embed_dim = mixing_embed_dim
36 | if hypernet_layers == 1:
37 | self.hyper_w_1 = nn.Linear(state_dim,
38 | mixing_embed_dim * num_agents)
39 | self.hyper_w_2 = nn.Linear(state_dim, mixing_embed_dim)
40 | elif hypernet_layers == 2:
41 | self.hyper_w_1 = nn.Sequential(
42 | nn.Linear(state_dim, hypernet_embed_dim),
43 | nn.ReLU(inplace=True),
44 | nn.Linear(hypernet_embed_dim, mixing_embed_dim * num_agents),
45 | )
46 | self.hyper_w_2 = nn.Sequential(
47 | nn.Linear(state_dim, hypernet_embed_dim),
48 | nn.ReLU(inplace=True),
49 | nn.Linear(hypernet_embed_dim, mixing_embed_dim),
50 | )
51 | else:
52 | raise ValueError('hypernet_layers should be "1" or "2"!')
53 |
54 | # State dependent bias for hidden layer
55 | self.hyper_b_1 = nn.Linear(state_dim, mixing_embed_dim)
56 | self.hyper_b_2 = nn.Sequential(
57 | nn.Linear(state_dim, mixing_embed_dim),
58 | nn.ReLU(inplace=True),
59 | nn.Linear(mixing_embed_dim, 1),
60 | )
61 |
62 | def forward(self, agent_qs: torch.Tensor, states: torch.Tensor):
63 | """
64 | Args:
65 | agent_qs (torch.Tensor): (batch_size, T, num_agents)
66 | states (torch.Tensor): (batch_size, T, state_dim)
67 | Returns:
68 | q_total (torch.Tensor): (batch_size, T, 1)
69 | """
70 | batch_size = agent_qs.size(0)
71 | # states : (batch_size * T, state_dim)
72 | states = states.reshape(-1, self.state_dim)
73 | # agent_qs: (batch_size * T, 1, num_agents)
74 | agent_qs = agent_qs.view(-1, 1, self.num_agents)
75 |
76 | # First layer w and b
77 | w1 = torch.abs(self.hyper_w_1(states))
78 | # w1: (batch_size * T, num_agents, embed_dim)
79 | w1 = w1.view(-1, self.num_agents, self.mixing_embed_dim)
80 | b1 = self.hyper_b_1(states)
81 | b1 = b1.view(-1, 1, self.mixing_embed_dim)
82 |
83 | # Second layer w and b
84 | # w2 : (batch_size * T, embed_dim)
85 | w2 = torch.abs(self.hyper_w_2(states))
86 | # w2 : (batch_size * T, embed_dim, 1)
87 | w2 = w2.view(-1, self.mixing_embed_dim, 1)
88 | # State-dependent bias
89 | b2 = self.hyper_b_2(states).view(-1, 1, 1)
90 |
91 | # First hidden layer
92 | # hidden: (batch_size * T, 1, embed_dim)
93 | hidden = F.elu(torch.bmm(agent_qs, w1) + b1)
94 | # Compute final output
95 | # y: (batch_size * T, 1, 1)
96 | y = torch.bmm(hidden, w2) + b2
97 | # Reshape and return
98 | # q_total: (batch_size, T, 1)
99 | q_total = y.view(batch_size, -1, 1)
100 | return q_total
101 |
102 | def update(self, model):
103 | self.load_state_dict(model.state_dict())
104 |
105 |
106 | class QMixerCentralFF(nn.Module):
107 |
108 | def __init__(self, num_agents, state_dim, central_mixing_embed_dim,
109 | central_action_embed):
110 | super(QMixerCentralFF, self).__init__()
111 |
112 | self.num_agents = num_agents
113 | self.state_dim = state_dim
114 | self.input_dim = num_agents * central_action_embed + state_dim
115 | self.central_mixing_embed_dim = central_mixing_embed_dim
116 | self.central_action_embed = central_action_embed
117 |
118 | self.net = nn.Sequential(
119 | nn.Linear(self.input_dim, central_mixing_embed_dim),
120 | nn.ReLU(inplace=True),
121 | nn.Linear(central_mixing_embed_dim, central_mixing_embed_dim),
122 | nn.ReLU(inplace=True),
123 | nn.Linear(central_mixing_embed_dim, central_mixing_embed_dim),
124 | nn.ReLU(inplace=True), nn.Linear(central_mixing_embed_dim, 1))
125 |
126 | # V(s) instead of a bias for the last layers
127 | self.vnet = nn.Sequential(
128 | nn.Linear(state_dim, central_mixing_embed_dim),
129 | nn.ReLU(inplace=True),
130 | nn.Linear(central_mixing_embed_dim, 1),
131 | )
132 |
133 | def forward(self, agent_qs, states):
134 | bs = agent_qs.size(0)
135 | states = states.reshape(-1, self.state_dim)
136 | agent_qs = agent_qs.reshape(
137 | -1, self.num_agents * self.central_action_embed)
138 |
139 | inputs = torch.cat([states, agent_qs], dim=1)
140 |
141 | advs = self.net(inputs)
142 | vs = self.vnet(states)
143 | y = advs + vs
144 | q_tot = y.view(bs, -1, 1)
145 | return q_tot
146 |
--------------------------------------------------------------------------------
/marltoolkit/modules/mixers/vdn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class VDNMixer(nn.Module):
7 | """Computes total Q values given agent q values and global states."""
8 |
9 | def __init__(self):
10 | super(VDNMixer, self).__init__()
11 |
12 | def forward(self, agent_qs: torch.Tensor, states: torch.Tensor):
13 | if type(agent_qs) == np.ndarray:
14 | agent_qs = torch.FloatTensor(agent_qs)
15 | return torch.sum(agent_qs, dim=2, keepdim=True)
16 |
--------------------------------------------------------------------------------
/marltoolkit/modules/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/modules/utils/__init__.py
--------------------------------------------------------------------------------
/marltoolkit/modules/utils/common.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class MLPBase(nn.Module):
8 |
9 | def __init__(
10 | self,
11 | input_dim: int,
12 | hidden_dim: int,
13 | activation: str = 'relu',
14 | use_orthogonal: bool = False,
15 | use_feature_normalization: bool = False,
16 | ) -> None:
17 | super(MLPBase, self).__init__()
18 |
19 | use_relu = 1 if activation == 'relu' else 0
20 | active_func = [nn.ReLU(), nn.Tanh()][use_relu]
21 | if use_orthogonal:
22 | init_method = nn.init.orthogonal_
23 | else:
24 | init_method = nn.init.xavier_uniform_
25 |
26 | gain = nn.init.calculate_gain(['tanh', 'relu'][use_relu])
27 | self.use_feature_normalization = use_feature_normalization
28 | if use_feature_normalization:
29 | self.feature_norm = nn.LayerNorm(input_dim)
30 |
31 | def init_weight(module: nn.Module) -> None:
32 | if isinstance(module, nn.Linear):
33 | init_method(module.weight, gain=gain)
34 | if module.bias is not None:
35 | nn.init.constant_(module.bias, 0)
36 |
37 | self.fc1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), active_func,
38 | nn.LayerNorm(hidden_dim))
39 | self.fc2 = nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
40 | active_func, nn.LayerNorm(hidden_dim))
41 | self.apply(init_weight)
42 |
43 | def forward(self, inputs: torch.Tensor):
44 | """Forward method for MLPBase.
45 |
46 | Args:
47 | inputs (torch.Tensor): Input tensor. Shape (batch_size, input_dim)
48 |
49 | Returns:
50 | output (torch.Tensor): Output tensor. Shape (batch_size, hidden_dim)
51 | """
52 | if self.use_feature_normalization:
53 | output = self.feature_norm(inputs)
54 |
55 | output = self.fc1(output)
56 | output = self.fc2(output)
57 |
58 | return output
59 |
60 |
61 | class Flatten(nn.Module):
62 |
63 | def forward(self, inputs: torch.Tensor):
64 | return inputs.view(inputs.size(0), -1)
65 |
66 |
67 | class CNNBase(nn.Module):
68 |
69 | def __init__(
70 | self,
71 | obs_shape: Tuple,
72 | hidden_dim: int,
73 | kernel_size: int = 3,
74 | stride: int = 1,
75 | activation: str = 'relu',
76 | use_orthogonal: bool = False,
77 | ) -> None:
78 | super(CNNBase, self).__init__()
79 |
80 | use_relu = 1 if activation == 'relu' else 0
81 | active_func = [nn.ReLU(), nn.Tanh()][use_relu]
82 | init_method = [nn.init.xavier_uniform_,
83 | nn.init.orthogonal_][use_orthogonal]
84 | gain = nn.init.calculate_gain(['tanh', 'relu'][use_relu])
85 |
86 | (in_channel, width, height) = obs_shape
87 |
88 | def init_weight(module: nn.Module) -> None:
89 | if isinstance(module, nn.Linear):
90 | init_method(module.weight, gain=gain)
91 | if module.bias is not None:
92 | nn.init.constant_(module.bias, 0)
93 |
94 | self.cnn = nn.Sequential(
95 | nn.Conv2d(
96 | in_channels=in_channel,
97 | out_channels=hidden_dim // 2,
98 | kernel_size=kernel_size,
99 | stride=stride,
100 | ),
101 | active_func,
102 | Flatten(),
103 | nn.Linear(
104 | hidden_dim // 2 * (width - kernel_size + stride) *
105 | (height - kernel_size + stride),
106 | hidden_dim,
107 | ),
108 | active_func,
109 | nn.Linear(hidden_dim, hidden_dim),
110 | active_func,
111 | )
112 | self.apply(init_weight)
113 |
114 | def forward(self, inputs: torch.Tensor):
115 | inputs = inputs / 255.0
116 | output = self.cnn(inputs)
117 | return output
118 |
119 |
120 | class RNNLayer(nn.Module):
121 |
122 | def __init__(
123 | self,
124 | input_dim: int,
125 | rnn_hidden_dim: int,
126 | rnn_layers: int,
127 | use_orthogonal: bool = True,
128 | ) -> None:
129 | super(RNNLayer, self).__init__()
130 | self.rnn_layers = rnn_layers
131 | self.use_orthogonal = use_orthogonal
132 |
133 | self.rnn = nn.GRU(input_dim,
134 | rnn_hidden_dim,
135 | num_layers=rnn_layers,
136 | batch_first=True)
137 | for name, param in self.rnn.named_parameters():
138 | if 'bias' in name:
139 | nn.init.constant_(param, 0)
140 | elif 'weight' in name:
141 | if self.use_orthogonal:
142 | nn.init.orthogonal_(param)
143 | else:
144 | nn.init.xavier_uniform_(param)
145 | self.layer_norm = nn.LayerNorm(rnn_hidden_dim)
146 |
147 | def forward(
148 | self,
149 | inputs: torch.Tensor,
150 | hidden_state: torch.Tensor,
151 | masks: torch.Tensor,
152 | ) -> tuple[Any, torch.Tensor]:
153 | """Forward method for RNNLayer.
154 |
155 | Args:
156 | inputs (torch.Tensor): (num_agents, input_dim)
157 | hidden_state (torch.Tensor): (num_agents, rnn_layers, rnn_hidden_dim)
158 | masks (torch.Tensor): (num_agents, 1)
159 |
160 | Returns:
161 | tuple[Any, torch.Tensor]: (output, hidden_state)
162 | """
163 | print('inputs.shape: ', inputs.shape)
164 | print('hidden_state.shape: ', hidden_state.shape)
165 | print('masks.shape: ', masks.shape)
166 |
167 | if inputs.size(0) == hidden_state.size(0):
168 | # If the batch size is the same, we can just run the RNN
169 | masks = masks.repeat(1, self.rnn_layers).unsqueeze(-1)
170 | # mask shape (num_agents, rnn_layers, 1)
171 | hidden_state = (hidden_state * masks).transpose(0, 1).contiguous()
172 | # hidden_state shape (rnn_layers, num_agents, rnn_hidden_dim)
173 | inputs = inputs.unsqueeze(0).transpose(0, 1)
174 | # inputs shape (1, num_agents, input_dim)
175 | output, hidden_state = self.rnn(inputs, hidden_state)
176 | output = output.squeeze(1)
177 | hidden_state = hidden_state.transpose(0, 1)
178 | else:
179 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1)
180 | N = hidden_state.size(0)
181 | T = int(inputs.size(0) / N)
182 | # unflatten
183 | inputs = inputs.view(T, N, inputs.size(1))
184 | # Same deal with masks
185 | masks = masks.view(T, N)
186 | # Let's figure out which steps in the sequence have a zero for any agent
187 | # We will always assume t=0 has a zero in it as that makes the logic cleaner
188 | has_zeros = (masks[1:] == 0.0).any(
189 | dim=-1).nonzero().squeeze().cpu()
190 |
191 | # +1 to correct the masks[1:]
192 | if has_zeros.dim() == 0:
193 | # Deal with scalar
194 | has_zeros = [has_zeros.item() + 1]
195 | else:
196 | has_zeros = (has_zeros + 1).numpy().tolist()
197 |
198 | # add t=0 and t=T to the list
199 | has_zeros = [0] + has_zeros + [T]
200 |
201 | hidden_state = hidden_state.transpose(0, 1)
202 |
203 | outputs = []
204 | for i in range(len(has_zeros) - 1):
205 | # We can now process steps that don't have any zeros in masks together!
206 | # This is much faster
207 | start_idx = has_zeros[i]
208 | end_idx = has_zeros[i + 1]
209 | temp = (hidden_state * masks[start_idx].view(1, -1, 1).repeat(
210 | self.rnn_layers, 1, 1)).contiguous()
211 | rnn_scores, hidden_state = self.rnn(inputs[start_idx:end_idx],
212 | temp)
213 | outputs.append(rnn_scores)
214 |
215 | # assert len(outputs) == T
216 | # x is a (T, N, -1) tensor
217 | inputs = torch.cat(outputs, dim=0)
218 |
219 | # flatten
220 | inputs = inputs.reshape(T * N, -1)
221 | hidden_state = hidden_state.transpose(0, 1)
222 |
223 | output = self.layer_norm(outputs)
224 | return output, hidden_state
225 |
--------------------------------------------------------------------------------
/marltoolkit/modules/utils/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.distributions import Bernoulli, Categorical, Normal
4 |
5 |
6 | #
7 | # Standardize distribution interfaces
8 | #
9 | # Categorical
10 | class FixedCategorical(Categorical):
11 |
12 | def sample(self):
13 | return super().sample().unsqueeze(-1)
14 |
15 | def log_probs(self, actions):
16 | return (super().log_prob(actions.squeeze(-1)).view(
17 | actions.size(0), -1).sum(-1).unsqueeze(-1))
18 |
19 | def mode(self):
20 | return self.probs.argmax(dim=-1, keepdim=True)
21 |
22 |
23 | # Normal
24 | class FixedNormal(Normal):
25 |
26 | def log_probs(self, actions):
27 | return super().log_prob(actions).sum(-1, keepdim=True)
28 |
29 | def entropy(self):
30 | return super().entropy().sum(-1)
31 |
32 | def mode(self):
33 | return self.mean
34 |
35 |
36 | # Bernoulli
37 | class FixedBernoulli(Bernoulli):
38 |
39 | def log_probs(self, actions):
40 | return super.log_prob(actions).view(actions.size(0),
41 | -1).sum(-1).unsqueeze(-1)
42 |
43 | def entropy(self):
44 | return super().entropy().sum(-1)
45 |
46 | def mode(self):
47 | return torch.gt(self.probs, 0.5).float()
48 |
49 |
50 | class CustomCategorical(nn.Module):
51 |
52 | def __init__(self,
53 | num_inputs,
54 | num_outputs,
55 | use_orthogonal=True,
56 | gain=0.01):
57 | super(CustomCategorical, self).__init__()
58 | init_method = [nn.init.xavier_uniform_,
59 | nn.init.orthogonal_][use_orthogonal]
60 |
61 | def init_weight(module: nn.Module) -> None:
62 | if isinstance(module, nn.Linear):
63 | init_method(module.weight, gain=gain)
64 | if module.bias is not None:
65 | nn.init.constant_(module.bias, 0)
66 |
67 | self.linear = nn.Linear(num_inputs, num_outputs)
68 | self.apply(init_weight)
69 |
70 | def forward(self, x, available_actions=None):
71 | x = self.linear(x)
72 | if available_actions is not None:
73 | x[available_actions == 0] = -1e10
74 | return FixedCategorical(logits=x)
75 |
76 |
77 | class DiagGaussian(nn.Module):
78 |
79 | def __init__(self,
80 | num_inputs,
81 | num_outputs,
82 | use_orthogonal=True,
83 | gain=0.01):
84 | super(DiagGaussian, self).__init__()
85 |
86 | init_method = [nn.init.xavier_uniform_,
87 | nn.init.orthogonal_][use_orthogonal]
88 |
89 | def init_weight(module: nn.Module) -> None:
90 | if isinstance(module, nn.Linear):
91 | init_method(module.weight, gain=gain)
92 | if module.bias is not None:
93 | nn.init.constant_(module.bias, 0)
94 |
95 | self.fc_mean = nn.Linear(num_inputs, num_outputs)
96 | self.logstd = AddBias(torch.zeros(num_outputs))
97 | self.apply(init_weight)
98 |
99 | def forward(self, x):
100 | action_mean = self.fc_mean(x)
101 |
102 | # An ugly hack for my KFAC implementation.
103 | zeros = torch.zeros(action_mean.size())
104 | if x.is_cuda:
105 | zeros = zeros.cuda()
106 |
107 | action_logstd = self.logstd(zeros)
108 | return FixedNormal(action_mean, action_logstd.exp())
109 |
110 |
111 | class CustomBernoulli(nn.Module):
112 |
113 | def __init__(self,
114 | num_inputs,
115 | num_outputs,
116 | use_orthogonal=True,
117 | gain=0.01):
118 | super(CustomBernoulli, self).__init__()
119 | init_method = [nn.init.xavier_uniform_,
120 | nn.init.orthogonal_][use_orthogonal]
121 |
122 | def init_weight(module: nn.Module) -> None:
123 | if isinstance(module, nn.Linear):
124 | init_method(module.weight, gain=gain)
125 | if module.bias is not None:
126 | nn.init.constant_(module.bias, 0)
127 |
128 | self.linear = nn.Linear(num_inputs, num_outputs)
129 | self.apply(init_weight)
130 |
131 | def forward(self, x):
132 | x = self.linear(x)
133 | return FixedBernoulli(logits=x)
134 |
135 |
136 | class AddBias(nn.Module):
137 |
138 | def __init__(self, bias):
139 | super(AddBias, self).__init__()
140 | self._bias = nn.Parameter(bias.unsqueeze(1))
141 |
142 | def forward(self, x):
143 | if x.dim() == 2:
144 | bias = self._bias.t().view(1, -1)
145 | else:
146 | bias = self._bias.t().view(1, -1, 1, 1)
147 |
148 | return x + bias
149 |
--------------------------------------------------------------------------------
/marltoolkit/modules/utils/popart.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class PopArt(torch.nn.Module):
10 |
11 | def __init__(
12 | self,
13 | input_shape,
14 | output_shape,
15 | norm_axes=1,
16 | beta=0.99999,
17 | epsilon=1e-5,
18 | device=torch.device('cpu'),
19 | ):
20 | super(PopArt, self).__init__()
21 |
22 | self.beta = beta
23 | self.epsilon = epsilon
24 | self.norm_axes = norm_axes
25 | self.tpdv = dict(dtype=torch.float32, device=device)
26 |
27 | self.input_shape = input_shape
28 | self.output_shape = output_shape
29 |
30 | self.weight = nn.Parameter(torch.Tensor(output_shape,
31 | input_shape)).to(**self.tpdv)
32 | self.bias = nn.Parameter(torch.Tensor(output_shape)).to(**self.tpdv)
33 |
34 | self.stddev = nn.Parameter(torch.ones(output_shape),
35 | requires_grad=False).to(**self.tpdv)
36 | self.mean = nn.Parameter(torch.zeros(output_shape),
37 | requires_grad=False).to(**self.tpdv)
38 | self.mean_sq = nn.Parameter(torch.zeros(output_shape),
39 | requires_grad=False).to(**self.tpdv)
40 | self.debiasing_term = nn.Parameter(torch.tensor(0.0),
41 | requires_grad=False).to(**self.tpdv)
42 |
43 | self.reset_parameters()
44 |
45 | def reset_parameters(self):
46 | torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
47 | if self.bias is not None:
48 | fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(
49 | self.weight)
50 | bound = 1 / math.sqrt(fan_in)
51 | torch.nn.init.uniform_(self.bias, -bound, bound)
52 | self.mean.zero_()
53 | self.mean_sq.zero_()
54 | self.debiasing_term.zero_()
55 |
56 | def forward(self, input_vector):
57 | if type(input_vector) == np.ndarray:
58 | input_vector = torch.from_numpy(input_vector)
59 | input_vector = input_vector.to(**self.tpdv)
60 |
61 | return F.linear(input_vector, self.weight, self.bias)
62 |
63 | @torch.no_grad()
64 | def update(self, input_vector):
65 | if type(input_vector) == np.ndarray:
66 | input_vector = torch.from_numpy(input_vector)
67 | input_vector = input_vector.to(**self.tpdv)
68 |
69 | old_mean, old_var = self.debiased_mean_var()
70 | old_stddev = torch.sqrt(old_var)
71 |
72 | batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes)))
73 | batch_sq_mean = (input_vector**2).mean(
74 | dim=tuple(range(self.norm_axes)))
75 |
76 | self.mean.mul_(self.beta).add_(batch_mean * (1.0 - self.beta))
77 | self.mean_sq.mul_(self.beta).add_(batch_sq_mean * (1.0 - self.beta))
78 | self.debiasing_term.mul_(self.beta).add_(1.0 * (1.0 - self.beta))
79 |
80 | self.stddev = (self.mean_sq - self.mean**2).sqrt().clamp(min=1e-4)
81 |
82 | new_mean, new_var = self.debiased_mean_var()
83 | new_stddev = torch.sqrt(new_var)
84 |
85 | self.weight = self.weight * old_stddev / new_stddev
86 | self.bias = (old_stddev * self.bias + old_mean - new_mean) / new_stddev
87 |
88 | def debiased_mean_var(self):
89 | debiased_mean = self.mean / self.debiasing_term.clamp(min=self.epsilon)
90 | debiased_mean_sq = self.mean_sq / self.debiasing_term.clamp(
91 | min=self.epsilon)
92 | debiased_var = (debiased_mean_sq - debiased_mean**2).clamp(min=1e-2)
93 | return debiased_mean, debiased_var
94 |
95 | def normalize(self, input_vector):
96 | if type(input_vector) == np.ndarray:
97 | input_vector = torch.from_numpy(input_vector)
98 | input_vector = input_vector.to(**self.tpdv)
99 |
100 | mean, var = self.debiased_mean_var()
101 | out = (input_vector - mean[(None, ) * self.norm_axes]
102 | ) / torch.sqrt(var)[(None, ) * self.norm_axes]
103 |
104 | return out
105 |
106 | def denormalize(self, input_vector):
107 | if type(input_vector) == np.ndarray:
108 | input_vector = torch.from_numpy(input_vector)
109 | input_vector = input_vector.to(**self.tpdv)
110 |
111 | mean, var = self.debiased_mean_var()
112 | out = (input_vector * torch.sqrt(var)[(None, ) * self.norm_axes] +
113 | mean[(None, ) * self.norm_axes])
114 |
115 | out = out.cpu().numpy()
116 |
117 | return out
118 |
--------------------------------------------------------------------------------
/marltoolkit/modules/utils/valuenorm.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class ValueNorm(nn.Module):
7 | """Normalize a vector of observations - across the first norm_axes dimensions"""
8 |
9 | def __init__(
10 | self,
11 | input_shape,
12 | norm_axes=1,
13 | beta=0.99999,
14 | per_element_update=False,
15 | epsilon=1e-5,
16 | device=torch.device('cpu'),
17 | ):
18 | super(ValueNorm, self).__init__()
19 |
20 | self.input_shape = input_shape
21 | self.norm_axes = norm_axes
22 | self.epsilon = epsilon
23 | self.beta = beta
24 | self.per_element_update = per_element_update
25 | self.tpdv = dict(dtype=torch.float32, device=device)
26 |
27 | self.running_mean = nn.Parameter(torch.zeros(input_shape),
28 | requires_grad=False).to(**self.tpdv)
29 | self.running_mean_sq = nn.Parameter(
30 | torch.zeros(input_shape), requires_grad=False).to(**self.tpdv)
31 | self.debiasing_term = nn.Parameter(torch.tensor(0.0),
32 | requires_grad=False).to(**self.tpdv)
33 |
34 | self.reset_parameters()
35 |
36 | def reset_parameters(self):
37 | self.running_mean.zero_()
38 | self.running_mean_sq.zero_()
39 | self.debiasing_term.zero_()
40 |
41 | def running_mean_var(self):
42 | debiased_mean = self.running_mean / self.debiasing_term.clamp(
43 | min=self.epsilon)
44 | debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(
45 | min=self.epsilon)
46 | debiased_var = (debiased_mean_sq - debiased_mean**2).clamp(min=1e-2)
47 | return debiased_mean, debiased_var
48 |
49 | @torch.no_grad()
50 | def update(self, input_vector):
51 | if type(input_vector) == np.ndarray:
52 | input_vector = torch.from_numpy(input_vector)
53 | input_vector = input_vector.to(**self.tpdv)
54 |
55 | batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes)))
56 | batch_sq_mean = (input_vector**2).mean(
57 | dim=tuple(range(self.norm_axes)))
58 |
59 | if self.per_element_update:
60 | batch_size = np.prod(input_vector.size()[:self.norm_axes])
61 | weight = self.beta**batch_size
62 | else:
63 | weight = self.beta
64 |
65 | self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight))
66 | self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight))
67 | self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight))
68 |
69 | def normalize(self, input_vector):
70 | # Make sure input is float32
71 | if type(input_vector) == np.ndarray:
72 | input_vector = torch.from_numpy(input_vector)
73 | input_vector = input_vector.to(**self.tpdv)
74 |
75 | mean, var = self.running_mean_var()
76 | out = (input_vector - mean[(None, ) * self.norm_axes]
77 | ) / torch.sqrt(var)[(None, ) * self.norm_axes]
78 |
79 | return out
80 |
81 | def denormalize(self, input_vector):
82 | """Transform normalized data back into original distribution."""
83 | if type(input_vector) == np.ndarray:
84 | input_vector = torch.from_numpy(input_vector)
85 | input_vector = input_vector.to(**self.tpdv)
86 |
87 | mean, var = self.running_mean_var()
88 | out = (input_vector * torch.sqrt(var)[(None, ) * self.norm_axes] +
89 | mean[(None, ) * self.norm_axes])
90 |
91 | out = out.cpu().numpy()
92 |
93 | return out
94 |
--------------------------------------------------------------------------------
/marltoolkit/runners/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jianzhnie/deep-marl-toolkit/b8e061a7e4b1de9a22d59509d4cb98a796ddc5c4/marltoolkit/runners/__init__.py
--------------------------------------------------------------------------------
/marltoolkit/runners/episode_runner.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Tuple
3 |
4 | from marltoolkit.agents.base_agent import BaseAgent
5 | from marltoolkit.data.ma_buffer import ReplayBuffer
6 | from marltoolkit.envs import MultiAgentEnv
7 | from marltoolkit.utils.logger.logs import avg_val_from_list_of_dicts
8 |
9 |
10 | def run_train_episode(
11 | env: MultiAgentEnv,
12 | agent: BaseAgent,
13 | rpm: ReplayBuffer,
14 | args: argparse.Namespace = None,
15 | ) -> dict[str, float]:
16 | episode_reward = 0.0
17 | episode_step = 0
18 | done = False
19 | agent.init_hidden_states(batch_size=1)
20 | (obs, state, info) = env.reset()
21 | while not done:
22 | available_actions = env.get_available_actions()
23 | actions = agent.sample(obs=obs, available_actions=available_actions)
24 | next_obs, next_state, reward, terminated, truncated, info = env.step(
25 | actions)
26 | done = terminated or truncated
27 | transitions = {
28 | 'obs': obs,
29 | 'state': state,
30 | 'actions': actions,
31 | 'available_actions': available_actions,
32 | 'rewards': reward,
33 | 'dones': done,
34 | 'filled': False,
35 | }
36 |
37 | rpm.store_transitions(transitions)
38 | obs, state = next_obs, next_state
39 |
40 | episode_reward += reward
41 | episode_step += 1
42 |
43 | # fill the episode
44 | for _ in range(episode_step, args.episode_limit):
45 | rpm.episode_data.fill_mask()
46 |
47 | # ReplayBuffer store the episode data
48 | rpm.store_episodes()
49 | is_win = env.win_counted
50 |
51 | train_res_lst = []
52 | if rpm.size() > args.memory_warmup_size:
53 | for _ in range(args.learner_update_freq):
54 | batch = rpm.sample(args.batch_size)
55 | results = agent.learn(batch)
56 | train_res_lst.append(results)
57 |
58 | train_res_dict = avg_val_from_list_of_dicts(train_res_lst)
59 |
60 | train_res_dict['episode_reward'] = episode_reward
61 | train_res_dict['episode_step'] = episode_step
62 | train_res_dict['win_rate'] = is_win
63 | return train_res_dict
64 |
65 |
66 | def run_eval_episode(
67 | env: MultiAgentEnv,
68 | agent: BaseAgent,
69 | args: argparse.Namespace = None,
70 | ) -> Tuple[float, float, float]:
71 | eval_res_list = []
72 | for _ in range(args.num_eval_episodes):
73 | agent.init_hidden_states(batch_size=1)
74 | episode_reward = 0.0
75 | episode_step = 0
76 | done = False
77 | obs, state, info = env.reset()
78 |
79 | while not done:
80 | available_actions = env.get_available_actions()
81 | actions = agent.predict(
82 | obs=obs,
83 | available_actions=available_actions,
84 | )
85 | next_obs, next_state, reward, terminated, truncated, info = env.step(
86 | actions)
87 | done = terminated or truncated
88 | obs = next_obs
89 | episode_step += 1
90 | episode_reward += reward
91 |
92 | is_win = env.win_counted
93 | eval_res_list.append({
94 | 'episode_reward': episode_reward,
95 | 'episode_step': episode_step,
96 | 'win_rate': is_win,
97 | })
98 | eval_res_dict = avg_val_from_list_of_dicts(eval_res_list)
99 | return eval_res_dict
100 |
--------------------------------------------------------------------------------
/marltoolkit/runners/onpolicy_runner.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from typing import Tuple
3 |
4 | from marltoolkit.agents.mappo_agent import MAPPOAgent
5 | from marltoolkit.data.shared_buffer import SharedReplayBuffer
6 | from marltoolkit.envs.smacv1 import SMACWrapperEnv
7 | from marltoolkit.utils.logger.logs import avg_val_from_list_of_dicts
8 |
9 |
10 | def run_train_episode(
11 | env: SMACWrapperEnv,
12 | agent: MAPPOAgent,
13 | rpm: SharedReplayBuffer,
14 | args: argparse.Namespace = None,
15 | ) -> dict[str, float]:
16 | episode_reward = 0.0
17 | episode_step = 0
18 | done = False
19 | agent.init_hidden_states(batch_size=1)
20 | (obs, state, info) = env.reset()
21 | available_actions = env.get_available_actions()
22 |
23 | rpm.obs[0] = obs.copy()
24 | rpm.state[0] = state.copy()
25 | rpm.available_actions[0] = available_actions.copy()
26 |
27 | for step in range(args.episode_limit):
28 | available_actions = env.get_available_actions()
29 | actions = agent.sample(obs=obs, available_actions=available_actions)
30 | next_obs, next_state, reward, terminated, truncated, info = env.step(
31 | actions)
32 | done = terminated or truncated
33 | transitions = {
34 | 'obs': obs,
35 | 'state': state,
36 | 'actions': actions,
37 | 'available_actions': available_actions,
38 | 'rewards': reward,
39 | 'dones': done,
40 | 'filled': False,
41 | }
42 |
43 | rpm.store_transitions(transitions)
44 | obs, state = next_obs, next_state
45 |
46 | episode_reward += reward
47 | episode_step += 1
48 |
49 | # fill the episode
50 | for _ in range(episode_step, args.episode_limit):
51 | rpm.episode_data.fill_mask()
52 |
53 | # ReplayBuffer store the episode data
54 | rpm.store_episodes()
55 | is_win = env.win_counted
56 |
57 | train_res_lst = []
58 | if rpm.size() > args.memory_warmup_size:
59 | for _ in range(args.learner_update_freq):
60 | batch = rpm.sample(args.batch_size)
61 | results = agent.learn(batch)
62 | train_res_lst.append(results)
63 |
64 | train_res_dict = avg_val_from_list_of_dicts(train_res_lst)
65 |
66 | train_res_dict['episode_reward'] = episode_reward
67 | train_res_dict['episode_step'] = episode_step
68 | train_res_dict['win_rate'] = is_win
69 | return train_res_dict
70 |
71 |
72 | def run_eval_episode(
73 | env: SMACWrapperEnv,
74 | agent: MAPPOAgent,
75 | args: argparse.Namespace = None,
76 | ) -> Tuple[float, float, float]:
77 | eval_res_list = []
78 | for _ in range(args.num_eval_episodes):
79 | agent.init_hidden_states(batch_size=1)
80 | episode_reward = 0.0
81 | episode_step = 0
82 | done = False
83 | obs, state, info = env.reset()
84 |
85 | while not done:
86 | available_actions = env.get_available_actions()
87 | actions = agent.predict(
88 | obs=obs,
89 | available_actions=available_actions,
90 | )
91 | next_obs, next_state, reward, terminated, truncated, info = env.step(
92 | actions)
93 | done = terminated or truncated
94 | obs = next_obs
95 | episode_step += 1
96 | episode_reward += reward
97 |
98 | is_win = env.win_counted
99 | eval_res_list.append({
100 | 'episode_reward': episode_reward,
101 | 'episode_step': episode_step,
102 | 'win_rate': is_win,
103 | })
104 | eval_res_dict = avg_val_from_list_of_dicts(eval_res_list)
105 | return eval_res_dict
106 |
--------------------------------------------------------------------------------
/marltoolkit/runners/parallel_episode_runner.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import numpy as np
4 |
5 | from marltoolkit.agents import BaseAgent
6 | from marltoolkit.data import MaEpisodeData, OffPolicyBufferRNN
7 | from marltoolkit.envs import BaseVecEnv
8 |
9 |
10 | def run_train_episode(
11 | envs: BaseVecEnv,
12 | agent: BaseAgent,
13 | rpm: OffPolicyBufferRNN,
14 | args: argparse.Namespace = None,
15 | ):
16 | agent.reset_agent()
17 | # reset the environment
18 | obs, state, info = envs.reset()
19 | num_envs = envs.num_envs
20 | env_dones = envs.buf_dones
21 | episode_step = 0
22 | episode_score = np.zeros(num_envs, dtype=np.float32)
23 | filled = np.zeros([args.episode_limit, num_envs, 1], dtype=np.int32)
24 | episode_data = MaEpisodeData(
25 | num_envs,
26 | args.num_agents,
27 | args.episode_limit,
28 | args.obs_shape,
29 | args.state_shape,
30 | args.action_shape,
31 | args.reward_shape,
32 | args.done_shape,
33 | )
34 | while not env_dones.all():
35 | available_actions = envs.get_available_actions()
36 | # Get actions from the agent
37 | actions = agent.sample(obs, available_actions)
38 | # Environment step
39 | next_obs, next_state, rewards, env_dones, info = envs.step(actions)
40 | # Fill the episode buffer
41 | filled[episode_step, :] = np.ones([num_envs, 1])
42 | transitions = dict(
43 | obs=obs,
44 | state=state,
45 | actions=actions,
46 | rewards=rewards,
47 | env_dones=env_dones,
48 | available_actions=available_actions,
49 | )
50 | # Store the transitions
51 | episode_data.store_transitions(transitions)
52 | # Check if the episode is done
53 | for env_idx in range(num_envs):
54 | if env_dones[env_idx]:
55 | # Fill the rest of the episode with zeros
56 | filled[episode_step, env_idx, :] = 0
57 | # Get the episode score from the info
58 | final_info = info['final_info']
59 | episode_score[env_idx] = final_info[env_idx]['episode_score']
60 | # Get the current available actions
61 | available_actions = envs.get_available_actions()
62 | terminal_data = (next_obs, next_state, available_actions,
63 | filled)
64 | # Finish the episode
65 | rpm.finish_path(env_idx, episode_step, *terminal_data)
66 |
67 | # Update the episode step
68 | episode_step += 1
69 | obs, state = next_obs, next_state
70 | # Store the episode data
71 | rpm.store_episodes(episode_data.episode_buffer)
72 |
73 | mean_loss = []
74 | mean_td_error = []
75 | if rpm.size() > args.memory_warmup_size:
76 | for _ in range(args.learner_update_freq):
77 | batch = rpm.sample_batch(args.batch_size)
78 | loss, td_error = agent.learn(**batch)
79 | mean_loss.append(loss)
80 | mean_td_error.append(td_error)
81 |
82 | mean_loss = np.mean(mean_loss) if mean_loss else None
83 | mean_td_error = np.mean(mean_td_error) if mean_td_error else None
84 |
85 | return episode_score, episode_step, mean_loss, mean_td_error
86 |
87 |
88 | def run_eval_episode(
89 | env: BaseVecEnv,
90 | agent: BaseAgent,
91 | num_eval_episodes: int = 5,
92 | ):
93 | eval_is_win_buffer = []
94 | eval_reward_buffer = []
95 | eval_steps_buffer = []
96 | for _ in range(num_eval_episodes):
97 | agent.reset_agent()
98 | episode_reward = 0.0
99 | episode_step = 0
100 | terminated = False
101 | obs, state = env.reset()
102 | while not terminated:
103 | available_actions = env.get_available_actions()
104 | actions = agent.predict(obs, available_actions)
105 | state, obs, reward, terminated = env.step(actions)
106 | episode_step += 1
107 | episode_reward += reward
108 |
109 | is_win = env.win_counted
110 |
111 | eval_reward_buffer.append(episode_reward)
112 | eval_steps_buffer.append(episode_step)
113 | eval_is_win_buffer.append(is_win)
114 |
115 | eval_rewards = np.mean(eval_reward_buffer)
116 | eval_steps = np.mean(eval_steps_buffer)
117 | eval_win_rate = np.mean(eval_is_win_buffer)
118 |
119 | return eval_rewards, eval_steps, eval_win_rate
120 |
--------------------------------------------------------------------------------
/marltoolkit/runners/qtran_runner.py.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from marltoolkit.data.ma_replaybuffer import EpisodeData, ReplayBuffer
4 |
5 |
6 | def run_train_episode(env, agent, rpm: ReplayBuffer, config: dict = None):
7 |
8 | episode_limit = config['episode_limit']
9 | agent.reset_agent()
10 | episode_reward = 0.0
11 | episode_step = 0
12 | terminated = False
13 | state, obs = env.reset()
14 | episode_experience = EpisodeData(
15 | episode_limit=episode_limit,
16 | state_shape=config['state_shape'],
17 | obs_shape=config['obs_shape'],
18 | num_actions=config['n_actions'],
19 | num_agents=config['n_agents'],
20 | )
21 |
22 | while not terminated:
23 | available_actions = env.get_available_actions()
24 | actions = agent.sample(obs, available_actions)
25 | actions_onehot = env._get_actions_one_hot(actions)
26 | next_state, next_obs, reward, terminated = env.step(actions)
27 | episode_reward += reward
28 | episode_step += 1
29 | episode_experience.add(state, obs, actions, actions_onehot,
30 | available_actions, reward, terminated, 0)
31 | state = next_state
32 | obs = next_obs
33 |
34 | # fill the episode
35 | for _ in range(episode_step, episode_limit):
36 | episode_experience.fill_mask()
37 |
38 | episode_data = episode_experience.get_data()
39 |
40 | rpm.store(**episode_data)
41 | is_win = env.win_counted
42 |
43 | mean_loss = []
44 | mean_td_loss = []
45 | mean_opt_loss = []
46 | mean_nopt_loss = []
47 | if rpm.size() > config['memory_warmup_size']:
48 | for _ in range(config['learner_update_freq']):
49 | batch = rpm.sample_batch(config['batch_size'])
50 | loss, td_loss, opt_loss, nopt_loss = agent.learn(**batch)
51 | mean_loss.append(loss)
52 | mean_td_loss.append(td_loss)
53 | mean_opt_loss.append(opt_loss)
54 | mean_nopt_loss.append(nopt_loss)
55 |
56 | mean_loss = np.mean(mean_loss) if mean_loss else None
57 | mean_td_loss = np.mean(mean_td_loss) if mean_td_loss else None
58 | mean_opt_loss = np.mean(mean_opt_loss) if mean_opt_loss else None
59 | mean_nopt_loss = np.mean(mean_nopt_loss) if mean_nopt_loss else None
60 |
61 | return episode_reward, episode_step, is_win, mean_loss, mean_td_loss, mean_opt_loss, mean_nopt_loss
62 |
63 |
64 | def run_eval_episode(env, agent, num_eval_episodes=5):
65 | eval_is_win_buffer = []
66 | eval_reward_buffer = []
67 | eval_steps_buffer = []
68 | for _ in range(num_eval_episodes):
69 | agent.reset_agent()
70 | episode_reward = 0.0
71 | episode_step = 0
72 | terminated = False
73 | state, obs = env.reset()
74 | while not terminated:
75 | available_actions = env.get_available_actions()
76 | actions = agent.predict(obs, available_actions)
77 | state, obs, reward, terminated = env.step(actions)
78 | episode_step += 1
79 | episode_reward += reward
80 |
81 | is_win = env.win_counted
82 |
83 | eval_reward_buffer.append(episode_reward)
84 | eval_steps_buffer.append(episode_step)
85 | eval_is_win_buffer.append(is_win)
86 |
87 | eval_rewards = np.mean(eval_reward_buffer)
88 | eval_steps = np.mean(eval_steps_buffer)
89 | eval_win_rate = np.mean(eval_is_win_buffer)
90 |
91 | return eval_rewards, eval_steps, eval_win_rate
92 |
--------------------------------------------------------------------------------
/marltoolkit/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .logger import (BaseLogger, TensorboardLogger, WandbLogger, get_outdir,
2 | get_root_logger)
3 | from .lr_scheduler import (LinearDecayScheduler, MultiStepScheduler,
4 | PiecewiseScheduler)
5 | from .model_utils import (check_model_method, hard_target_update,
6 | soft_target_update)
7 | from .progressbar import ProgressBar
8 | from .timer import Timer
9 | from .transforms import OneHotTransform
10 |
11 | __all__ = [
12 | 'BaseLogger', 'TensorboardLogger', 'WandbLogger', 'get_outdir',
13 | 'get_root_logger', 'ProgressBar', 'Timer', 'OneHotTransform',
14 | 'hard_target_update', 'soft_target_update', 'check_model_method',
15 | 'LinearDecayScheduler', 'PiecewiseScheduler', 'MultiStepScheduler'
16 | ]
17 |
--------------------------------------------------------------------------------
/marltoolkit/utils/env_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from marltoolkit.envs.smacv1 import SMACWrapperEnv
4 | from marltoolkit.envs.vec_env import SubprocVecEnv
5 |
6 |
7 | def make_vec_env(
8 | env_id: str,
9 | map_name: str,
10 | num_train_envs: int = 10,
11 | num_test_envs: int = 10,
12 | **kwargs,
13 | ):
14 |
15 | def make_env():
16 | if env_id == 'SMAC-v1':
17 | env = SMACWrapperEnv(map_name, **kwargs)
18 | else:
19 | raise ValueError(f'Unknown environment: {env_id}')
20 | return env
21 |
22 | train_envs = SubprocVecEnv([make_env for _ in range(num_train_envs)])
23 | test_envs = SubprocVecEnv([make_env for _ in range(num_test_envs)])
24 | return train_envs, test_envs
25 |
26 |
27 | def get_actor_input_dim(args: argparse.Namespace) -> None:
28 | """Get the input shape of the actor model.
29 |
30 | Args:
31 | args (argparse.Namespace): The arguments
32 | Returns:
33 | input_shape (int): The input shape of the actor model.
34 | """
35 | input_dim = args.obs_dim
36 | if args.use_global_state:
37 | input_dim += args.state_dim
38 | if args.use_last_actions:
39 | input_dim += args.n_actions
40 | if args.use_agents_id_onehot:
41 | input_dim += args.num_agents
42 | return input_dim
43 |
44 |
45 | def get_critic_input_dim(args: argparse.Namespace) -> None:
46 | """Get the input shape of the critic model.
47 |
48 | Args:
49 | args (argparse.Namespace): The arguments.
50 |
51 | Returns:
52 | input_dim (int): The input shape of the critic model.
53 | """
54 | input_dim = args.obs_dim
55 | if args.use_global_state:
56 | input_dim += args.state_dim
57 | if args.use_last_actions:
58 | input_dim += args.n_actions
59 | if args.use_agents_id_onehot:
60 | input_dim += args.num_agents
61 | return input_dim
62 |
63 |
64 | def get_shape_from_obs_space(obs_space):
65 | if obs_space.__class__.__name__ == 'Box':
66 | obs_shape = obs_space.shape
67 | elif obs_space.__class__.__name__ == 'Tuple':
68 | obs_shape = obs_space
69 | else:
70 | raise NotImplementedError
71 | return obs_shape
72 |
73 |
74 | def get_shape_from_act_space(act_space):
75 | if act_space.__class__.__name__ == 'Discrete':
76 | act_shape = 1
77 | elif act_space.__class__.__name__ == 'MultiDiscrete':
78 | act_shape = act_space.shape
79 | elif act_space.__class__.__name__ == 'Box':
80 | act_shape = act_space.shape[0]
81 | elif act_space.__class__.__name__ == 'MultiBinary':
82 | act_shape = act_space.shape[0]
83 | else: # agar
84 | act_shape = act_space[0].shape[0] + 1
85 | return act_shape
86 |
--------------------------------------------------------------------------------
/marltoolkit/utils/logger/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import BaseLogger
2 | from .logs import get_outdir, get_root_logger
3 | from .tensorboard import TensorboardLogger
4 | from .wandb import WandbLogger
5 |
6 | __all__ = [
7 | 'BaseLogger', 'TensorboardLogger', 'WandbLogger', 'get_root_logger',
8 | 'get_outdir'
9 | ]
10 |
--------------------------------------------------------------------------------
/marltoolkit/utils/logger/base.py:
--------------------------------------------------------------------------------
1 | """Code Reference Tianshou https://github.com/thu-
2 | ml/tianshou/tree/master/tianshou/utils/logger."""
3 | from abc import ABC, abstractmethod
4 | from enum import Enum
5 | from numbers import Number
6 | from typing import Callable, Dict, Optional, Tuple, Union
7 |
8 | import numpy as np
9 |
10 | LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]]
11 |
12 |
13 | class DataScope(Enum):
14 | TRAIN = 'train'
15 | TEST = 'test'
16 | UPDATE = 'update'
17 | INFO = 'info'
18 |
19 |
20 | class BaseLogger(ABC):
21 | """The base class for any logger which is compatible with trainer.
22 |
23 | Try to overwrite write() method to use your own writer.
24 |
25 | :param train_interval: the log interval in log_train_data(). Default to 1000.
26 | :param test_interval: the log interval in log_test_data(). Default to 1.
27 | :param update_interval: the log interval in log_update_data(). Default to 1000.
28 | :param info_interval: the log interval in log_info_data(). Default to 1.
29 | """
30 |
31 | def __init__(
32 | self,
33 | train_interval: int = 1000,
34 | test_interval: int = 1,
35 | update_interval: int = 1000,
36 | info_interval: int = 1,
37 | ) -> None:
38 | super().__init__()
39 | self.train_interval = train_interval
40 | self.test_interval = test_interval
41 | self.update_interval = update_interval
42 | self.info_interval = info_interval
43 | self.last_log_train_step = -1
44 | self.last_log_test_step = -1
45 | self.last_log_update_step = -1
46 | self.last_log_info_step = -1
47 |
48 | @abstractmethod
49 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
50 | """Specify how the writer is used to log data.
51 |
52 | :param str step_type: namespace which the data dict belongs to.
53 | :param step: stands for the ordinate of the data dict.
54 | :param data: the data to write with format ``{key: value}``.
55 | """
56 | pass
57 |
58 | def log_train_data(self, log_data: dict, step: int) -> None:
59 | """Use writer to log statistics generated during training.
60 |
61 | :param log_data: a dict containing information of data collected in
62 | training stage, i.e., returns of collector.collect().
63 | :param int step: stands for the timestep the log_data being logged.
64 | """
65 | if step - self.last_log_train_step >= self.train_interval:
66 | log_data = {f'train/{k}': v for k, v in log_data.items()}
67 | self.write('train/env_step', step, log_data)
68 | self.last_log_train_step = step
69 |
70 | def log_test_data(self, log_data: dict, step: int) -> None:
71 | """Use writer to log statistics generated during evaluating.
72 |
73 | :param log_data: a dict containing information of data collected in
74 | evaluating stage, i.e., returns of collector.collect().
75 | :param int step: stands for the timestep the log_data being logged.
76 | """
77 | if step - self.last_log_test_step >= self.test_interval:
78 | log_data = {f'test/{k}': v for k, v in log_data.items()}
79 | self.write('test/env_step', step, log_data)
80 | self.last_log_test_step = step
81 |
82 | def log_update_data(self, log_data: dict, step: int) -> None:
83 | """Use writer to log statistics generated during updating.
84 |
85 | :param log_data: a dict containing information of data collected in
86 | updating stage, i.e., returns of policy.update().
87 | :param int step: stands for the timestep the log_data being logged.
88 | """
89 | if step - self.last_log_update_step >= self.update_interval:
90 | log_data = {f'update/{k}': v for k, v in log_data.items()}
91 | self.write('update/gradient_step', step, log_data)
92 | self.last_log_update_step = step
93 |
94 | def log_info_data(self, log_data: dict, step: int) -> None:
95 | """Use writer to log global statistics.
96 |
97 | :param log_data: a dict containing information of data collected at the end of an epoch.
98 | :param step: stands for the timestep the training info is logged.
99 | """
100 | if (step - self.last_log_info_step >= self.info_interval):
101 | log_data = {f'info/{k}': v for k, v in log_data.items()}
102 | self.write('info/epoch', step, log_data)
103 | self.last_log_info_step = step
104 |
105 | @abstractmethod
106 | def save_data(
107 | self,
108 | epoch: int,
109 | env_step: int,
110 | gradient_step: int,
111 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
112 | ) -> None:
113 | """Use writer to log metadata when calling ``save_checkpoint_fn`` in
114 | trainer.
115 |
116 | :param int epoch: the epoch in trainer.
117 | :param int env_step: the env_step in trainer.
118 | :param int gradient_step: the gradient_step in trainer.
119 | :param function save_checkpoint_fn: a hook defined by user, see trainer
120 | documentation for detail.
121 | """
122 | pass
123 |
124 | @abstractmethod
125 | def restore_data(self) -> Tuple[int, int, int]:
126 | """Return the metadata from existing log.
127 |
128 | If it finds nothing or an error occurs during the recover process, it will
129 | return the default parameters.
130 |
131 | :return: epoch, env_step, gradient_step.
132 | """
133 | pass
134 |
135 |
136 | class LazyLogger(BaseLogger):
137 | """A logger that does nothing.
138 |
139 | Used as the placeholder in trainer.
140 | """
141 |
142 | def __init__(self) -> None:
143 | super().__init__()
144 |
145 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
146 | """The LazyLogger writes nothing."""
147 | pass
148 |
149 | def save_data(
150 | self,
151 | epoch: int,
152 | env_step: int,
153 | gradient_step: int,
154 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
155 | ) -> None:
156 | pass
157 |
158 | def restore_data(self) -> Tuple[int, int, int]:
159 | return 0, 0, 0
160 |
--------------------------------------------------------------------------------
/marltoolkit/utils/logger/base_orig.py:
--------------------------------------------------------------------------------
1 | """Code Reference Tianshou https://github.com/thu-
2 | ml/tianshou/tree/master/tianshou/utils/logger."""
3 | from abc import ABC, abstractmethod
4 | from numbers import Number
5 | from typing import Callable, Dict, Optional, Tuple, Union
6 |
7 | import numpy as np
8 |
9 | LOG_DATA_TYPE = Dict[str, Union[int, Number, np.number, np.ndarray]]
10 |
11 |
12 | class BaseLogger(ABC):
13 | """The base class for any logger which is compatible with trainer.
14 |
15 | Try to overwrite write() method to use your own writer.
16 |
17 | :param int train_interval: the log interval in log_train_data(). Default to 1000.
18 | :param int test_interval: the log interval in log_test_data(). Default to 1.
19 | :param int update_interval: the log interval in log_update_data(). Default to 1000.
20 | """
21 |
22 | def __init__(
23 | self,
24 | train_interval: int = 1000,
25 | test_interval: int = 1,
26 | update_interval: int = 1000,
27 | ) -> None:
28 | super().__init__()
29 | self.train_interval = train_interval
30 | self.test_interval = test_interval
31 | self.update_interval = update_interval
32 | self.last_log_train_step = -1
33 | self.last_log_test_step = -1
34 | self.last_log_update_step = -1
35 |
36 | @abstractmethod
37 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
38 | """Specify how the writer is used to log data.
39 |
40 | :param str step_type: namespace which the data dict belongs to.
41 | :param int step: stands for the ordinate of the data dict.
42 | :param dict data: the data to write with format ``{key: value}``.
43 | """
44 | pass
45 |
46 | def log_train_data(self, collect_result: dict, step: int) -> None:
47 | """Use writer to log statistics generated during training.
48 |
49 | :param collect_result: a dict containing information of data collected in
50 | training stage, i.e., returns of collector.collect().
51 | :param int step: stands for the timestep the collect_result being logged.
52 | """
53 | if step - self.last_log_train_step >= self.train_interval:
54 | log_data = {f'train/{k}': v for k, v in collect_result.items()}
55 | self.write('train/env_step', step, log_data)
56 | self.last_log_train_step = step
57 |
58 | def log_test_data(self, collect_result: dict, step: int) -> None:
59 | """Use writer to log statistics generated during evaluating.
60 |
61 | :param collect_result: a dict containing information of data collected in
62 | evaluating stage, i.e., returns of collector.collect().
63 | :param int step: stands for the timestep the collect_result being logged.
64 | """
65 | if step - self.last_log_test_step >= self.test_interval:
66 | log_data = {f'test/{k}': v for k, v in collect_result.items()}
67 | self.write('test/env_step', step, log_data)
68 | self.last_log_test_step = step
69 |
70 | def log_update_data(self, update_result: dict, step: int) -> None:
71 | """Use writer to log statistics generated during updating.
72 |
73 | :param update_result: a dict containing information of data collected in
74 | updating stage, i.e., returns of policy.update().
75 | :param int step: stands for the timestep the collect_result being logged.
76 | """
77 | if step - self.last_log_update_step >= self.update_interval:
78 | log_data = {f'update/{k}': v for k, v in update_result.items()}
79 | self.write('update/gradient_step', step, log_data)
80 | self.last_log_update_step = step
81 |
82 | @abstractmethod
83 | def save_data(
84 | self,
85 | epoch: int,
86 | env_step: int,
87 | gradient_step: int,
88 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
89 | ) -> None:
90 | """Use writer to log metadata when calling ``save_checkpoint_fn`` in
91 | trainer.
92 |
93 | :param int epoch: the epoch in trainer.
94 | :param int env_step: the env_step in trainer.
95 | :param int gradient_step: the gradient_step in trainer.
96 | :param function save_checkpoint_fn: a hook defined by user, see trainer
97 | documentation for detail.
98 | """
99 | pass
100 |
101 | @abstractmethod
102 | def restore_data(self) -> Tuple[int, int, int]:
103 | """Return the metadata from existing log.
104 |
105 | If it finds nothing or an error occurs during the recover process, it will
106 | return the default parameters.
107 |
108 | :return: epoch, env_step, gradient_step.
109 | """
110 | pass
111 |
112 |
113 | class LazyLogger(BaseLogger):
114 | """A logger that does nothing.
115 |
116 | Used as the placeholder in trainer.
117 | """
118 |
119 | def __init__(self) -> None:
120 | super().__init__()
121 |
122 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
123 | """The LazyLogger writes nothing."""
124 | pass
125 |
126 | def save_data(
127 | self,
128 | epoch: int,
129 | env_step: int,
130 | gradient_step: int,
131 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
132 | ) -> None:
133 | pass
134 |
135 | def restore_data(self) -> Tuple[int, int, int]:
136 | return 0, 0, 0
137 |
--------------------------------------------------------------------------------
/marltoolkit/utils/logger/logging.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import logging
3 |
4 | import torch.distributed as dist
5 |
6 | logger_initialized: dict = {}
7 |
8 |
9 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'):
10 | """Initialize and get a logger by name.
11 |
12 | If the logger has not been initialized, this method will initialize the
13 | logger by adding one or two handlers, otherwise the initialized logger will
14 | be directly returned. During initialization, a StreamHandler will always be
15 | added. If `log_file` is specified and the process rank is 0, a FileHandler
16 | will also be added.
17 |
18 | Args:
19 | name (str): Logger name.
20 | log_file (str | None): The log filename. If specified, a FileHandler
21 | will be added to the logger.
22 | log_level (int): The logger level. Note that only the process of
23 | rank 0 is affected, and other processes will set the level to
24 | "Error" thus be silent most of the time.
25 | file_mode (str): The file mode used in opening log file.
26 | Defaults to 'w'.
27 |
28 | Returns:
29 | logging.Logger: The expected logger.
30 | """
31 | logger = logging.getLogger(name)
32 | if name in logger_initialized:
33 | return logger
34 | # handle hierarchical names
35 | # e.g., logger "a" is initialized, then logger "a.b" will skip the
36 | # initialization since it is a child of "a".
37 | for logger_name in logger_initialized:
38 | if name.startswith(logger_name):
39 | return logger
40 |
41 | # handle duplicate logs to the console
42 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET)
43 | # to the root logger. As logger.propagate is True by default, this root
44 | # level handler causes logging messages from rank>0 processes to
45 | # unexpectedly show up on the console, creating much unwanted clutter.
46 | # To fix this issue, we set the root logger's StreamHandler, if any, to log
47 | # at the ERROR level.
48 | for handler in logger.root.handlers:
49 | if type(handler) is logging.StreamHandler:
50 | handler.setLevel(logging.ERROR)
51 |
52 | stream_handler = logging.StreamHandler()
53 | handlers = [stream_handler]
54 |
55 | if dist.is_available() and dist.is_initialized():
56 | rank = dist.get_rank()
57 | else:
58 | rank = 0
59 |
60 | # only rank 0 will add a FileHandler
61 | if rank == 0 and log_file is not None:
62 | # Here, the default behaviour of the official logger is 'a'. Thus, we
63 | # provide an interface to change the file mode to the default
64 | # behaviour.
65 | file_handler = logging.FileHandler(log_file, file_mode)
66 | handlers.append(file_handler)
67 |
68 | formatter = logging.Formatter(
69 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
70 | for handler in handlers:
71 | handler.setFormatter(formatter)
72 | handler.setLevel(log_level)
73 | logger.addHandler(handler)
74 |
75 | if rank == 0:
76 | logger.setLevel(log_level)
77 | else:
78 | logger.setLevel(logging.ERROR)
79 |
80 | logger_initialized[name] = True
81 |
82 | return logger
83 |
84 |
85 | def print_log(msg, logger=None, level=logging.INFO):
86 | """Print a log message.
87 |
88 | Args:
89 | msg (str): The message to be logged.
90 | logger (logging.Logger | str | None): The logger to be used.
91 | Some special loggers are:
92 |
93 | - "silent": no message will be printed.
94 | - other str: the logger obtained with `get_root_logger(logger)`.
95 | - None: The `print()` method will be used to print log messages.
96 | level (int): Logging level. Only available when `logger` is a Logger
97 | object or "root".
98 | """
99 | if logger is None:
100 | print(msg)
101 | elif isinstance(logger, logging.Logger):
102 | logger.log(level, msg)
103 | elif logger == 'silent':
104 | pass
105 | elif isinstance(logger, str):
106 | _logger = get_logger(logger)
107 | _logger.log(level, msg)
108 | else:
109 | raise TypeError(
110 | 'logger should be either a logging.Logger object, str, '
111 | f'"silent" or None, but got {type(logger)}')
112 |
--------------------------------------------------------------------------------
/marltoolkit/utils/logger/logs.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from collections import OrderedDict, defaultdict
4 |
5 | from .logging import get_logger
6 |
7 | try:
8 | import wandb
9 | except ImportError:
10 | pass
11 |
12 |
13 | def get_root_logger(log_file=None, log_level=logging.INFO):
14 | """Get root logger.
15 |
16 | Args:
17 | log_file (str, optional): File path of log. Defaults to None.
18 | log_level (int, optional): The level of logger.
19 | Defaults to logging.INFO.
20 |
21 | Returns:
22 | :obj:`logging.Logger`: The obtained logger
23 | """
24 | logger = get_logger(name='marltoolkit',
25 | log_file=log_file,
26 | log_level=log_level)
27 |
28 | return logger
29 |
30 |
31 | def get_outdir(path, *paths, inc=False):
32 | outdir = os.path.join(path, *paths)
33 | if not os.path.exists(outdir):
34 | os.makedirs(outdir)
35 | elif inc:
36 | count = 1
37 | outdir_inc = outdir + '-' + str(count)
38 | while os.path.exists(outdir_inc):
39 | count = count + 1
40 | outdir_inc = outdir + '-' + str(count)
41 | assert count < 100
42 | outdir = outdir_inc
43 | os.makedirs(outdir)
44 | return outdir
45 |
46 |
47 | def avg_val_from_list_of_dicts(list_of_dicts):
48 | sum_values = defaultdict(int)
49 | count_dicts = defaultdict(int)
50 |
51 | # Transpose the list of dictionaries into a list of key-value pairs
52 | for dictionary in list_of_dicts:
53 | for key, value in dictionary.items():
54 | sum_values[key] += value
55 | count_dicts[key] += 1
56 |
57 | # Calculate the average values using a dictionary comprehension
58 | avg_val_dict = {
59 | key: sum_value / count_dicts[key]
60 | for key, sum_value in sum_values.items()
61 | }
62 |
63 | return avg_val_dict
64 |
65 |
66 | def update_summary(train_metrics, eval_metrics, log_wandb=False):
67 | rowd = OrderedDict()
68 | rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
69 | rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
70 | if log_wandb:
71 | wandb.log(rowd)
72 |
--------------------------------------------------------------------------------
/marltoolkit/utils/logger/tensorboard.py:
--------------------------------------------------------------------------------
1 | """Code Reference Tianshou https://github.com/thu-
2 | ml/tianshou/tree/master/tianshou/utils/logger."""
3 |
4 | from typing import Callable, Optional, Tuple
5 |
6 | from tensorboard.backend.event_processing import event_accumulator
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | from .base import LOG_DATA_TYPE, BaseLogger
10 |
11 |
12 | class TensorboardLogger(BaseLogger):
13 | """A logger that relies on tensorboard SummaryWriter by default to
14 | visualize and log statistics.
15 |
16 | :param SummaryWriter writer: the writer to log data.
17 | :param train_interval: the log interval in log_train_data(). Default to 1000.
18 | :param test_interval: the log interval in log_test_data(). Default to 1.
19 | :param update_interval: the log interval in log_update_data(). Default to 1000.
20 | :param info_interval: the log interval in log_info_data(). Default to 1.
21 | :param save_interval: the save interval in save_data(). Default to 1 (save at
22 | the end of each epoch).
23 | :param write_flush: whether to flush tensorboard result after each
24 | add_scalar operation. Default to True.
25 | """
26 |
27 | def __init__(
28 | self,
29 | writer: SummaryWriter,
30 | train_interval: int = 1000,
31 | test_interval: int = 1,
32 | update_interval: int = 1000,
33 | info_interval: int = 1,
34 | save_interval: int = 1,
35 | write_flush: bool = True,
36 | ) -> None:
37 | super().__init__(train_interval, test_interval, update_interval,
38 | info_interval)
39 | self.save_interval = save_interval
40 | self.write_flush = write_flush
41 | self.last_save_step = -1
42 | self.writer = writer
43 |
44 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
45 | for k, v in data.items():
46 | self.writer.add_scalar(k, v, global_step=step)
47 | if self.write_flush: # issue 580
48 | self.writer.flush() # issue #482
49 |
50 | def save_data(
51 | self,
52 | epoch: int,
53 | env_step: int,
54 | gradient_step: int,
55 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
56 | ) -> None:
57 | if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
58 | self.last_save_step = epoch
59 | save_checkpoint_fn(epoch, env_step, gradient_step)
60 | self.write('save/epoch', epoch, {'save/epoch': epoch})
61 | self.write('save/env_step', env_step, {'save/env_step': env_step})
62 | self.write('save/gradient_step', gradient_step,
63 | {'save/gradient_step': gradient_step})
64 |
65 | def restore_data(self) -> Tuple[int, int, int]:
66 | ea = event_accumulator.EventAccumulator(self.writer.log_dir)
67 | ea.Reload()
68 |
69 | try: # epoch / gradient_step
70 | epoch = ea.scalars.Items('save/epoch')[-1].step
71 | self.last_save_step = self.last_log_test_step = epoch
72 | gradient_step = ea.scalars.Items('save/gradient_step')[-1].step
73 | self.last_log_update_step = gradient_step
74 | except KeyError:
75 | epoch, gradient_step = 0, 0
76 | try: # offline trainer doesn't have env_step
77 | env_step = ea.scalars.Items('save/env_step')[-1].step
78 | self.last_log_train_step = env_step
79 | except KeyError:
80 | env_step = 0
81 |
82 | return epoch, env_step, gradient_step
83 |
--------------------------------------------------------------------------------
/marltoolkit/utils/logger/wandb.py:
--------------------------------------------------------------------------------
1 | """Code Reference Tianshou https://github.com/thu-
2 | ml/tianshou/tree/master/tianshou/utils/logger."""
3 |
4 | import argparse
5 | import contextlib
6 | import os
7 | from typing import Callable, Optional, Tuple
8 |
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | from .base import LOG_DATA_TYPE, BaseLogger
12 | from .tensorboard import TensorboardLogger
13 |
14 | with contextlib.suppress(ImportError):
15 | import wandb
16 |
17 |
18 | class WandbLogger(BaseLogger):
19 | """Weights and Biases logger that sends data to https://wandb.ai/.
20 |
21 | This logger creates three panels with plots: train, test, and update.
22 | Make sure to select the correct access for each panel in weights and biases:
23 |
24 | Example of usage:
25 | ::
26 |
27 | logger = WandbLogger()
28 | logger.load(SummaryWriter(log_path))
29 | result = OnpolicyTrainer(policy, train_collector, test_collector,
30 | logger=logger).run()
31 |
32 | :param train_interval: the log interval in log_train_data(). Default to 1000.
33 | :param test_interval: the log interval in log_test_data(). Default to 1.
34 | :param update_interval: the log interval in log_update_data().
35 | Default to 1000.
36 | :param info_interval: the log interval in log_info_data(). Default to 1.
37 | :param save_interval: the save interval in save_data(). Default to 1 (save at
38 | the end of each epoch).
39 | :param write_flush: whether to flush tensorboard result after each
40 | add_scalar operation. Default to True.
41 | :param str project: W&B project name. Default to "tianshou".
42 | :param str name: W&B run name. Default to None. If None, random name is assigned.
43 | :param str entity: W&B team/organization name. Default to None.
44 | :param str run_id: run id of W&B run to be resumed. Default to None.
45 | :param argparse.Namespace config: experiment configurations. Default to None.
46 | """
47 |
48 | def __init__(
49 | self,
50 | dir: str = None,
51 | train_interval: int = 1000,
52 | test_interval: int = 1,
53 | update_interval: int = 1000,
54 | info_interval: int = 1,
55 | save_interval: int = 1000,
56 | write_flush: bool = True,
57 | project: Optional[str] = None,
58 | name: Optional[str] = None,
59 | entity: Optional[str] = None,
60 | run_id: Optional[str] = None,
61 | config: Optional[argparse.Namespace] = None,
62 | monitor_gym: bool = True,
63 | ) -> None:
64 | super().__init__(train_interval, test_interval, update_interval,
65 | info_interval)
66 | self.last_save_step = -1
67 | self.save_interval = save_interval
68 | self.write_flush = write_flush
69 | self.restored = False
70 | if project is None:
71 | project = os.getenv('WANDB_PROJECT', 'marltoolkit')
72 |
73 | self.wandb_run = wandb.init(
74 | dir=dir,
75 | project=project,
76 | name=name,
77 | id=run_id,
78 | resume='allow',
79 | entity=entity,
80 | sync_tensorboard=True,
81 | monitor_gym=monitor_gym,
82 | config=config, # type: ignore
83 | ) if not wandb.run else wandb.run
84 | self.wandb_run._label(repo='marltoolkit') # type: ignore
85 | self.tensorboard_logger: Optional[TensorboardLogger] = None
86 |
87 | def load(self, writer: SummaryWriter) -> None:
88 | self.writer = writer
89 | self.tensorboard_logger = TensorboardLogger(
90 | writer,
91 | self.train_interval,
92 | self.test_interval,
93 | self.update_interval,
94 | self.save_interval,
95 | self.write_flush,
96 | )
97 |
98 | def write(self, step_type: str, step: int, data: LOG_DATA_TYPE) -> None:
99 | if self.tensorboard_logger is None:
100 | raise Exception(
101 | '`logger` needs to load the Tensorboard Writer before '
102 | 'writing data. Try `logger.load(SummaryWriter(log_path))`')
103 | else:
104 | self.tensorboard_logger.write(step_type, step, data)
105 |
106 | def save_data(
107 | self,
108 | epoch: int,
109 | env_step: int,
110 | gradient_step: int,
111 | save_checkpoint_fn: Optional[Callable[[int, int, int], str]] = None,
112 | ) -> None:
113 | """Use writer to log metadata when calling ``save_checkpoint_fn`` in
114 | trainer.
115 |
116 | :param epoch: the epoch in trainer.
117 | :param env_step: the env_step in trainer.
118 | :param gradient_step: the gradient_step in trainer.
119 | :param function save_checkpoint_fn: a hook defined by user, see trainer
120 | documentation for detail.
121 | """
122 | if save_checkpoint_fn and epoch - self.last_save_step >= self.save_interval:
123 | self.last_save_step = epoch
124 | checkpoint_path = save_checkpoint_fn(epoch, env_step,
125 | gradient_step)
126 |
127 | checkpoint_artifact = wandb.Artifact(
128 | 'run_' + self.wandb_run.id + '_checkpoint', # type: ignore
129 | type='model',
130 | metadata={
131 | 'save/epoch': epoch,
132 | 'save/env_step': env_step,
133 | 'save/gradient_step': gradient_step,
134 | 'checkpoint_path': str(checkpoint_path),
135 | })
136 | checkpoint_artifact.add_file(str(checkpoint_path))
137 | self.wandb_run.log_artifact(checkpoint_artifact) # type: ignore
138 |
139 | def restore_data(self) -> Tuple[int, int, int]:
140 | checkpoint_artifact = self.wandb_run.use_artifact( # type: ignore
141 | f'run_{self.wandb_run.id}_checkpoint:latest' # type: ignore
142 | )
143 | assert checkpoint_artifact is not None, "W&B dataset artifact doesn't exist"
144 |
145 | checkpoint_artifact.download(
146 | os.path.dirname(checkpoint_artifact.metadata['checkpoint_path']))
147 |
148 | try: # epoch / gradient_step
149 | epoch = checkpoint_artifact.metadata['save/epoch']
150 | self.last_save_step = self.last_log_test_step = epoch
151 | gradient_step = checkpoint_artifact.metadata['save/gradient_step']
152 | self.last_log_update_step = gradient_step
153 | except KeyError:
154 | epoch, gradient_step = 0, 0
155 | try: # offline trainer doesn't have env_step
156 | env_step = checkpoint_artifact.metadata['save/env_step']
157 | self.last_log_train_step = env_step
158 | except KeyError:
159 | env_step = 0
160 | return epoch, env_step, gradient_step
161 |
--------------------------------------------------------------------------------
/marltoolkit/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | '''
2 | Author: jianzhnie@126.com
3 | Date: 2022-09-01 15:17:32
4 | LastEditors: jianzhnie@126.com
5 | LastEditTime: 2022-09-01 15:17:35
6 | Description:
7 |
8 | Copyright (c) 2022 by jianzhnie, All Rights Reserved.
9 | '''
10 |
11 | from collections import Counter
12 | from typing import List
13 |
14 | import six
15 |
16 | __all__ = ['PiecewiseScheduler', 'LinearDecayScheduler']
17 |
18 |
19 | class PiecewiseScheduler(object):
20 | """Set hyper parameters by a predefined step-based scheduler."""
21 |
22 | def __init__(self, scheduler_list):
23 | """Piecewise scheduler of hyper parameter.
24 |
25 | Args:
26 | scheduler_list: list of (step, value) pair. E.g. [(0, 0.001), (10000, 0.0005)]
27 | """
28 | assert len(scheduler_list) > 0
29 |
30 | for i in six.moves.range(len(scheduler_list) - 1):
31 | assert scheduler_list[i][0] < scheduler_list[i + 1][0], \
32 | 'step of scheduler_list should be incremental.'
33 |
34 | self.scheduler_list = scheduler_list
35 |
36 | self.cur_index = 0
37 | self.cur_step = 0
38 | self.cur_value = self.scheduler_list[0][1]
39 |
40 | self.scheduler_num = len(self.scheduler_list)
41 |
42 | def step(self, step_num=1):
43 | """Step step_num and fetch value according to following rule:
44 |
45 | Given scheduler_list: [(step_0, value_0), (step_1, value_1), ..., (step_N, value_N)],
46 | function will return value_K which satisfying self.cur_step >= step_K and self.cur_step < step_K+1
47 |
48 | Args:
49 | step_num (int): number of steps (default: 1)
50 | """
51 | assert isinstance(step_num, int) and step_num >= 1
52 | self.cur_step += step_num
53 |
54 | if self.cur_index < self.scheduler_num - 1:
55 | if self.cur_step >= self.scheduler_list[self.cur_index + 1][0]:
56 | self.cur_index += 1
57 | self.cur_value = self.scheduler_list[self.cur_index][1]
58 |
59 | return self.cur_value
60 |
61 |
62 | class LinearDecayScheduler(object):
63 | """Set hyper parameters by a step-based scheduler with linear decay
64 | values."""
65 |
66 | def __init__(self, start_value, max_steps):
67 | """Linear decay scheduler of hyper parameter. Decay value linearly
68 | until 0.
69 |
70 | Args:
71 | start_value (float): start value
72 | max_steps (int): maximum steps
73 | """
74 | assert max_steps > 0
75 | self.cur_step = 0
76 | self.max_steps = max_steps
77 | self.start_value = start_value
78 |
79 | def step(self, step_num=1):
80 | """Step step_num and fetch value according to following rule:
81 |
82 | return_value = start_value * (1.0 - (cur_steps / max_steps))
83 |
84 | Args:
85 | step_num (int): number of steps (default: 1)
86 |
87 | Returns:
88 | value (float): current value
89 | """
90 | assert isinstance(step_num, int) and step_num >= 1
91 | self.cur_step = min(self.cur_step + step_num, self.max_steps)
92 |
93 | value = self.start_value * (1.0 -
94 | ((self.cur_step * 1.0) / self.max_steps))
95 |
96 | return value
97 |
98 |
99 | class MultiStepScheduler(object):
100 | """step learning rate scheduler."""
101 |
102 | def __init__(self,
103 | start_value: float,
104 | max_steps: int,
105 | milestones: List = None,
106 | decay_factor: float = 0.1):
107 | assert max_steps > 0
108 | assert isinstance(decay_factor, float)
109 | assert decay_factor > 0 and decay_factor < 1
110 | self.milestones = Counter(milestones)
111 | self.cur_value = start_value
112 | self.cur_step = 0
113 | self.max_steps = max_steps
114 | self.decay_factor = decay_factor
115 |
116 | def step(self, step_num=1):
117 | assert isinstance(step_num, int) and step_num >= 1
118 | self.cur_step = min(self.cur_step + step_num, self.max_steps)
119 |
120 | if self.cur_step not in self.milestones:
121 | return self.cur_value
122 | else:
123 | self.cur_value *= self.decay_factor**self.milestones[self.cur_step]
124 |
125 | return self.cur_value
126 |
127 |
128 | if __name__ == '__main__':
129 | scheduler = MultiStepScheduler(100, 100, [50, 80], 0.5)
130 | for i in range(101):
131 | value = scheduler.step()
132 | print(value)
133 |
--------------------------------------------------------------------------------
/marltoolkit/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch.nn as nn
4 |
5 |
6 | def huber_loss(e, d):
7 | a = (abs(e) <= d).float()
8 | b = (abs(e) > d).float()
9 | return a * e**2 / 2 + b * d * (abs(e) - d / 2)
10 |
11 |
12 | def mse_loss(e):
13 | return e**2 / 2
14 |
15 |
16 | def get_gard_norm(it):
17 | sum_grad = 0
18 | for x in it:
19 | if x.grad is None:
20 | continue
21 | sum_grad += x.grad.norm()**2
22 | return math.sqrt(sum_grad)
23 |
24 |
25 | def hard_target_update(src: nn.Module, tgt: nn.Module) -> None:
26 | """Hard update model parameters.
27 |
28 | Params
29 | ======
30 | src: PyTorch model (weights will be copied from)
31 | tgt: PyTorch model (weights will be copied to)
32 | """
33 |
34 | tgt.load_state_dict(src.state_dict())
35 |
36 |
37 | def soft_target_update(src: nn.Module, tgt: nn.Module, tau=0.05) -> None:
38 | """Soft update model parameters.
39 |
40 | θ_target = τ*θ_local + (1 - τ)*θ_target
41 |
42 | Params
43 | ======
44 | src: PyTorch model (weights will be copied from)
45 | tgt: PyTorch model (weights will be copied to)
46 | tau (float): interpolation parameter
47 | """
48 | for src_param, tgt_param in zip(src.parameters(), tgt.parameters()):
49 | tgt_param.data.copy_(tau * src_param.data +
50 | (1.0 - tau) * tgt_param.data)
51 |
52 |
53 | def check_model_method(model, method, algo):
54 | """check method existence for input model to algo.
55 |
56 | Args:
57 | model(nn.Model): model for checking
58 | method(str): method name
59 | algo(str): algorithm name
60 |
61 | Raises:
62 | AssertionError: if method is not implemented in model
63 | """
64 | if method == 'forward':
65 | # check if forward is overridden by the subclass
66 | assert callable(
67 | getattr(model, 'forward',
68 | None)), 'forward should be a function in model class'
69 | assert model.forward.__func__ is not super(
70 | model.__class__, model
71 | ).forward.__func__, "{}'s model needs to implement forward method. \n".format(
72 | algo)
73 | else:
74 | # check if the specified method is implemented
75 | assert hasattr(model, method) and callable(getattr(
76 | model, method,
77 | None)), "{}'s model needs to implement {} method. \n".format(
78 | algo, method)
79 |
--------------------------------------------------------------------------------
/marltoolkit/utils/progressbar.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import sys
3 | from collections.abc import Iterable
4 | from multiprocessing import Pool
5 | from shutil import get_terminal_size
6 |
7 | from .timer import Timer
8 |
9 |
10 | class ProgressBar:
11 | """A progress bar which can print the progress."""
12 |
13 | def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
14 | self.task_num = task_num
15 | self.bar_width = bar_width
16 | self.completed = 0
17 | self.file = file
18 | if start:
19 | self.start()
20 |
21 | @property
22 | def terminal_width(self):
23 | width, _ = get_terminal_size()
24 | return width
25 |
26 | def start(self):
27 | if self.task_num > 0:
28 | self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
29 | 'elapsed: 0s, ETA:')
30 | else:
31 | self.file.write('completed: 0, elapsed: 0s')
32 | self.file.flush()
33 | self.timer = Timer()
34 |
35 | def update(self, num_tasks=1):
36 | assert num_tasks > 0
37 | self.completed += num_tasks
38 | elapsed = self.timer.since_start()
39 | if elapsed > 0:
40 | fps = self.completed / elapsed
41 | else:
42 | fps = float('inf')
43 | if self.task_num > 0:
44 | percentage = self.completed / float(self.task_num)
45 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
46 | msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
47 | f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
48 | f'ETA: {eta:5}s'
49 |
50 | bar_width = min(self.bar_width,
51 | int(self.terminal_width - len(msg)) + 2,
52 | int(self.terminal_width * 0.6))
53 | bar_width = max(2, bar_width)
54 | mark_width = int(bar_width * percentage)
55 | bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
56 | self.file.write(msg.format(bar_chars))
57 | else:
58 | self.file.write(
59 | f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
60 | f' {fps:.1f} tasks/s')
61 | self.file.flush()
62 |
63 |
64 | def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
65 | """Track the progress of tasks execution with a progress bar.
66 |
67 | Tasks are done with a simple for-loop.
68 |
69 | Args:
70 | func (callable): The function to be applied to each task.
71 | tasks (list or tuple[Iterable, int]): A list of tasks or
72 | (tasks, total num).
73 | bar_width (int): Width of progress bar.
74 |
75 | Returns:
76 | list: The task results.
77 | """
78 | if isinstance(tasks, tuple):
79 | assert len(tasks) == 2
80 | assert isinstance(tasks[0], Iterable)
81 | assert isinstance(tasks[1], int)
82 | task_num = tasks[1]
83 | tasks = tasks[0]
84 | elif isinstance(tasks, Iterable):
85 | task_num = len(tasks)
86 | else:
87 | raise TypeError(
88 | '"tasks" must be an iterable object or a (iterator, int) tuple')
89 | prog_bar = ProgressBar(task_num, bar_width, file=file)
90 | results = []
91 | for task in tasks:
92 | results.append(func(task, **kwargs))
93 | prog_bar.update()
94 | prog_bar.file.write('\n')
95 | return results
96 |
97 |
98 | def init_pool(process_num, initializer=None, initargs=None):
99 | if initializer is None:
100 | return Pool(process_num)
101 | elif initargs is None:
102 | return Pool(process_num, initializer)
103 | else:
104 | if not isinstance(initargs, tuple):
105 | raise TypeError('"initargs" must be a tuple')
106 | return Pool(process_num, initializer, initargs)
107 |
108 |
109 | def track_parallel_progress(func,
110 | tasks,
111 | nproc,
112 | initializer=None,
113 | initargs=None,
114 | bar_width=50,
115 | chunksize=1,
116 | skip_first=False,
117 | keep_order=True,
118 | file=sys.stdout):
119 | """Track the progress of parallel task execution with a progress bar.
120 |
121 | The built-in :mod:`multiprocessing` module is used for process pools and
122 | tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
123 |
124 | Args:
125 | func (callable): The function to be applied to each task.
126 | tasks (list or tuple[Iterable, int]): A list of tasks or
127 | (tasks, total num).
128 | nproc (int): Process (worker) number.
129 | initializer (None or callable): Refer to :class:`multiprocessing.Pool`
130 | for details.
131 | initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
132 | details.
133 | chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
134 | bar_width (int): Width of progress bar.
135 | skip_first (bool): Whether to skip the first sample for each worker
136 | when estimating fps, since the initialization step may takes
137 | longer.
138 | keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
139 | :func:`Pool.imap_unordered` is used.
140 |
141 | Returns:
142 | list: The task results.
143 | """
144 | if isinstance(tasks, tuple):
145 | assert len(tasks) == 2
146 | assert isinstance(tasks[0], Iterable)
147 | assert isinstance(tasks[1], int)
148 | task_num = tasks[1]
149 | tasks = tasks[0]
150 | elif isinstance(tasks, Iterable):
151 | task_num = len(tasks)
152 | else:
153 | raise TypeError(
154 | '"tasks" must be an iterable object or a (iterator, int) tuple')
155 | pool = init_pool(nproc, initializer, initargs)
156 | start = not skip_first
157 | task_num -= nproc * chunksize * int(skip_first)
158 | prog_bar = ProgressBar(task_num, bar_width, start, file=file)
159 | results = []
160 | if keep_order:
161 | gen = pool.imap(func, tasks, chunksize)
162 | else:
163 | gen = pool.imap_unordered(func, tasks, chunksize)
164 | for result in gen:
165 | results.append(result)
166 | if skip_first:
167 | if len(results) < nproc * chunksize:
168 | continue
169 | elif len(results) == nproc * chunksize:
170 | prog_bar.start()
171 | continue
172 | prog_bar.update()
173 | prog_bar.file.write('\n')
174 | pool.close()
175 | pool.join()
176 | return results
177 |
178 |
179 | def track_iter_progress(tasks, bar_width=50, file=sys.stdout):
180 | """Track the progress of tasks iteration or enumeration with a progress
181 | bar.
182 |
183 | Tasks are yielded with a simple for-loop.
184 |
185 | Args:
186 | tasks (list or tuple[Iterable, int]): A list of tasks or
187 | (tasks, total num).
188 | bar_width (int): Width of progress bar.
189 |
190 | Yields:
191 | list: The task results.
192 | """
193 | if isinstance(tasks, tuple):
194 | assert len(tasks) == 2
195 | assert isinstance(tasks[0], Iterable)
196 | assert isinstance(tasks[1], int)
197 | task_num = tasks[1]
198 | tasks = tasks[0]
199 | elif isinstance(tasks, Iterable):
200 | task_num = len(tasks)
201 | else:
202 | raise TypeError(
203 | '"tasks" must be an iterable object or a (iterator, int) tuple')
204 | prog_bar = ProgressBar(task_num, bar_width, file=file)
205 | for task in tasks:
206 | yield task
207 | prog_bar.update()
208 | prog_bar.file.write('\n')
209 |
--------------------------------------------------------------------------------
/marltoolkit/utils/timer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from time import time
3 |
4 |
5 | class TimerError(Exception):
6 |
7 | def __init__(self, message):
8 | self.message = message
9 | super().__init__(message)
10 |
11 |
12 | class Timer:
13 | """A flexible Timer class.
14 |
15 | Examples:
16 | >>> import time
17 | >>> import mmcv
18 | >>> with mmcv.Timer():
19 | >>> # simulate a code block that will run for 1s
20 | >>> time.sleep(1)
21 | 1.000
22 | >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
23 | >>> # simulate a code block that will run for 1s
24 | >>> time.sleep(1)
25 | it takes 1.0 seconds
26 | >>> timer = mmcv.Timer()
27 | >>> time.sleep(0.5)
28 | >>> print(timer.since_start())
29 | 0.500
30 | >>> time.sleep(0.5)
31 | >>> print(timer.since_last_check())
32 | 0.500
33 | >>> print(timer.since_start())
34 | 1.000
35 | """
36 |
37 | def __init__(self, start=True, print_tmpl=None):
38 | self._is_running = False
39 | self.print_tmpl = print_tmpl if print_tmpl else '{:.3f}'
40 | if start:
41 | self.start()
42 |
43 | @property
44 | def is_running(self):
45 | """bool: indicate whether the timer is running"""
46 | return self._is_running
47 |
48 | def __enter__(self):
49 | self.start()
50 | return self
51 |
52 | def __exit__(self, type, value, traceback):
53 | print(self.print_tmpl.format(self.since_last_check()))
54 | self._is_running = False
55 |
56 | def start(self):
57 | """Start the timer."""
58 | if not self._is_running:
59 | self._t_start = time()
60 | self._is_running = True
61 | self._t_last = time()
62 |
63 | def since_start(self):
64 | """Total time since the timer is started.
65 |
66 | Returns:
67 | float: Time in seconds.
68 | """
69 | if not self._is_running:
70 | raise TimerError('timer is not running')
71 | self._t_last = time()
72 | return self._t_last - self._t_start
73 |
74 | def since_last_check(self):
75 | """Time since the last checking.
76 |
77 | Either :func:`since_start` or :func:`since_last_check` is a checking
78 | operation.
79 |
80 | Returns:
81 | float: Time in seconds.
82 | """
83 | if not self._is_running:
84 | raise TimerError('timer is not running')
85 | dur = time() - self._t_last
86 | self._t_last = time()
87 | return dur
88 |
89 |
90 | _g_timers = {} # global timers
91 |
92 |
93 | def check_time(timer_id):
94 | """Add check points in a single line.
95 |
96 | This method is suitable for running a task on a list of items. A timer will
97 | be registered when the method is called for the first time.
98 |
99 | Examples:
100 | >>> import time
101 | >>> import mmcv
102 | >>> for i in range(1, 6):
103 | >>> # simulate a code block
104 | >>> time.sleep(i)
105 | >>> mmcv.check_time('task1')
106 | 2.000
107 | 3.000
108 | 4.000
109 | 5.000
110 |
111 | Args:
112 | str: Timer identifier.
113 | """
114 | if timer_id not in _g_timers:
115 | _g_timers[timer_id] = Timer()
116 | return 0
117 | else:
118 | return _g_timers[timer_id].since_last_check()
119 |
--------------------------------------------------------------------------------
/marltoolkit/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class OneHotTransform(object):
5 | """One hot transform, convert index to one hot vector.
6 |
7 | Args:
8 | out_dim (int): The dimension of one hot vector.
9 | """
10 |
11 | def __init__(self, out_dim: int) -> None:
12 | self.out_dim = out_dim
13 |
14 | def __call__(self, index: int) -> np.ndarray:
15 | assert index < self.out_dim
16 | one_hot_id = np.zeros([self.out_dim])
17 | one_hot_id[index] = 1.0
18 | return one_hot_id
19 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/oxwhirl/smac.git
2 | git+https://github.com/oxwhirl/smacv2.git
3 | numpy
4 | protobuf
5 | pygame
6 | pyglet
7 | PyYAML
8 | scipy
9 | six
10 | tensorboard
11 | tensorboard-logger
12 | tensorboardX
13 | torch
14 | torchvision
15 | tqdm
16 | wandb
17 |
--------------------------------------------------------------------------------
/scripts/main_coma.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import time
5 |
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | sys.path.append('../')
10 |
11 | from configs.arguments import get_common_args
12 | from configs.coma_config import ComaConfig
13 | from marltoolkit.agents.coma_agent import ComaAgent
14 | from marltoolkit.data import ReplayBuffer
15 | from marltoolkit.envs.smacv1.smac_env import SMACWrapperEnv
16 | from marltoolkit.modules.actors import RNNActorModel
17 | from marltoolkit.modules.critics.coma import MLPCriticModel
18 | from marltoolkit.runners.episode_runner import (run_eval_episode,
19 | run_train_episode)
20 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger,
21 | get_outdir, get_root_logger)
22 |
23 |
24 | def main():
25 | """Main function for running the QMix algorithm.
26 |
27 | This function initializes the necessary configurations, environment, logger, models, and agents.
28 | It then runs training episodes and evaluates the agent's performance periodically.
29 |
30 | Returns:
31 | None
32 | """
33 | coma_config = ComaConfig()
34 | common_args = get_common_args()
35 | args = argparse.Namespace(**vars(common_args), **vars(coma_config))
36 | device = (torch.device('cuda') if torch.cuda.is_available() and args.cuda
37 | else torch.device('cpu'))
38 |
39 | env = SMACWrapperEnv(map_name=args.scenario, difficulty=args.difficulty)
40 | args.episode_limit = env.episode_limit
41 | args.obs_dim = env.obs_dim
42 | args.obs_shape = env.obs_shape
43 | args.state_dim = env.state_dim
44 | args.state_shape = env.state_shape
45 | args.num_agents = env.num_agents
46 | args.n_actions = env.n_actions
47 | args.action_shape = env.action_shape
48 | args.reward_shape = env.reward_shape
49 | args.done_shape = env.done_shape
50 | args.device = device
51 |
52 | # init the logger before other steps
53 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
54 | # log
55 | log_name = os.path.join(args.project, args.scenario, args.algo_name,
56 | timestamp).replace(os.path.sep, '_')
57 | log_path = os.path.join(args.log_dir, args.project, args.scenario,
58 | args.algo_name)
59 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir')
60 | log_file = os.path.join(log_path, log_name + '.log')
61 | text_logger = get_root_logger(log_file=log_file, log_level='INFO')
62 |
63 | if args.logger == 'wandb':
64 | logger = WandbLogger(
65 | train_interval=args.train_log_interval,
66 | test_interval=args.test_log_interval,
67 | update_interval=args.train_log_interval,
68 | project=args.project,
69 | name=log_name,
70 | save_interval=1,
71 | config=args,
72 | )
73 | writer = SummaryWriter(tensorboard_log_path)
74 | writer.add_text('args', str(args))
75 | if args.logger == 'tensorboard':
76 | logger = TensorboardLogger(writer)
77 | else: # wandb
78 | logger.load(writer)
79 |
80 | rpm = ReplayBuffer(
81 | max_size=args.replay_buffer_size,
82 | num_agents=args.num_agents,
83 | episode_limit=args.episode_limit,
84 | obs_space=args.obs_shape,
85 | state_space=args.state_shape,
86 | action_space=args.action_shape,
87 | reward_space=args.reward_shape,
88 | done_space=args.done_shape,
89 | device=device,
90 | )
91 |
92 | actor_model = RNNActorModel(
93 | input_dim=args.obs_dim,
94 | rnn_hidden_dim=args.rnn_hidden_dim,
95 | n_actions=args.n_actions,
96 | )
97 |
98 | critic_model = MLPCriticModel(
99 | input_dim=args.obs_dim,
100 | hidden_dim=args.fc_hidden_dim,
101 | output_dim=args.n_actions,
102 | )
103 | agent = ComaAgent(
104 | actor_model=actor_model,
105 | critic_model=critic_model,
106 | num_agents=args.num_agents,
107 | double_q=args.double_q,
108 | total_steps=args.total_steps,
109 | gamma=args.gamma,
110 | actor_lr=args.learning_rate,
111 | critic_lr=args.learning_rate,
112 | egreedy_exploration=args.egreedy_exploration,
113 | min_exploration=args.min_exploration,
114 | target_update_interval=args.target_update_interval,
115 | learner_update_freq=args.learner_update_freq,
116 | clip_grad_norm=args.clip_grad_norm,
117 | device=args.device,
118 | )
119 |
120 | progress_bar = ProgressBar(args.memory_warmup_size)
121 | while rpm.size() < args.memory_warmup_size:
122 | run_train_episode(env, agent, rpm, args)
123 | progress_bar.update()
124 |
125 | steps_cnt = 0
126 | episode_cnt = 0
127 | progress_bar = ProgressBar(args.total_steps)
128 | while steps_cnt < args.total_steps:
129 | train_res_dict = run_train_episode(env, agent, rpm, args)
130 | # update episodes and steps
131 | episode_cnt += 1
132 | steps_cnt += train_res_dict['episode_step']
133 |
134 | # learning rate decay
135 | agent.learning_rate = max(
136 | agent.lr_scheduler.step(train_res_dict['episode_step']),
137 | agent.min_learning_rate,
138 | )
139 |
140 | train_res_dict.update({
141 | 'exploration': agent.exploration,
142 | 'learning_rate': agent.learning_rate,
143 | 'replay_max_size': rpm.size(),
144 | 'target_update_count': agent.target_update_count,
145 | })
146 | if episode_cnt % args.train_log_interval == 0:
147 | text_logger.info(
148 | '[Train], episode: {}, train_episode_step: {}, train_win_rate: {:.2f}, train_reward: {:.2f}'
149 | .format(
150 | episode_cnt,
151 | train_res_dict['episode_step'],
152 | train_res_dict['win_rate'],
153 | train_res_dict['episode_reward'],
154 | ))
155 | logger.log_train_data(train_res_dict, steps_cnt)
156 |
157 | if episode_cnt % args.test_log_interval == 0:
158 | eval_res_dict = run_eval_episode(env, agent, args=args)
159 | text_logger.info(
160 | '[Eval], episode: {}, eval_episode_step:{:.2f}, eval_win_rate: {:.2f}, eval_reward: {:.2f}'
161 | .format(
162 | episode_cnt,
163 | eval_res_dict['episode_step'],
164 | eval_res_dict['win_rate'],
165 | eval_res_dict['episode_reward'],
166 | ))
167 | logger.log_test_data(eval_res_dict, steps_cnt)
168 |
169 | progress_bar.update(train_res_dict['episode_step'])
170 |
171 |
172 | if __name__ == '__main__':
173 | main()
174 |
--------------------------------------------------------------------------------
/scripts/main_idqn.py:
--------------------------------------------------------------------------------
1 | """This script contains the main function for running the IDQN (Independent
2 | Deep Q-Network) algorithm in a StarCraft II environment."""
3 |
4 | import argparse
5 | import os
6 | import sys
7 | import time
8 |
9 | import torch
10 | from smac.env import StarCraft2Env
11 | from torch.utils.tensorboard import SummaryWriter
12 |
13 | sys.path.append('../')
14 | from configs.arguments import get_common_args
15 | from configs.idqn_config import IDQNConfig
16 | from marltoolkit.agents import IDQNAgent
17 | from marltoolkit.data import ReplayBuffer
18 | from marltoolkit.envs.smacv1.env_wrapper import SC2EnvWrapper
19 | from marltoolkit.modules.actors import RNNActorModel
20 | from marltoolkit.runners.episode_runner import (run_eval_episode,
21 | run_train_episode)
22 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger,
23 | get_outdir, get_root_logger)
24 |
25 |
26 | def main():
27 | """The main function for running the IDQN algorithm.
28 |
29 | It initializes the necessary components such as the environment, agent,
30 | logger, and replay buffer. Then, it performs training episodes and
31 | evaluation episodes, logging the results at specified intervals.
32 | """
33 |
34 | idn_config = IDQNConfig()
35 | common_args = get_common_args()
36 | args = argparse.Namespace(**vars(common_args), **vars(idn_config))
37 | device = (torch.device('cuda') if torch.cuda.is_available() and args.cuda
38 | else torch.device('cpu'))
39 |
40 | env = StarCraft2Env(map_name=args.scenario, difficulty=args.difficulty)
41 |
42 | env = SC2EnvWrapper(env)
43 | args.episode_limit = env.episode_limit
44 | args.obs_shape = env.obs_shape
45 | args.state_shape = env.state_shape
46 | args.num_agents = env.num_agents
47 | args.n_actions = env.n_actions
48 | args.device = device
49 |
50 | # init the logger before other steps
51 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
52 | # log
53 | log_name = os.path.join(args.project, args.scenario, args.algo_name,
54 | timestamp).replace(os.path.sep, '_')
55 | log_path = os.path.join(args.log_dir, args.project, args.scenario,
56 | args.algo_name)
57 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir')
58 | log_file = os.path.join(log_path, log_name + '.log')
59 | text_logger = get_root_logger(log_file=log_file, log_level='INFO')
60 |
61 | if args.logger == 'wandb':
62 | logger = WandbLogger(
63 | train_interval=args.train_log_interval,
64 | test_interval=args.test_log_interval,
65 | update_interval=args.train_log_interval,
66 | project=args.project,
67 | name=log_name,
68 | save_interval=1,
69 | config=args,
70 | )
71 | writer = SummaryWriter(tensorboard_log_path)
72 | writer.add_text('args', str(args))
73 | if args.logger == 'tensorboard':
74 | logger = TensorboardLogger(writer)
75 | else: # wandb
76 | logger.load(writer)
77 |
78 | rpm = ReplayBuffer(
79 | max_size=args.replay_buffer_size,
80 | num_agents=args.num_agents,
81 | num_actions=args.n_actions,
82 | episode_limit=args.episode_limit,
83 | obs_shape=args.obs_shape,
84 | state_shape=args.state_shape,
85 | device=args.device,
86 | )
87 | actor_model = RNNActorModel(
88 | input_dim=args.obs_shape,
89 | rnn_hidden_dim=args.rnn_hidden_dim,
90 | n_actions=args.n_actions,
91 | )
92 |
93 | marl_agent = IDQNAgent(
94 | actor_model=actor_model,
95 | mixer_model=None,
96 | num_agents=args.num_agents,
97 | double_q=args.double_q,
98 | total_steps=args.total_steps,
99 | gamma=args.gamma,
100 | learning_rate=args.learning_rate,
101 | min_learning_rate=args.min_learning_rate,
102 | egreedy_exploration=args.egreedy_exploration,
103 | min_exploration=args.min_exploration,
104 | target_update_interval=args.target_update_interval,
105 | learner_update_freq=args.learner_update_freq,
106 | clip_grad_norm=args.clip_grad_norm,
107 | device=args.device,
108 | )
109 |
110 | progress_bar = ProgressBar(args.memory_warmup_size)
111 | while rpm.size() < args.memory_warmup_size:
112 | run_train_episode(env, marl_agent, rpm, args)
113 | progress_bar.update()
114 |
115 | steps_cnt = 0
116 | episode_cnt = 0
117 | progress_bar = ProgressBar(args.total_steps)
118 | while steps_cnt < args.total_steps:
119 | (
120 | episode_reward,
121 | episode_step,
122 | is_win,
123 | mean_loss,
124 | mean_td_error,
125 | ) = run_train_episode(env, marl_agent, rpm, args)
126 | # update episodes and steps
127 | episode_cnt += 1
128 | steps_cnt += episode_step
129 |
130 | # learning rate decay
131 | marl_agent.learning_rate = max(
132 | marl_agent.lr_scheduler.step(episode_step),
133 | marl_agent.min_learning_rate)
134 |
135 | train_results = {
136 | 'env_step': episode_step,
137 | 'rewards': episode_reward,
138 | 'win_rate': is_win,
139 | 'mean_loss': mean_loss,
140 | 'mean_td_error': mean_td_error,
141 | 'exploration': marl_agent.exploration,
142 | 'learning_rate': marl_agent.learning_rate,
143 | 'replay_buffer_size': rpm.size(),
144 | 'target_update_count': marl_agent.target_update_count,
145 | }
146 | if episode_cnt % args.train_log_interval == 0:
147 | text_logger.info(
148 | '[Train], episode: {}, train_win_rate: {:.2f}, train_reward: {:.2f}'
149 | .format(episode_cnt, is_win, episode_reward))
150 | logger.log_train_data(train_results, steps_cnt)
151 |
152 | if episode_cnt % args.test_log_interval == 0:
153 | eval_rewards, eval_steps, eval_win_rate = run_eval_episode(
154 | env, marl_agent, num_eval_episodes=5)
155 | text_logger.info(
156 | '[Eval], episode: {}, eval_win_rate: {:.2f}, eval_rewards: {:.2f}'
157 | .format(episode_cnt, eval_win_rate, eval_rewards))
158 |
159 | test_results = {
160 | 'env_step': eval_steps,
161 | 'rewards': eval_rewards,
162 | 'win_rate': eval_win_rate,
163 | }
164 | logger.log_test_data(test_results, steps_cnt)
165 |
166 | progress_bar.update(episode_step)
167 |
168 |
169 | if __name__ == '__main__':
170 | main()
171 |
--------------------------------------------------------------------------------
/scripts/main_mappo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 |
5 | import torch
6 | from torch.utils.tensorboard import SummaryWriter
7 |
8 | sys.path.append('../')
9 |
10 | from configs.arguments import get_common_args
11 | from marltoolkit.agents.mappo_agent import MAPPOAgent
12 | from marltoolkit.data.shared_buffer import SharedReplayBuffer
13 | from marltoolkit.envs.smacv1 import SMACWrapperEnv
14 | from marltoolkit.runners.episode_runner import (run_eval_episode,
15 | run_train_episode)
16 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger,
17 | get_outdir, get_root_logger)
18 |
19 |
20 | def main() -> None:
21 | """Main function for running the QMix algorithm.
22 |
23 | This function initializes the necessary configurations, environment, logger, models, and agents.
24 | It then runs training episodes and evaluates the agent's performance periodically.
25 |
26 | Returns:
27 | None
28 | """
29 | args = get_common_args()
30 | if args.algorithm_name == 'rmappo':
31 | args.use_recurrent_policy = True
32 | args.use_naive_recurrent_policy = False
33 |
34 | if args.algorithm_name == 'rmappo':
35 | print(
36 | 'u are choosing to use rmappo, we set use_recurrent_policy to be True'
37 | )
38 | args.use_recurrent_policy = True
39 | args.use_naive_recurrent_policy = False
40 | elif (args.algorithm_name == 'mappo' or args.algorithm_name == 'mat'
41 | or args.algorithm_name == 'mat_dec'):
42 | assert (args.use_recurrent_policy is False
43 | and args.use_naive_recurrent_policy is False
44 | ), 'check recurrent policy!'
45 | print(
46 | 'U are choosing to use mappo, we set use_recurrent_policy & use_naive_recurrent_policy to be False'
47 | )
48 | args.use_recurrent_policy = False
49 | args.use_naive_recurrent_policy = False
50 |
51 | elif args.algorithm_name == 'ippo':
52 | print(
53 | 'u are choosing to use ippo, we set use_centralized_v to be False')
54 | args.use_centralized_v = False
55 | elif args.algorithm_name == 'happo' or args.algorithm_name == 'hatrpo':
56 | # can or cannot use recurrent network?
57 | print('using', args.algorithm_name, 'without recurrent network')
58 | args.use_recurrent_policy = False
59 | args.use_naive_recurrent_policy = False
60 | else:
61 | raise NotImplementedError
62 |
63 | if args.algorithm_name == 'mat_dec':
64 | args.dec_actor = True
65 | args.share_actor = True
66 |
67 | # cuda
68 | if args.cuda and torch.cuda.is_available():
69 | print('choose to use gpu...')
70 | args.device = torch.device('cuda')
71 | torch.set_num_threads(args.n_training_threads)
72 | if args.cuda_deterministic:
73 | torch.backends.cudnn.benchmark = False
74 | torch.backends.cudnn.deterministic = True
75 | else:
76 | print('choose to use cpu...')
77 | args.device = torch.device('cpu')
78 | torch.set_num_threads(args.n_training_threads)
79 |
80 | # Environment
81 | env = SMACWrapperEnv(map_name=args.scenario, args=args)
82 | args.episode_limit = env.episode_limit
83 | args.obs_dim = env.obs_dim
84 | args.obs_shape = env.obs_shape
85 | args.state_dim = env.state_dim
86 | args.state_shape = env.state_shape
87 | args.num_agents = env.num_agents
88 | args.n_actions = env.n_actions
89 | args.action_shape = env.action_shape
90 | args.reward_shape = env.reward_shape
91 | args.done_shape = env.done_shape
92 |
93 | # init the logger before other steps
94 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
95 | # log
96 | log_name = os.path.join(args.project, args.scenario, args.algo_name,
97 | timestamp).replace(os.path.sep, '_')
98 | log_path = os.path.join(args.log_dir, args.project, args.scenario,
99 | args.algo_name)
100 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir')
101 | log_file = os.path.join(log_path, log_name + '.log')
102 | text_logger = get_root_logger(log_file=log_file, log_level='INFO')
103 |
104 | if args.logger == 'wandb':
105 | logger = WandbLogger(
106 | train_interval=args.train_log_interval,
107 | test_interval=args.test_log_interval,
108 | update_interval=args.train_log_interval,
109 | project=args.project,
110 | name=log_name,
111 | save_interval=1,
112 | config=args,
113 | )
114 | writer = SummaryWriter(tensorboard_log_path)
115 | writer.add_text('args', str(args))
116 | if args.logger == 'tensorboard':
117 | logger = TensorboardLogger(writer)
118 | else: # wandb
119 | logger.load(writer)
120 |
121 | args.obs_shape = env.get_actor_input_shape()
122 |
123 | # policy network
124 | agent = MAPPOAgent(args)
125 | # buffer
126 | rpm = SharedReplayBuffer(
127 | args.num_envs,
128 | args.num_agents,
129 | args.episode_limit,
130 | args.obs_shape,
131 | args.state_shape,
132 | args.action_shape,
133 | args.reward_shape,
134 | args.done_shape,
135 | args,
136 | )
137 |
138 | progress_bar = ProgressBar(args.memory_warmup_size)
139 | while rpm.size() < args.memory_warmup_size:
140 | run_train_episode(env, agent, rpm, args)
141 | progress_bar.update()
142 |
143 | steps_cnt = 0
144 | episode_cnt = 0
145 | progress_bar = ProgressBar(args.total_steps)
146 | while steps_cnt < args.total_steps:
147 | train_res_dict = run_train_episode(env, agent, rpm, args)
148 | # update episodes and steps
149 | episode_cnt += 1
150 | steps_cnt += train_res_dict['episode_step']
151 |
152 | # learning rate decay
153 | agent.learning_rate = max(
154 | agent.lr_scheduler.step(train_res_dict['episode_step']),
155 | agent.min_learning_rate,
156 | )
157 |
158 | train_res_dict.update({
159 | 'exploration': agent.exploration,
160 | 'learning_rate': agent.learning_rate,
161 | 'replay_max_size': rpm.size(),
162 | 'target_update_count': agent.target_update_count,
163 | })
164 | if episode_cnt % args.train_log_interval == 0:
165 | text_logger.info(
166 | '[Train], episode: {}, train_episode_step: {}, train_win_rate: {:.2f}, train_reward: {:.2f}'
167 | .format(
168 | episode_cnt,
169 | train_res_dict['episode_step'],
170 | train_res_dict['win_rate'],
171 | train_res_dict['episode_reward'],
172 | ))
173 | logger.log_train_data(train_res_dict, steps_cnt)
174 |
175 | if episode_cnt % args.test_log_interval == 0:
176 | eval_res_dict = run_eval_episode(env, agent, args=args)
177 | text_logger.info(
178 | '[Eval], episode: {}, eval_episode_step:{:.2f}, eval_win_rate: {:.2f}, eval_reward: {:.2f}'
179 | .format(
180 | episode_cnt,
181 | eval_res_dict['episode_step'],
182 | eval_res_dict['win_rate'],
183 | eval_res_dict['episode_reward'],
184 | ))
185 | logger.log_test_data(eval_res_dict, steps_cnt)
186 |
187 | progress_bar.update(train_res_dict['episode_step'])
188 |
189 |
190 | if __name__ == '__main__':
191 | main()
192 |
--------------------------------------------------------------------------------
/scripts/main_qmix.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import time
5 |
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | sys.path.append('../')
10 |
11 | from configs.arguments import get_common_args
12 | from configs.qmix_config import QMixConfig
13 | from marltoolkit.agents.qmix_agent import QMixAgent
14 | from marltoolkit.data import ReplayBuffer
15 | from marltoolkit.envs.smacv1.smac_env import SMACWrapperEnv
16 | from marltoolkit.modules.actors import RNNActorModel
17 | from marltoolkit.modules.mixers import QMixerModel
18 | from marltoolkit.runners.episode_runner import (run_eval_episode,
19 | run_train_episode)
20 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger,
21 | get_outdir, get_root_logger)
22 |
23 |
24 | def main():
25 | """Main function for running the QMix algorithm.
26 |
27 | This function initializes the necessary configurations, environment, logger, models, and agents.
28 | It then runs training episodes and evaluates the agent's performance periodically.
29 |
30 | Returns:
31 | None
32 | """
33 | qmix_config = QMixConfig()
34 | common_args = get_common_args()
35 | args = argparse.Namespace(**vars(qmix_config))
36 | args.__dict__.update(**vars(common_args))
37 | device = torch.device('cuda') if torch.cuda.is_available(
38 | ) and args.cuda else torch.device('cpu')
39 |
40 | env = SMACWrapperEnv(map_name=args.scenario, args=args)
41 | args.episode_limit = env.episode_limit
42 | args.obs_dim = env.obs_dim
43 | args.obs_shape = env.obs_shape
44 | args.state_dim = env.state_dim
45 | args.state_shape = env.state_shape
46 | args.num_agents = env.num_agents
47 | args.n_actions = env.n_actions
48 | args.action_shape = env.action_shape
49 | args.reward_shape = env.reward_shape
50 | args.done_shape = env.done_shape
51 | args.device = device
52 |
53 | # init the logger before other steps
54 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
55 | # log
56 | log_name = os.path.join(args.project, args.scenario, args.algo_name,
57 | timestamp).replace(os.path.sep, '_')
58 | log_path = os.path.join(args.log_dir, args.project, args.scenario,
59 | args.algo_name)
60 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir')
61 | log_file = os.path.join(log_path, log_name + '.log')
62 | text_logger = get_root_logger(log_file=log_file, log_level='INFO')
63 |
64 | if args.logger == 'wandb':
65 | logger = WandbLogger(
66 | train_interval=args.train_log_interval,
67 | test_interval=args.test_log_interval,
68 | update_interval=args.train_log_interval,
69 | project=args.project,
70 | name=log_name,
71 | save_interval=1,
72 | config=args,
73 | )
74 | writer = SummaryWriter(tensorboard_log_path)
75 | writer.add_text('args', str(args))
76 | if args.logger == 'tensorboard':
77 | logger = TensorboardLogger(writer)
78 | else: # wandb
79 | logger.load(writer)
80 |
81 | args.obs_shape = env.get_actor_input_shape()
82 |
83 | rpm = ReplayBuffer(
84 | max_size=args.replay_buffer_size,
85 | num_agents=args.num_agents,
86 | episode_limit=args.episode_limit,
87 | obs_shape=args.obs_shape,
88 | state_shape=args.state_shape,
89 | action_shape=args.action_shape,
90 | reward_shape=args.reward_shape,
91 | done_shape=args.done_shape,
92 | device=device,
93 | )
94 |
95 | actor_model = RNNActorModel(args)
96 |
97 | mixer_model = QMixerModel(
98 | num_agents=args.num_agents,
99 | state_dim=args.state_dim,
100 | mixing_embed_dim=args.mixing_embed_dim,
101 | hypernet_layers=args.hypernet_layers,
102 | hypernet_embed_dim=args.hypernet_embed_dim,
103 | )
104 |
105 | agent = QMixAgent(
106 | actor_model=actor_model,
107 | mixer_model=mixer_model,
108 | num_agents=args.num_agents,
109 | double_q=args.double_q,
110 | total_steps=args.total_steps,
111 | gamma=args.gamma,
112 | learning_rate=args.learning_rate,
113 | min_learning_rate=args.min_learning_rate,
114 | egreedy_exploration=args.egreedy_exploration,
115 | min_exploration=args.min_exploration,
116 | target_update_tau=args.target_update_tau,
117 | target_update_interval=args.target_update_interval,
118 | learner_update_freq=args.learner_update_freq,
119 | clip_grad_norm=args.clip_grad_norm,
120 | device=args.device,
121 | )
122 |
123 | progress_bar = ProgressBar(args.memory_warmup_size)
124 | while rpm.size() < args.memory_warmup_size:
125 | run_train_episode(env, agent, rpm, args)
126 | progress_bar.update()
127 |
128 | steps_cnt = 0
129 | episode_cnt = 0
130 | progress_bar = ProgressBar(args.total_steps)
131 | while steps_cnt < args.total_steps:
132 | train_res_dict = run_train_episode(env, agent, rpm, args)
133 | # update episodes and steps
134 | episode_cnt += 1
135 | steps_cnt += train_res_dict['episode_step']
136 |
137 | # learning rate decay
138 | agent.learning_rate = max(
139 | agent.lr_scheduler.step(train_res_dict['episode_step']),
140 | agent.min_learning_rate,
141 | )
142 |
143 | train_res_dict.update({
144 | 'exploration': agent.exploration,
145 | 'learning_rate': agent.learning_rate,
146 | 'replay_max_size': rpm.size(),
147 | 'target_update_count': agent.target_update_count,
148 | })
149 | if episode_cnt % args.train_log_interval == 0:
150 | text_logger.info(
151 | '[Train], episode: {}, train_episode_step: {}, train_win_rate: {:.2f}, train_reward: {:.2f}'
152 | .format(
153 | episode_cnt,
154 | train_res_dict['episode_step'],
155 | train_res_dict['win_rate'],
156 | train_res_dict['episode_reward'],
157 | ))
158 | logger.log_train_data(train_res_dict, steps_cnt)
159 |
160 | if episode_cnt % args.test_log_interval == 0:
161 | eval_res_dict = run_eval_episode(env, agent, args=args)
162 | text_logger.info(
163 | '[Eval], episode: {}, eval_episode_step:{:.2f}, eval_win_rate: {:.2f}, eval_reward: {:.2f}'
164 | .format(
165 | episode_cnt,
166 | eval_res_dict['episode_step'],
167 | eval_res_dict['win_rate'],
168 | eval_res_dict['episode_reward'],
169 | ))
170 | logger.log_test_data(eval_res_dict, steps_cnt)
171 |
172 | progress_bar.update(train_res_dict['episode_step'])
173 |
174 |
175 | if __name__ == '__main__':
176 | main()
177 |
--------------------------------------------------------------------------------
/scripts/main_qtran.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import time
5 | from copy import deepcopy
6 |
7 | import torch
8 | from smac.env import StarCraft2Env
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 | sys.path.append('../')
12 | from configs.arguments import get_common_args
13 | from configs.qtran_config import QTranConfig
14 | from marltoolkit.agents.qtran_agent import QTranAgent
15 | from marltoolkit.data import MaReplayBuffer
16 | from marltoolkit.envs import SC2EnvWrapper
17 | from marltoolkit.modules.actors import RNNModel
18 | from marltoolkit.modules.mixers.qtran_mixer import QTransModel
19 | from marltoolkit.runners.episode_runner import (run_eval_episode,
20 | run_train_episode)
21 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger,
22 | get_outdir, get_root_logger)
23 |
24 |
25 | def main():
26 | device = torch.device(
27 | 'cuda') if torch.cuda.is_available() else torch.device('cpu')
28 |
29 | config = deepcopy(QTranConfig)
30 | common_args = get_common_args()
31 | common_dict = vars(common_args)
32 | config.update(common_dict)
33 |
34 | env = StarCraft2Env(map_name=config['scenario'],
35 | difficulty=config['difficulty'])
36 |
37 | env = SC2EnvWrapper(env)
38 | config['episode_limit'] = env.episode_limit
39 | config['obs_shape'] = env.obs_shape
40 | config['state_shape'] = env.state_shape
41 | config['n_agents'] = env.n_agents
42 | config['n_actions'] = env.n_actions
43 |
44 | args = argparse.Namespace(**config)
45 |
46 | # init the logger before other steps
47 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
48 | # log
49 | log_name = os.path.join(args.project, args.scenario, args.algo, timestamp)
50 | text_log_path = os.path.join(args.log_dir, args.project, args.scenario,
51 | args.algo)
52 | tensorboard_log_path = get_outdir(text_log_path, 'log_dir')
53 | log_file = os.path.join(text_log_path, f'{timestamp}.log')
54 | text_logger = get_root_logger(log_file=log_file, log_level='INFO')
55 |
56 | if args.logger == 'wandb':
57 | logger = WandbLogger(train_interval=args.train_log_interval,
58 | test_interval=args.test_log_interval,
59 | update_interval=args.train_log_interval,
60 | project=args.project,
61 | name=log_name.replace(os.path.sep, '_'),
62 | save_interval=1,
63 | config=args,
64 | entity='jianzhnie')
65 | writer = SummaryWriter(tensorboard_log_path)
66 | writer.add_text('args', str(args))
67 | if args.logger == 'tensorboard':
68 | logger = TensorboardLogger(writer)
69 | else: # wandb
70 | logger.load(writer)
71 |
72 | rpm = MaReplayBuffer(max_size=config['replay_buffer_size'],
73 | episode_limit=config['episode_limit'],
74 | state_shape=config['state_shape'],
75 | obs_shape=config['obs_shape'],
76 | num_agents=config['n_agents'],
77 | num_actions=config['n_actions'],
78 | batch_size=config['batch_size'],
79 | device=device)
80 |
81 | actor_model = RNNModel(
82 | input_shape=config['obs_shape'],
83 | rnn_hidden_dim=config['rnn_hidden_dim'],
84 | n_actions=config['n_actions'],
85 | )
86 | mixer_model = QTransModel(n_agents=config['n_agents'],
87 | n_actions=config['n_actions'],
88 | state_dim=config['state_shape'],
89 | rnn_hidden_dim=config['rnn_hidden_dim'],
90 | mixing_embed_dim=config['mixing_embed_dim'])
91 |
92 | qmix_agent = QTranAgent(
93 | actor_model=actor_model,
94 | mixer_model=mixer_model,
95 | n_agents=config['n_agents'],
96 | n_actions=config['n_actions'],
97 | double_q=config['double_q'],
98 | total_steps=config['total_steps'],
99 | gamma=config['gamma'],
100 | learning_rate=config['learning_rate'],
101 | min_learning_rate=config['min_learning_rate'],
102 | egreedy_exploration=config['egreedy_exploration'],
103 | min_exploration=config['min_exploration'],
104 | target_update_interval=config['target_update_interval'],
105 | learner_update_freq=config['learner_update_freq'],
106 | clip_grad_norm=config['clip_grad_norm'],
107 | opt_loss_coef=config['opt_loss_coef'],
108 | nopt_min_loss_coef=config['nopt_min_loss_coef'],
109 | device=device,
110 | )
111 |
112 | progress_bar = ProgressBar(config['memory_warmup_size'])
113 | while rpm.size() < config['memory_warmup_size']:
114 | run_train_episode(env, qmix_agent, rpm, config)
115 | progress_bar.update()
116 |
117 | steps_cnt = 0
118 | episode_cnt = 0
119 | progress_bar = ProgressBar(config['total_steps'])
120 | while steps_cnt < config['total_steps']:
121 | episode_reward, episode_step, is_win, mean_loss, mean_td_loss, mean_opt_loss, mean_nopt_loss = run_train_episode(
122 | env, qmix_agent, rpm, config)
123 | # update episodes and steps
124 | episode_cnt += 1
125 | steps_cnt += episode_step
126 |
127 | # learning rate decay
128 | qmix_agent.learning_rate = max(
129 | qmix_agent.lr_scheduler.step(episode_step),
130 | qmix_agent.min_learning_rate)
131 |
132 | train_results = {
133 | 'env_step': episode_step,
134 | 'rewards': episode_reward,
135 | 'win_rate': is_win,
136 | 'mean_loss': mean_loss,
137 | 'mean_td_loss': mean_td_loss,
138 | 'exploration': qmix_agent.exploration,
139 | 'learning_rate': qmix_agent.learning_rate,
140 | 'replay_buffer_size': rpm.size(),
141 | 'target_update_count': qmix_agent.target_update_count,
142 | }
143 | if episode_cnt % config['train_log_interval'] == 0:
144 | text_logger.info(
145 | '[Train], episode: {}, train_win_rate: {:.2f}, train_reward: {:.2f}'
146 | .format(episode_cnt, is_win, episode_reward))
147 | logger.log_train_data(train_results, steps_cnt)
148 |
149 | if episode_cnt % config['test_log_interval'] == 0:
150 | eval_rewards, eval_steps, eval_win_rate = run_eval_episode(
151 | env, qmix_agent, num_eval_episodes=5)
152 | text_logger.info(
153 | '[Eval], episode: {}, eval_win_rate: {:.2f}, eval_rewards: {:.2f}'
154 | .format(episode_cnt, eval_win_rate, eval_rewards))
155 |
156 | test_results = {
157 | 'env_step': eval_steps,
158 | 'rewards': eval_rewards,
159 | 'win_rate': eval_win_rate
160 | }
161 | logger.log_test_data(test_results, steps_cnt)
162 |
163 | progress_bar.update(episode_step)
164 |
165 |
166 | if __name__ == '__main__':
167 | main()
168 |
--------------------------------------------------------------------------------
/scripts/main_vdn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import time
5 |
6 | import torch
7 | from torch.utils.tensorboard import SummaryWriter
8 |
9 | sys.path.append('../')
10 | from configs.arguments import get_common_args
11 | from configs.qmix_config import QMixConfig
12 | from marltoolkit.agents.vdn_agent import VDNAgent
13 | from marltoolkit.data import OffPolicyBufferRNN
14 | from marltoolkit.envs.smacv1 import SMACWrapperEnv
15 | from marltoolkit.modules.actors import RNNActorModel
16 | from marltoolkit.modules.mixers import VDNMixer
17 | from marltoolkit.runners.parallel_episode_runner import (run_eval_episode,
18 | run_train_episode)
19 | from marltoolkit.utils import (ProgressBar, TensorboardLogger, WandbLogger,
20 | get_outdir, get_root_logger)
21 | from marltoolkit.utils.env_utils import make_vec_env
22 |
23 |
24 | def main():
25 | vdn_config = QMixConfig()
26 | common_args = get_common_args()
27 | args = argparse.Namespace(**vars(common_args), **vars(vdn_config))
28 | device = torch.device('cuda') if torch.cuda.is_available(
29 | ) and args.cuda else torch.device('cpu')
30 |
31 | train_envs, test_envs = make_vec_env(
32 | env_id=args.env_id,
33 | map_name=args.scenario,
34 | num_train_envs=args.num_train_envs,
35 | num_test_envs=args.num_test_envs,
36 | difficulty=args.difficulty,
37 | )
38 | env = SMACWrapperEnv(map_name=args.scenario)
39 | args.episode_limit = env.episode_limit
40 | args.obs_dim = env.obs_dim
41 | args.obs_shape = env.obs_shape
42 | args.state_shape = env.state_shape
43 | args.num_agents = env.num_agents
44 | args.n_actions = env.n_actions
45 | args.action_shape = env.action_shape
46 | args.reward_shape = env.reward_shape
47 | args.done_shape = env.done_shape
48 | args.device = device
49 |
50 | # init the logger before other steps
51 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
52 | # log
53 | log_name = os.path.join(args.project, args.scenario, args.algo_name,
54 | timestamp).replace(os.path.sep, '_')
55 | log_path = os.path.join(args.log_dir, args.project, args.scenario,
56 | args.algo_name)
57 | tensorboard_log_path = get_outdir(log_path, 'tensorboard_log_dir')
58 | log_file = os.path.join(log_path, log_name + '.log')
59 | text_logger = get_root_logger(log_file=log_file, log_level='INFO')
60 |
61 | if args.logger == 'wandb':
62 | logger = WandbLogger(
63 | train_interval=args.train_log_interval,
64 | test_interval=args.test_log_interval,
65 | update_interval=args.train_log_interval,
66 | project=args.project,
67 | name=log_name,
68 | save_interval=1,
69 | config=args,
70 | )
71 | writer = SummaryWriter(tensorboard_log_path)
72 | writer.add_text('args', str(args))
73 | if args.logger == 'tensorboard':
74 | logger = TensorboardLogger(writer)
75 | else: # wandb
76 | logger.load(writer)
77 |
78 | rpm = OffPolicyBufferRNN(
79 | max_size=args.replay_buffer_size,
80 | num_envs=args.num_train_envs,
81 | num_agents=args.num_agents,
82 | episode_limit=args.episode_limit,
83 | obs_space=args.obs_shape,
84 | state_space=args.state_shape,
85 | action_space=args.action_shape,
86 | reward_space=args.reward_shape,
87 | done_space=args.done_shape,
88 | device=args.device,
89 | )
90 | actor_model = RNNActorModel(
91 | input_dim=args.obs_dim,
92 | fc_hidden_dim=args.fc_hidden_dim,
93 | rnn_hidden_dim=args.rnn_hidden_dim,
94 | n_actions=args.n_actions,
95 | )
96 |
97 | mixer_model = VDNMixer()
98 |
99 | marl_agent = VDNAgent(
100 | actor_model=actor_model,
101 | mixer_model=mixer_model,
102 | num_envs=args.num_train_envs,
103 | num_agents=args.num_agents,
104 | action_dim=args.n_actions,
105 | double_q=args.double_q,
106 | total_steps=args.total_steps,
107 | gamma=args.gamma,
108 | learning_rate=args.learning_rate,
109 | min_learning_rate=args.min_learning_rate,
110 | egreedy_exploration=args.egreedy_exploration,
111 | min_exploration=args.min_exploration,
112 | target_update_interval=args.target_update_interval,
113 | learner_update_freq=args.learner_update_freq,
114 | clip_grad_norm=args.clip_grad_norm,
115 | device=args.device,
116 | )
117 |
118 | progress_bar = ProgressBar(args.memory_warmup_size)
119 | while rpm.size() < args.memory_warmup_size:
120 | run_train_episode(train_envs, marl_agent, rpm, args)
121 | progress_bar.update()
122 |
123 | steps_cnt = 0
124 | episode_cnt = 0
125 | progress_bar = ProgressBar(args.total_steps)
126 | while steps_cnt < args.total_steps:
127 | episode_reward, episode_step, is_win, mean_loss, mean_td_error = (
128 | run_train_episode(train_envs, marl_agent, rpm, args))
129 | # update episodes and steps
130 | episode_cnt += 1
131 | steps_cnt += episode_step
132 |
133 | # learning rate decay
134 | marl_agent.learning_rate = max(
135 | marl_agent.lr_scheduler.step(episode_step),
136 | marl_agent.min_learning_rate)
137 |
138 | train_results = {
139 | 'env_step': episode_step,
140 | 'rewards': episode_reward,
141 | 'win_rate': is_win,
142 | 'mean_loss': mean_loss,
143 | 'mean_td_error': mean_td_error,
144 | 'exploration': marl_agent.exploration,
145 | 'learning_rate': marl_agent.learning_rate,
146 | 'replay_buffer_size': rpm.size(),
147 | 'target_update_count': marl_agent.target_update_count,
148 | }
149 | if episode_cnt % args.train_log_interval == 0:
150 | text_logger.info(
151 | '[Train], episode: {}, train_win_rate: {:.2f}, train_reward: {:.2f}'
152 | .format(episode_cnt, is_win, episode_reward))
153 | logger.log_train_data(train_results, steps_cnt)
154 |
155 | if episode_cnt % args.test_log_interval == 0:
156 | eval_rewards, eval_steps, eval_win_rate = run_eval_episode(
157 | test_envs, marl_agent, num_eval_episodes=5)
158 | text_logger.info(
159 | '[Eval], episode: {}, eval_win_rate: {:.2f}, eval_rewards: {:.2f}'
160 | .format(episode_cnt, eval_win_rate, eval_rewards))
161 |
162 | test_results = {
163 | 'env_step': eval_steps,
164 | 'rewards': eval_rewards,
165 | 'win_rate': eval_win_rate
166 | }
167 | logger.log_test_data(test_results, steps_cnt)
168 |
169 | progress_bar.update(episode_step)
170 |
171 |
172 | if __name__ == '__main__':
173 | main()
174 |
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | # # main
2 | # nohup python main.py --scenario 3m --total_steps 1000000 > runoob1.log 2>&1 &
3 | # nohup python main.py --scenario 8m --total_steps 1000000 > runoob2.log 2>&1 &
4 | # nohup python main.py --scenario 5m_vs_6m --total_steps 1000000 > runoob3.log 2>&1 &
5 | # nohup python main.py --scenario 8m_vs_9m --total_steps 1000000 > runoob4.log 2>&1 &
6 | # nohup python main.py --scenario MMM --total_steps 1000000 > runoob5.log 2>&1 &
7 | # nohup python main.py --scenario 2s3z --total_steps 1000000 > runoob6.log 2>&1 &
8 | # nohup python main.py --scenario 3s5z --total_steps 1000000 > runoob7.log 2>&1 &
9 |
10 |
11 | # main
12 | nohup python main_qmix.py --scenario 3m --total_steps 1000000 > runoob1.log 2>&1 &
13 | nohup python main_qmix.py --scenario 8m --total_steps 1000000 > runoob2.log 2>&1 &
14 | nohup python main_qmix.py --scenario 5m_vs_6m --total_steps 1000000 > runoob3.log 2>&1 &
15 | nohup python main_qmix.py --scenario 8m_vs_9m --total_steps 1000000 > runoob4.log 2>&1 &
16 | nohup python main_qmix.py --scenario MMM --total_steps 1000000 > runoob5.log 2>&1 &
17 | nohup python main_qmix.py --scenario 2s3z --total_steps 1000000 > runoob6.log 2>&1 &
18 | nohup python main_qmix.py --scenario 3s5z --total_steps 1000000 > runoob7.log 2>&1 &
19 |
--------------------------------------------------------------------------------
/scripts/test_multi_env.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append('../')
4 | import torch
5 | from torch.distributions import Categorical
6 |
7 | from marltoolkit.envs.smacv1.smac_env import SMACWrapperEnv
8 |
9 | if __name__ == '__main__':
10 | env = SMACWrapperEnv(map_name='3m')
11 | env_info = env.get_env_info()
12 | print('env_info:', env_info)
13 |
14 | action_dim = env_info['n_actions']
15 | num_agents = env_info['num_agents']
16 | results = env.reset()
17 | print('Reset:', results)
18 |
19 | avail_actions = env.get_available_actions()
20 | print('avail_actions:', avail_actions)
21 | available_actions = torch.tensor(avail_actions)
22 | actions_dist = Categorical(available_actions)
23 | random_actions = actions_dist.sample().numpy().tolist()
24 |
25 | print('random_actions: ', random_actions)
26 | results = env.step(random_actions)
27 | print('Step:', results)
28 | env.close()
29 |
--------------------------------------------------------------------------------
/scripts/test_multi_env2.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import torch
4 | from torch.distributions import Categorical
5 |
6 | sys.path.append('../')
7 | from marltoolkit.envs.smacv1.smac_env import SMACWrapperEnv
8 | from marltoolkit.envs.vec_env import DummyVecEnv, SubprocVecEnv
9 |
10 |
11 | def make_subproc_envs(map_name='3m', parallels=8) -> SubprocVecEnv:
12 |
13 | def _thunk():
14 | env = SMACWrapperEnv(map_name=map_name)
15 | return env
16 |
17 | return SubprocVecEnv([_thunk for _ in range(parallels)],
18 | shared_memory=False)
19 |
20 |
21 | def make_dummy_envs(map_name='3m', parallels=8) -> DummyVecEnv:
22 |
23 | def _thunk():
24 | env = SMACWrapperEnv(map_name=map_name)
25 | return env
26 |
27 | return DummyVecEnv([_thunk for _ in range(parallels)])
28 |
29 |
30 | def test_dummy_envs():
31 | parallels = 2
32 | env = SMACWrapperEnv(map_name='3m')
33 | env.reset()
34 | env_info = env.get_env_info()
35 | print('env_info:', env_info)
36 |
37 | train_envs = make_dummy_envs(parallels=parallels)
38 | results = train_envs.reset()
39 | print('Reset:', results)
40 |
41 | num_envs = parallels
42 | avail_actions = env.get_available_actions()
43 | print('avail_actions:', avail_actions)
44 | available_actions = torch.tensor(avail_actions)
45 | actions_dist = Categorical(available_actions)
46 | random_actions = actions_dist.sample().numpy().tolist()
47 |
48 | print('random_actions: ', random_actions)
49 | dummy_actions = [random_actions for _ in range(num_envs)]
50 | print('dummy_actions:', dummy_actions)
51 | results = train_envs.step(dummy_actions)
52 | print('Step:', results)
53 | train_envs.close()
54 |
55 |
56 | def test_subproc_envs():
57 | parallels = 10
58 | train_envs = make_subproc_envs(parallels=parallels)
59 | results = train_envs.reset()
60 | obs, state, info = results
61 | print('Env Reset:', '*' * 100)
62 | print('obs:', obs.shape)
63 | print('state:', state.shape)
64 | print('info:', info)
65 |
66 | avail_actions = train_envs.get_available_actions()
67 | print('avail_actions:', avail_actions)
68 | available_actions = torch.tensor(avail_actions)
69 | actions_dist = Categorical(available_actions)
70 | random_actions = actions_dist.sample().numpy().tolist()
71 | print('random_actions: ', random_actions)
72 | results = train_envs.step(random_actions)
73 | obs, state, rewards, dones, info = results
74 | print('Env Step:', '*' * 100)
75 | print('obs:', obs.shape)
76 | print('state:', state.shape)
77 | print('rewards:', rewards)
78 | print('dones:', dones)
79 | print('info:', info)
80 |
81 |
82 | if __name__ == '__main__':
83 | test_subproc_envs()
84 |
--------------------------------------------------------------------------------