├── .gitignore ├── requirements.txt ├── README.assets └── mpe_simple_spread.gif ├── agent ├── __pycache__ │ └── agents.cpython-36.pyc └── agents.py ├── policy ├── __pycache__ │ ├── qmix.cpython-36.pyc │ ├── base_policy.cpython-36.pyc │ ├── centralized_ppo.cpython-36.pyc │ ├── independent_ppo.cpython-36.pyc │ ├── grid_wise_control.cpython-36.pyc │ ├── grid_wise_control_ppo.cpython-36.pyc │ └── grid_wise_control_ddpg.cpython-36.pyc ├── base_policy.py ├── centralized_ppo.py ├── independent_ppo.py ├── grid_wise_control_ppo.py ├── qmix.py ├── grid_wise_control.py └── grid_wise_control_ddpg.py ├── utils ├── __pycache__ │ ├── env_utils.cpython-36.pyc │ ├── config_utils.cpython-36.pyc │ ├── train_utils.cpython-36.pyc │ └── config_objects.cpython-36.pyc ├── config_objects.py ├── train_utils.py ├── config_utils.py └── env_utils.py ├── networks ├── __pycache__ │ ├── ppo_net.cpython-36.pyc │ ├── qmix_net.cpython-36.pyc │ ├── grid_net_actor.cpython-36.pyc │ └── grid_net_critic.cpython-36.pyc ├── grid_net_critic.py ├── qmix_net.py ├── ppo_net.py └── grid_net_actor.py ├── common ├── __pycache__ │ ├── reply_buffer.cpython-36.pyc │ └── pettingzoo_environment.cpython-36.pyc ├── pettingzoo_environment.py └── reply_buffer.py ├── config.yaml ├── LICENSE ├── main.py ├── README.md └── runner.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | gym==0.21.0 2 | PettingZoo==1.12.0 3 | PyYAML==5.3 4 | torch==1.6.0+cu101 5 | tqdm==4.42.1 6 | -------------------------------------------------------------------------------- /README.assets/mpe_simple_spread.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/README.assets/mpe_simple_spread.gif -------------------------------------------------------------------------------- /agent/__pycache__/agents.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/agent/__pycache__/agents.cpython-36.pyc -------------------------------------------------------------------------------- /policy/__pycache__/qmix.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/policy/__pycache__/qmix.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/env_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/utils/__pycache__/env_utils.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/ppo_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/networks/__pycache__/ppo_net.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/qmix_net.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/networks/__pycache__/qmix_net.cpython-36.pyc -------------------------------------------------------------------------------- /policy/__pycache__/base_policy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/policy/__pycache__/base_policy.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/utils/__pycache__/config_utils.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/train_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/utils/__pycache__/train_utils.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/reply_buffer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/common/__pycache__/reply_buffer.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/config_objects.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/utils/__pycache__/config_objects.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/grid_net_actor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/networks/__pycache__/grid_net_actor.cpython-36.pyc -------------------------------------------------------------------------------- /policy/__pycache__/centralized_ppo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/policy/__pycache__/centralized_ppo.cpython-36.pyc -------------------------------------------------------------------------------- /policy/__pycache__/independent_ppo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/policy/__pycache__/independent_ppo.cpython-36.pyc -------------------------------------------------------------------------------- /networks/__pycache__/grid_net_critic.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/networks/__pycache__/grid_net_critic.cpython-36.pyc -------------------------------------------------------------------------------- /policy/__pycache__/grid_wise_control.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/policy/__pycache__/grid_wise_control.cpython-36.pyc -------------------------------------------------------------------------------- /policy/__pycache__/grid_wise_control_ppo.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/policy/__pycache__/grid_wise_control_ppo.cpython-36.pyc -------------------------------------------------------------------------------- /common/__pycache__/pettingzoo_environment.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/common/__pycache__/pettingzoo_environment.cpython-36.pyc -------------------------------------------------------------------------------- /policy/__pycache__/grid_wise_control_ddpg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yangchen1997/Multi-Agent-Reinforcement-Learning/HEAD/policy/__pycache__/grid_wise_control_ddpg.cpython-36.pyc -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | environment: 2 | grid_size: 16 3 | learn_policy: "independent_ppo" 4 | max_cycles: 40 5 | train: 6 | epochs: 1000000 7 | show_evaluate_epoch: 100 8 | evaluate_epoch: 2 9 | cuda: False -------------------------------------------------------------------------------- /policy/base_policy.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class BasePolicy(object): 5 | 6 | @staticmethod 7 | def init_path(*args): 8 | for path in args: 9 | if not os.path.exists(path): 10 | os.makedirs(path) 11 | 12 | # 初始化权重 13 | def init_wight(self): 14 | raise NotImplementedError 15 | 16 | # 学习的方法,以一个batch的数据作为输入(封装成字典形式), 17 | # episode_num表示当前是第几次迭代,用于double类型的算法 18 | def learn(self, batch_data: dict, episode_num: int): 19 | raise NotImplementedError 20 | 21 | # 保存模型 22 | def save_model(self): 23 | raise NotImplementedError 24 | 25 | # 加载模型 26 | def load_model(self): 27 | raise NotImplementedError 28 | 29 | # 删除模型 30 | def del_model(self): 31 | raise NotImplementedError 32 | 33 | # 判断是否保存过模型 34 | def is_saved_model(self) -> bool: 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /utils/config_objects.py: -------------------------------------------------------------------------------- 1 | class EnvironmentConfig: 2 | def __init__(self): 3 | self.seed = 8 4 | self.n_agents = 3 5 | self.grid_size = 16 6 | self.max_cycles = 25 7 | self.learn_policy = "grid_wise_control" 8 | 9 | 10 | class TrainConfig: 11 | def __init__(self): 12 | self.epochs = 100000 13 | self.evaluate_epoch = 1 14 | self.show_evaluate_epoch = 20 15 | self.memory_batch = 32 16 | self.memory_size = 1000 17 | self.run_episode_before_train = 3 # 用同一个策略跑几个episode,onpolicy算法中使用 18 | self.learn_num = 2 19 | self.lr_actor = 1e-4 20 | self.lr_critic = 1e-3 21 | self.gamma = 0.99 # 衰减因子 22 | self.var = 0.05 # ddpg选择动作添加的噪声点,以输出为均值var为方差进行探索 23 | self.epsilon = 0.7 24 | self.grad_norm_clip = 10 25 | self.ppo_loss_clip = 0.2 # ppo的损失函数截取值 26 | self.target_update_cycle = 100 27 | self.save_epoch = 1000 28 | self.model_dir = r"./models" 29 | self.result_dir = r"./results" 30 | self.cuda = True 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 yangchen1997 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn, Tensor 2 | 3 | 4 | def weight_init(m): 5 | # weight_initialization 6 | if isinstance(m, nn.Conv2d): 7 | nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='relu') 8 | elif isinstance(m, nn.BatchNorm2d): 9 | nn.init.normal_(m.weight, mean=0, std=0.02) 10 | elif isinstance(m, nn.Linear): 11 | nn.init.normal_(m.weight, mean=0, std=0.02) 12 | nn.init.normal_(m.bias, mean=1, std=0.02) 13 | 14 | 15 | def reshape_tensor_from_list(tensor: Tensor, shape_list: list) -> list: 16 | """ 17 | 根据shape_list 切分张量, 18 | 为了避免一个batch中游戏长度不同(pettingzoo中可能没有这个问题,smac中存在此问题), 19 | 将tensor放入list 20 | :param tensor: 输入的张量 21 | :param shape_list: 每局游戏长度的list 22 | :return: 按照每局长度切分的张量,结果封装成list 23 | """ 24 | if len(tensor) != sum(shape_list): 25 | raise ValueError("value error: len(tensor.shape) not equals sum(shape_list)") 26 | if len(tensor.shape) != 1: 27 | raise ValueError("value error: len(tensor.shape) != 1") 28 | rewards = [] 29 | current_index = 0 30 | for i in shape_list: 31 | rewards.append(tensor[current_index:current_index + i]) 32 | current_index += i 33 | return rewards 34 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | 4 | import matplotlib.pyplot as plt 5 | from absl import app 6 | 7 | from common.pettingzoo_environment import SimpleSpreadEnv 8 | from runner import RunnerSimpleSpreadEnv 9 | from utils.config_utils import ConfigObjectFactory 10 | 11 | 12 | def main(args): 13 | print(args) 14 | env = SimpleSpreadEnv() 15 | try: 16 | runner = RunnerSimpleSpreadEnv(env) 17 | runner.run_marl() 18 | finally: 19 | env.close() 20 | 21 | 22 | def evaluate(): 23 | train_config = ConfigObjectFactory.get_train_config() 24 | env_config = ConfigObjectFactory.get_environment_config() 25 | csv_filename = os.path.join(train_config.result_dir, env_config.learn_policy, "result.csv") 26 | rewards = [] 27 | total_rewards = [] 28 | avg_rewards = 0 29 | len_csv = 0 30 | with open(csv_filename, 'r') as f: 31 | r_csv = csv.reader(f) 32 | for data in r_csv: 33 | total_rewards.append(round(float(data[0]), 2)) 34 | avg_rewards += float(data[0]) 35 | if len_csv % train_config.show_evaluate_epoch == 0 and len_csv > 0: 36 | rewards.append(round(avg_rewards / train_config.show_evaluate_epoch, 2)) 37 | avg_rewards = 0 38 | len_csv += 1 39 | 40 | plt.plot([i * train_config.show_evaluate_epoch for i in range(len_csv // train_config.show_evaluate_epoch)], 41 | rewards) 42 | plt.plot([i for i in range(len_csv)], total_rewards, alpha=0.3) 43 | plt.title("rewards") 44 | plt.show() 45 | 46 | 47 | if __name__ == "__main__": 48 | app.run(main) 49 | evaluate() 50 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | from utils.config_objects import * 4 | 5 | 6 | class ConfigName: 7 | environment_config_name = "environment" 8 | train_config_name = "train" 9 | 10 | 11 | class ConfigObjectFactory(object): 12 | yaml_path = r"./config.yaml" 13 | yaml_config = None 14 | environment_config = None 15 | train_config = None 16 | 17 | @classmethod 18 | def init_yaml_config(cls): 19 | if not cls.yaml_config: 20 | with open(cls.yaml_path, 'rb') as file: 21 | cls.yaml_config = yaml.safe_load(file) 22 | 23 | @staticmethod 24 | def init_config_object_attr(instance: object, attrs: dict): 25 | if not instance or not attrs: 26 | return 27 | for name, value in attrs.items(): 28 | if hasattr(instance, name): 29 | setattr(instance, name, value) 30 | 31 | @classmethod 32 | def get_environment_config(cls) -> EnvironmentConfig: 33 | if cls.environment_config is None: 34 | cls.init_yaml_config() 35 | cls.environment_config = EnvironmentConfig() 36 | if ConfigName.environment_config_name in cls.yaml_config: 37 | cls.init_config_object_attr(cls.environment_config, cls.yaml_config[ConfigName.environment_config_name]) 38 | return cls.environment_config 39 | 40 | @classmethod 41 | def get_train_config(cls) -> TrainConfig: 42 | if cls.train_config is None: 43 | cls.init_yaml_config() 44 | cls.train_config = TrainConfig() 45 | if ConfigName.train_config_name in cls.yaml_config: 46 | cls.init_config_object_attr(cls.train_config, cls.yaml_config[ConfigName.train_config_name]) 47 | return cls.train_config 48 | 49 | 50 | -------------------------------------------------------------------------------- /networks/grid_net_critic.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class StateValueModel(nn.Module, ABC): 8 | def __init__(self, grid_input_shape: list): 9 | super(StateValueModel, self).__init__() 10 | input_shape = grid_input_shape[2] * grid_input_shape[3] * 64 11 | self.fc = nn.Sequential( 12 | nn.Linear(in_features=input_shape, out_features=128), 13 | nn.ReLU(), 14 | nn.Linear(in_features=128, out_features=64), 15 | nn.ReLU(), 16 | nn.Linear(in_features=64, out_features=1), 17 | nn.ReLU() 18 | ) 19 | 20 | def forward(self, state): 21 | result = self.fc(state).squeeze() 22 | return result 23 | 24 | 25 | class QValueModelDDPG(nn.Module, ABC): 26 | def __init__(self, grid_input_shape: list, n_agent: int, action_dim: int): 27 | super(QValueModelDDPG, self).__init__() 28 | state_input_shape = grid_input_shape[2] * grid_input_shape[3] * 64 29 | self.actions_input_shape = n_agent * action_dim 30 | self.fc_state = nn.Sequential( 31 | nn.Linear(in_features=state_input_shape, out_features=128), 32 | nn.ReLU(), 33 | nn.Linear(in_features=128, out_features=128), 34 | nn.ReLU(), 35 | ) 36 | self.fc_actions = nn.Sequential( 37 | nn.Linear(in_features=self.actions_input_shape, out_features=128), 38 | nn.ReLU(), 39 | nn.Linear(in_features=128, out_features=128), 40 | nn.ReLU(), 41 | ) 42 | self.output_final = nn.Linear(128, 1) 43 | 44 | def forward(self, state, actions): 45 | state_output = self.fc_state(state) 46 | actions = actions.view(-1, self.actions_input_shape) 47 | action_output = self.fc_actions(actions) 48 | result = self.output_final(F.relu(state_output + action_output)) 49 | return result 50 | -------------------------------------------------------------------------------- /networks/qmix_net.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class RNN(nn.Module, ABC): 9 | 10 | def __init__(self, input_shape: int, n_actions: int, rnn_hidden_dim: int): 11 | super(RNN, self).__init__() 12 | self.rnn_hidden_dim = rnn_hidden_dim 13 | self.fc1 = nn.Linear(input_shape, self.rnn_hidden_dim) 14 | self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim) 15 | self.fc2 = nn.Linear(self.rnn_hidden_dim, n_actions) 16 | 17 | def forward(self, obs, hidden_state): 18 | x = torch.relu(self.fc1(obs)) 19 | h_in = hidden_state.reshape(-1, self.rnn_hidden_dim) 20 | h = self.rnn(x, h_in) 21 | q = self.fc2(h) 22 | return q, h 23 | 24 | 25 | class QMixNet(nn.Module, ABC): 26 | """ 27 | 因为生成的hyper_w1需要是一个矩阵,而pytorch神经网络只能输出一个向量, 28 | 所以就先输出长度为需要的 矩阵行*矩阵列 的向量,然后再转化成矩阵 29 | n_agents是使用hyper_w1作为参数的网络的输入维度 qmix_hidden_dim是网络隐藏层参数个数 30 | 从而经过hyper_w1得到(经验条数,n_agents * qmix_hidden_dim)的矩阵 31 | """ 32 | 33 | def __init__(self, n_agents: int, state_shape: int): 34 | super(QMixNet, self).__init__() 35 | self.qmix_hidden_dim = 32 36 | self.n_agents = n_agents 37 | self.state_shape = state_shape 38 | self.hyper_w1 = nn.Linear(state_shape, n_agents * self.qmix_hidden_dim) 39 | self.hyper_w2 = nn.Linear(state_shape, self.qmix_hidden_dim * 1) 40 | 41 | self.hyper_b1 = nn.Linear(state_shape, self.qmix_hidden_dim) 42 | self.hyper_b2 = nn.Sequential(nn.Linear(state_shape, self.qmix_hidden_dim), 43 | nn.ReLU(), 44 | nn.Linear(self.qmix_hidden_dim, 1) 45 | ) 46 | 47 | def forward(self, q_values, states): 48 | # states的shape为(batch_size, max_episode_len, state_shape) 49 | # 传入的q_values是三维的,shape为(batch_size, max_episode_len, n_agents) 50 | episode_num = q_values.size(0) 51 | q_values = q_values.view(-1, 1, self.n_agents) 52 | states = states.reshape(-1, self.state_shape) 53 | 54 | w1 = torch.abs(self.hyper_w1(states)) 55 | b1 = self.hyper_b1(states) 56 | 57 | w1 = w1.view(-1, self.n_agents, self.qmix_hidden_dim) 58 | b1 = b1.view(-1, 1, self.qmix_hidden_dim) 59 | 60 | hidden = F.elu(torch.bmm(q_values, w1) + b1) 61 | 62 | w2 = torch.abs(self.hyper_w2(states)) 63 | b2 = self.hyper_b2(states) 64 | 65 | w2 = w2.view(-1, self.qmix_hidden_dim, 1) 66 | b2 = b2.view(-1, 1, 1) 67 | 68 | q_total = torch.bmm(hidden, w2) + b2 69 | q_total = q_total.view(episode_num, -1, 1) 70 | return q_total 71 | -------------------------------------------------------------------------------- /utils/env_utils.py: -------------------------------------------------------------------------------- 1 | from numpy import ndarray 2 | 3 | from utils.config_utils import ConfigObjectFactory 4 | 5 | 6 | def map_value(data: ndarray) -> tuple: 7 | """ 8 | 将agent现有的坐标进行映射,暂时假设原坐标区间为[-3,3], 需要映射到[0,grid_size]区间 9 | :param data: 原坐标 10 | :return: 映射坐标 11 | """ 12 | grid_size = ConfigObjectFactory.get_environment_config().grid_size 13 | target_min = 0 14 | target_max = grid_size - 1 15 | pos_x, pos_y = clip_pos((float(data[0]), float(data[1])), min_value=-1.5, max_value=1.5) 16 | map_pos_x = target_min + (target_max - target_min) / (1.5 - -1.5) * (pos_x - -1.5) 17 | map_pos_y = target_min + (target_max - target_min) / (1.5 - -1.5) * (pos_y - -1.5) 18 | return map_pos_x, map_pos_y 19 | 20 | 21 | def clip_pos(old_pos: tuple, min_value: float, max_value: float) -> tuple: 22 | """ 23 | 对坐标进行裁剪,以免超出grid_input的范围 24 | :param min_value: 25 | :param max_value: 26 | :param old_pos: 待剪裁的坐标 27 | :return: 新坐标 28 | """ 29 | if len(old_pos) != 2: 30 | raise ValueError( 31 | 'Expecting a list of length 2 as input, but now the length of the input is {}.'.format(len(old_pos))) 32 | x = old_pos[0] 33 | y = old_pos[1] 34 | if x < min_value: 35 | x = min_value 36 | if x > max_value - 1: 37 | x = max_value - 1 38 | if y < min_value: 39 | y = min_value 40 | if y > max_value - 1: 41 | y = max_value - 1 42 | return x, y 43 | 44 | 45 | def get_approximate_pos(pos: tuple) -> tuple: 46 | """ 47 | 根据现有的位置,计算近似位置(将小数变成整数) 48 | :param pos: 49 | :return: 50 | """ 51 | if len(pos) != 2: 52 | raise ValueError( 53 | 'Expecting a list of length 2 as input, but now the length of the input is {}.'.format(len(pos))) 54 | grid_size = ConfigObjectFactory.get_environment_config().grid_size 55 | approximate_pos = int(round(pos[0], 0)), int(round(pos[1], 0)) 56 | return clip_pos(approximate_pos, min_value=0, max_value=grid_size) 57 | 58 | 59 | def recomput_approximate_pos(pos_dict: dict, approximate_pos: tuple, pos: tuple) -> tuple: 60 | """ 61 | 多个agent的近似位置发生冲突,重新计算近似位置 62 | :param pos_dict: 63 | :param approximate_pos: 64 | :param pos: 65 | :return: 66 | """ 67 | if not pos_dict: 68 | raise ValueError('pos_dict is null or length of pos_dict is 0') 69 | grid_size = ConfigObjectFactory.get_environment_config().grid_size 70 | exist_pos = pos_dict[approximate_pos] 71 | gap_x = abs(exist_pos[0] - pos[0]) 72 | gap_y = abs(exist_pos[1] - pos[1]) 73 | approximate_pos_x = approximate_pos[0] 74 | approximate_pos_y = approximate_pos[1] 75 | if gap_x <= gap_y: 76 | if exist_pos[0] >= pos[0]: 77 | approximate_pos_x -= 1 78 | else: 79 | approximate_pos_x += 1 80 | else: 81 | if exist_pos[1] >= pos[1]: 82 | approximate_pos_y -= 1 83 | else: 84 | approximate_pos_y += 1 85 | return clip_pos((approximate_pos_x, approximate_pos_y), min_value=0, max_value=grid_size) 86 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Abstract 4 | 5 | The implementation of multi-agent reinforcement learning algorithm in Pytorch, including: Grid-Wise Control, Qmix, Centralized PPO. Different learning strategies can be specified during training, and model and experimental data can be saved. 6 | 7 | **Quick Start:** Run the **main.py** script to start training. Please specify all parameters in the **config.yaml** file (The parameters used in this project are not optimal parameters, please adjust them according to the actual requirement). 8 | 9 | # Petting Zoo 10 | 11 | **MPE:** Multi Particle Environments (MPE) are a set of communication oriented environment where particle agents can (sometimes) move, communicate, see each other, push each other around, and interact with fixed landmarks. 12 | 13 | These environments are from [OpenAI’s MPE](https://github.com/openai/multiagent-particle-envs) codebase, with several minor fixes, mostly related to making the action space discrete by default, making the rewards consistent and cleaning up the observation space of certain environments. 14 | 15 | The environment applied in this project is **Simple Spread** (I'm also considering adding other environments in future releases). 16 | 17 | Env image 18 | 19 | 20 | 21 | # Requirement 22 | 23 | Note: The following are suggested versions only, and do not mean the program will not work with other versions. 24 | 25 | | Name | Version | 26 | | ---------- | ----------- | 27 | | Python | 3.10.9 | 28 | | gymnasium | 0.28.1 | 29 | | numpy | 1.23.5 | 30 | | PettingZoo | 1.23.0 | 31 | | Pytorch | 1.12.1 | 32 | 33 | Update on 4.10.2023: [Pytorch 2.0.0+cu118](https://pytorch.org/get-started/previous-versions/) on python 3.9.16 works. Please notice that python >3.9 won't work because PettingZoo 1.12.0 is not available. 34 | 35 | # Corresponding Papers 36 | 37 | - [Grid-Wise Control for Multi-Agent Reinforcement Learning in Video Game AI]([proceedings.mlr.press/v97/han19a/han19a.pdf](http://proceedings.mlr.press/v97/han19a/han19a.pdf)) 38 | - [QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning](https://arxiv.org/abs/1803.11485) 39 | 40 | - [The Surprising Effectiveness of PPOin Cooperative Multi-Agent Games](https://arxiv.org/abs/2103.01955) 41 | 42 | 43 | 44 | # Reference 45 | 46 | - **petting zoo:** 47 | 48 | ``` 49 | @article{terry2020pettingzoo, 50 | Title = {PettingZoo: Gym for Multi-Agent Reinforcement Learning}, 51 | Author = {Terry, J. K and Black, Benjamin and Grammel, Nathaniel and Jayakumar, Mario and Hari, Ananth and Sulivan, Ryan and Santos, Luis and Perez, Rodrigo and Horsch, Caroline and Dieffendahl, Clemens and Williams, Niall L and Lokesh, Yashas and Sullivan, Ryan and Ravi, Praveen}, 52 | journal={arXiv preprint arXiv:2009.14471}, 53 | year={2020} 54 | } 55 | ``` 56 | 57 | - **Qmix:** [starry-sky6688/StarCraft: Implementations of IQL, QMIX, VDN, COMA, QTRAN, MAVEN, CommNet, DyMA-CL, and G2ANet on SMAC, the decentralised micromanagement scenario of StarCraft II (github.com)](https://github.com/starry-sky6688/StarCraft) 58 | -------------------------------------------------------------------------------- /networks/ppo_net.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class CentralizedPPOActor(nn.Module, ABC): 8 | """ 9 | ppo算法属于on_policy算法,使用rnn网络来记录之前的经验。 10 | centralized_ppo 属于 centralized算法,需要通过搜集所有agent的观测值来给出动作 11 | """ 12 | 13 | def __init__(self, input_shape: int, action_dim: int, n_agents: int, rnn_hidden_dim: int): 14 | super(CentralizedPPOActor, self).__init__() 15 | self.rnn_hidden_dim = rnn_hidden_dim 16 | self.n_agents = n_agents 17 | self.action_dim = action_dim 18 | self.fc1 = nn.Linear(input_shape, self.rnn_hidden_dim) 19 | self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim) 20 | self.fc2 = nn.Linear(self.rnn_hidden_dim, self.action_dim * self.n_agents) 21 | 22 | def forward(self, obs, hidden_state): 23 | fc1_out = torch.relu(self.fc1(obs)) 24 | h_in = hidden_state.reshape(-1, self.rnn_hidden_dim) 25 | rnn_out = self.rnn(fc1_out, h_in) 26 | # 将动作空间映射到[0,1] 27 | fc2_out = torch.sigmoid(self.fc2(rnn_out)) 28 | fc2_out = fc2_out.view(-1, self.n_agents, self.action_dim) 29 | return fc2_out, rnn_out 30 | 31 | 32 | class CentralizedPPOCritic(nn.Module, ABC): 33 | """ 34 | centralized_ppo的Critic将全局的state作为输入 35 | """ 36 | 37 | def __init__(self, state_dim: int): 38 | super(CentralizedPPOCritic, self).__init__() 39 | self.fc = nn.Sequential( 40 | nn.Linear(in_features=state_dim, out_features=128), 41 | nn.ReLU(), 42 | nn.Linear(in_features=128, out_features=64), 43 | nn.ReLU(), 44 | nn.Linear(in_features=64, out_features=1), 45 | nn.ReLU() 46 | ) 47 | 48 | def forward(self, state): 49 | result = self.fc(state).squeeze() 50 | return result 51 | 52 | 53 | class IndependentPPOActor(nn.Module, ABC): 54 | """ 55 | ppo算法属于on_policy算法,使用rnn网络来记录之前的经验。 56 | independent_ppo 属于 independent算法,只需要收集单个agent的观测值就能给出动作 57 | """ 58 | 59 | def __init__(self, obs_dim: int, action_dim: int, rnn_hidden_dim: int): 60 | super(IndependentPPOActor, self).__init__() 61 | self.rnn_hidden_dim = rnn_hidden_dim 62 | self.action_dim = action_dim 63 | self.fc1 = nn.Linear(obs_dim, self.rnn_hidden_dim) 64 | self.rnn = nn.GRUCell(self.rnn_hidden_dim, self.rnn_hidden_dim) 65 | self.fc2 = nn.Linear(self.rnn_hidden_dim, self.action_dim) 66 | 67 | def forward(self, obs, hidden_state): 68 | fc1_out = torch.relu(self.fc1(obs)) 69 | rnn_out = self.rnn(fc1_out, hidden_state) 70 | # 将动作空间映射到[0,1] 71 | fc2_out = torch.sigmoid(self.fc2(rnn_out)) 72 | return fc2_out, rnn_out 73 | 74 | 75 | class IndependentPPOCritic(nn.Module, ABC): 76 | """ 77 | centralized_ppo的Critic将每个agent的obs作为输入 78 | """ 79 | 80 | def __init__(self, obs_dim: int): 81 | super(IndependentPPOCritic, self).__init__() 82 | self.fc = nn.Sequential( 83 | nn.Linear(in_features=obs_dim, out_features=128), 84 | nn.ReLU(), 85 | nn.Linear(in_features=128, out_features=64), 86 | nn.ReLU(), 87 | nn.Linear(in_features=64, out_features=1), 88 | nn.ReLU() 89 | ) 90 | 91 | def forward(self, state): 92 | result = self.fc(state).squeeze() 93 | return result 94 | -------------------------------------------------------------------------------- /networks/grid_net_actor.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch import Tensor 7 | 8 | 9 | # deeplabv3自编码器 10 | # 参考代码https: // blog.csdn.net / chenfang0529 / article / details / 108133672 11 | 12 | class AutoEncoder(nn.Module, ABC): 13 | def __init__(self, n_actions: int, grid_input_shape: list): 14 | super(AutoEncoder, self).__init__() 15 | self.grid_input_shape = grid_input_shape 16 | self.auto_encoder_output_shape = grid_input_shape[2] * grid_input_shape[3] * 64 17 | self.conv_block_1x1_1 = self.conv_block(input_channel=grid_input_shape[1], output_channel=64, kernel_size=1) 18 | self.conv_block_3x3_1 = self.conv_block(input_channel=grid_input_shape[1], output_channel=64, 19 | kernel_size=3, padding=6, dilation=6) 20 | self.conv_block_3x3_2 = self.conv_block(input_channel=grid_input_shape[1], output_channel=64, 21 | kernel_size=3, padding=12, dilation=12) 22 | 23 | self.avg_pool = nn.AdaptiveAvgPool2d(grid_input_shape[3] // 4) 24 | self.conv_block_1x1_2 = nn.Conv2d(grid_input_shape[1], 64, kernel_size=1) 25 | self.conv_block_1x1_3 = nn.Conv2d(256, 64, kernel_size=1) 26 | self.conv_block_1x1_4 = self.conv_block(input_channel=grid_input_shape[1], output_channel=64, kernel_size=1) 27 | self.conv_block_1x1_5 = self.conv_block(input_channel=128, output_channel=256, kernel_size=1) 28 | self.conv_block_1x1_6 = nn.Conv2d(256, n_actions, kernel_size=1) 29 | 30 | # 定义一个卷积块的静态方法,增加泛用性 31 | @staticmethod 32 | def conv_block(input_channel: int, output_channel: int, kernel_size: int, 33 | stride: int = 1, padding: int = 0, dilation: int = 1) -> nn.Sequential: 34 | one_conv_block = nn.Sequential( 35 | nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=kernel_size, 36 | stride=stride, padding=padding, dilation=dilation, bias=False), 37 | nn.BatchNorm2d(output_channel), 38 | nn.ReLU() 39 | ) 40 | return one_conv_block 41 | 42 | def encoder(self, grid_input: Tensor): 43 | grid_input_w = self.grid_input_shape[2] 44 | grid_input_h = self.grid_input_shape[3] 45 | out_1x1_1 = self.conv_block_1x1_1(grid_input) # 对应图中 E 46 | out_3x3_1 = self.conv_block_3x3_1(grid_input) # 对应图中 D 47 | out_3x3_2 = self.conv_block_3x3_2(grid_input) # 对应图中 C 48 | grid_input_avg = self.avg_pool(grid_input) # 对应图中 ImagePooling 49 | out_1x1_2 = self.conv_block_1x1_2(grid_input_avg) 50 | out_1x1_2_up = F.interpolate(out_1x1_2, size=(grid_input_h, grid_input_w), mode="bilinear", align_corners=False) 51 | out_cat = torch.cat([out_1x1_1, out_3x3_1, out_3x3_2, out_1x1_2_up], 1) 52 | encoder_out = self.conv_block_1x1_3(out_cat) # 对应图中 H out 对应图中I 53 | return encoder_out 54 | 55 | def decoder(self, grid_input: Tensor, encoder_output: Tensor): 56 | grid_input_w = self.grid_input_shape[2] 57 | grid_input_h = self.grid_input_shape[3] 58 | out_1x1_4 = self.conv_block_1x1_4(grid_input) 59 | encoder_output_up = F.interpolate(encoder_output, size=(grid_input_h, grid_input_w), mode="bilinear", 60 | align_corners=False) 61 | out_cat = torch.cat([out_1x1_4, encoder_output_up], 1) 62 | out_1x1_5 = self.conv_block_1x1_5(out_cat) 63 | decoder_out = self.conv_block_1x1_6(out_1x1_5) 64 | return decoder_out 65 | 66 | # 前向传播 67 | def forward(self, grid_input): 68 | encoder_output = self.encoder(grid_input) 69 | decoder_output = self.decoder(grid_input, encoder_output) 70 | decoder_output_softmax = F.softmax(decoder_output, dim=1) 71 | encoder_output_clone = encoder_output.clone().detach().view(-1, self.auto_encoder_output_shape) 72 | return decoder_output_softmax, encoder_output_clone 73 | 74 | 75 | class AutoEncoderContinuousActions(AutoEncoder, ABC): 76 | def __init__(self, grid_input_shape: list, action_dim: int = 5): 77 | super(AutoEncoderContinuousActions, self).__init__(action_dim, grid_input_shape) 78 | 79 | def forward(self, grid_input): 80 | encoder_output = self.encoder(grid_input) 81 | decoder_output = self.decoder(grid_input, encoder_output) 82 | # 对action进行放缩,实际上a in [0,1] 83 | decoder_output_sigmoid = torch.sigmoid(decoder_output) 84 | encoder_output_clone = encoder_output.clone().detach().view(-1, self.auto_encoder_output_shape) 85 | return decoder_output_sigmoid, encoder_output_clone 86 | -------------------------------------------------------------------------------- /common/pettingzoo_environment.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | from pettingzoo.mpe import simple_spread_v3 5 | from pettingzoo.utils import aec_to_parallel 6 | from torch import Tensor 7 | 8 | from utils.env_utils import * 9 | 10 | 11 | class SimpleSpreadEnv(object): 12 | def __init__(self): 13 | self.env_config = ConfigObjectFactory.get_environment_config() 14 | self.continuous_actions = "ddpg" in self.env_config.learn_policy or "ppo" in self.env_config.learn_policy 15 | self.env = simple_spread_v3.env(N=self.env_config.n_agents, local_ratio=0.5, 16 | max_cycles=self.env_config.max_cycles, 17 | continuous_actions=self.continuous_actions) 18 | self.parallel_env = aec_to_parallel(self.env) 19 | self.world = self.parallel_env.unwrapped.world 20 | self.agents_name = self.parallel_env.possible_agents 21 | self.agents = self.world.agents 22 | self.grid_input_features = ["density", "mass", "size"] 23 | self.grid_size = self.env_config.grid_size 24 | 25 | def render(self, mode="human"): 26 | return self.parallel_env.render(mode=mode) 27 | 28 | def reset(self): 29 | return self.parallel_env.reset(seed=self.env_config.seed) 30 | 31 | def close(self): 32 | return self.parallel_env.close() 33 | 34 | def state(self): 35 | return self.parallel_env.state() 36 | 37 | def get_env_info(self) -> dict: 38 | map_info = { 39 | 'grid_input_shape': [0, len(self.grid_input_features), self.grid_size, self.grid_size], 40 | 'n_agents': self.env_config.n_agents, 41 | 'agents_name': self.agents_name, 42 | 'obs_space': sum(self.parallel_env.observation_spaces[self.agents_name[0]].shape), 43 | 'state_space': sum(self.parallel_env.state_space.shape) 44 | } 45 | if not self.continuous_actions: 46 | map_info['n_actions'] = 5 47 | else: 48 | map_info['action_dim'] = 5 49 | map_info['action_space'] = self.parallel_env.action_space(self.agents_name[0]) 50 | return map_info 51 | 52 | def step(self, actions: dict): 53 | # self.env.render() 54 | # time.sleep(0.05) 55 | observations, rewards, dones, _, infos = self.parallel_env.step(actions) 56 | # 所有agent都结束游戏整局游戏才算结束 57 | finish_game = not (False in dones.values()) 58 | rewards = sum(rewards.values()) / self.env_config.n_agents 59 | return observations, rewards, finish_game, infos 60 | 61 | def get_agents_approximate_pos(self) -> list: 62 | """ 63 | 初始化grid_input,先将坐标映射,然后再离散化取整 64 | 如果当前格子存在agent,将后来的智能体进行移位 65 | :return: 66 | """ 67 | position_dict = {} 68 | approximate_poses = [] 69 | for agent in self.agents: 70 | if agent.movable: 71 | map_pos = map_value(agent.state.p_pos) 72 | approximate_pos = get_approximate_pos(map_pos) 73 | # 当前近似坐标已存在重新计算近似坐标 74 | if approximate_pos in position_dict.keys(): 75 | approximate_pos = recomput_approximate_pos(position_dict, approximate_pos, map_pos) 76 | position_dict[approximate_pos] = map_pos 77 | approximate_poses.append(approximate_pos) 78 | return approximate_poses 79 | 80 | def get_grid_input(self) -> Tensor: 81 | """ 82 | 根据当前状态初始化grid_input 83 | :return: 84 | """ 85 | approximate_pos_list = self.get_agents_approximate_pos() 86 | grid_input = np.zeros((1, len(self.grid_input_features), self.grid_size, self.grid_size), 87 | dtype=np.float32) 88 | for agent_id, pos in enumerate(approximate_pos_list): 89 | agent = self.agents[agent_id] 90 | for feature_num, feature_name in enumerate(self.grid_input_features): 91 | grid_input[0, feature_num, pos[1], pos[0]] = getattr(agent, feature_name, float(0)) 92 | tensor = torch.from_numpy(grid_input) 93 | return tensor 94 | 95 | def draw_maps(self): 96 | """ 97 | 画出每个agent的位置,以便进行验证 98 | :return: 99 | """ 100 | agents = self.world.agents 101 | agent_x = [] 102 | agent_y = [] 103 | for agent in agents: 104 | agent_x.append(agent.state.p_pos[0]) 105 | agent_y.append(agent.state.p_pos[1]) 106 | plt.scatter(agent_x, agent_y, c='red') 107 | plt.title("real_pos") 108 | plt.xlim((-3, 3)) 109 | plt.ylim((-3, 3)) 110 | plt.show() 111 | 112 | approximate_pos_list = self.get_agents_approximate_pos() 113 | plt.scatter([pos[0] for pos in approximate_pos_list], [pos[1] for pos in approximate_pos_list], c='red') 114 | plt.title("map_pos") 115 | plt.xlim((0, self.env_config.grid_size - 1)) 116 | plt.ylim((0, self.env_config.grid_size - 1)) 117 | plt.show() 118 | -------------------------------------------------------------------------------- /policy/centralized_ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch.distributions import MultivariateNormal 7 | 8 | from networks.ppo_net import CentralizedPPOActor, CentralizedPPOCritic 9 | from policy.base_policy import BasePolicy 10 | from utils.config_utils import ConfigObjectFactory 11 | from utils.train_utils import weight_init 12 | 13 | 14 | class CentralizedPPO(BasePolicy): 15 | 16 | def __init__(self, env_info: dict): 17 | 18 | # 读取配置 19 | self.train_config = ConfigObjectFactory.get_train_config() 20 | self.env_config = ConfigObjectFactory.get_environment_config() 21 | self.n_agents = env_info['n_agents'] 22 | self.action_dim = env_info['action_dim'] 23 | 24 | # 初始化网络 25 | self.rnn_hidden_dim = 64 26 | actor_input_shape = env_info['obs_space'] * self.n_agents 27 | self.ppo_actor = CentralizedPPOActor(actor_input_shape, self.action_dim, self.n_agents, self.rnn_hidden_dim) 28 | self.ppo_critic = CentralizedPPOCritic(env_info['state_space']) 29 | self.optimizer_actor = torch.optim.Adam(params=self.ppo_actor.parameters(), 30 | lr=self.train_config.lr_actor) 31 | self.optimizer_critic = torch.optim.Adam(params=self.ppo_critic.parameters(), 32 | lr=self.train_config.lr_critic) 33 | # 初始化路径 34 | self.model_path = os.path.join(self.train_config.model_dir, self.env_config.learn_policy) 35 | self.result_path = os.path.join(self.train_config.result_dir, self.env_config.learn_policy) 36 | self.init_path(self.model_path, self.result_path) 37 | self.ppo_actor_path = os.path.join(self.model_path, "ppo_actor.pth") 38 | self.ppo_critic_path = os.path.join(self.model_path, "ppo_critic.pth") 39 | 40 | # 是否使用GPU加速 41 | if self.train_config.cuda: 42 | torch.cuda.empty_cache() 43 | self.device = torch.device('cuda:0') 44 | else: 45 | self.device = torch.device('cpu') 46 | 47 | self.ppo_actor.to(self.device) 48 | self.ppo_critic.to(self.device) 49 | 50 | # 初始化动作的协方差矩阵,以便动作取样 51 | self.cov_var = torch.full(size=(self.action_dim,), fill_value=0.01) 52 | self.cov_mat = torch.diag(self.cov_var).to(self.device) 53 | self.init_wight() 54 | 55 | def init_wight(self): 56 | self.ppo_actor.apply(weight_init) 57 | self.ppo_critic.apply(weight_init) 58 | 59 | def init_hidden(self, batch_size): 60 | self.rnn_hidden = torch.zeros((batch_size, self.rnn_hidden_dim)).to(self.device) 61 | 62 | def learn(self, batch_data: dict, episode_num: int): 63 | # 从batch data中提取数据 64 | obs = batch_data['obs'].to(self.device).detach() 65 | state = batch_data['state'].to(self.device) 66 | actions = batch_data['actions'].to(self.device) 67 | log_probs = batch_data['log_probs'].to(self.device) 68 | batch_size = sum(batch_data['per_episode_len']) 69 | rewards = batch_data['rewards'] 70 | obs = obs.reshape(batch_size, -1) 71 | discount_reward = self.get_discount_reward(rewards).to(self.device) 72 | self.init_hidden(batch_size) 73 | # 计算状态价值和优势函数 74 | with torch.no_grad(): 75 | state_value = self.ppo_critic(state) 76 | advantage_function = discount_reward - state_value 77 | # 标准化advantage_function减少环境不确定带来的影响 78 | advantage_function = ((advantage_function - advantage_function.mean()) / ( 79 | advantage_function.std() + 1e-10)).unsqueeze(dim=-1) 80 | # 开始学习,重度重采样。 81 | for i in range(self.train_config.learn_num): 82 | 83 | # 防止rnn_hidden再一次学习被优化多次,造成图混乱,输入的时候要禁止rnn_hidden优化(把它当做x) 84 | action_means, self.rnn_hidden = self.ppo_actor(obs, self.rnn_hidden.detach()) 85 | dist = MultivariateNormal(action_means, self.cov_mat) 86 | curr_log_probs = dist.log_prob(actions) 87 | # 计算loss 88 | ratios = torch.exp(curr_log_probs - log_probs) 89 | surr1 = ratios * advantage_function 90 | surr2 = torch.clamp(ratios, 1 - self.train_config.ppo_loss_clip, 91 | 1 + self.train_config.ppo_loss_clip) * advantage_function 92 | actor_loss = (-torch.min(surr1, surr2)).mean() 93 | 94 | # actor_loss:取两个函数的最小值 95 | self.optimizer_actor.zero_grad() 96 | actor_loss.backward() 97 | self.optimizer_actor.step() 98 | 99 | # critic_loss: td_error的均方误差 100 | curr_state_value = self.ppo_critic(state) 101 | critic_loss = nn.MSELoss()(curr_state_value, discount_reward) 102 | self.optimizer_critic.zero_grad() 103 | critic_loss.backward() 104 | self.optimizer_critic.step() 105 | 106 | def get_cov_mat(self) -> Tensor: 107 | return self.cov_mat 108 | 109 | def get_discount_reward(self, batch_reward: list) -> Tensor: 110 | discount_rewards = [] 111 | for reward in reversed(batch_reward): 112 | discounted_reward = 0 113 | for one_reward in reversed(reward): 114 | discounted_reward = one_reward + discounted_reward * self.train_config.gamma 115 | discount_rewards.insert(0, discounted_reward) 116 | return torch.Tensor(discount_rewards) 117 | 118 | def save_model(self): 119 | torch.save(self.ppo_actor.state_dict(), self.ppo_actor_path) 120 | torch.save(self.ppo_critic.state_dict(), self.ppo_critic_path) 121 | 122 | def load_model(self): 123 | self.ppo_actor.load_state_dict(torch.load(self.ppo_actor_path)) 124 | self.ppo_critic.load_state_dict(torch.load(self.ppo_critic_path)) 125 | 126 | def del_model(self): 127 | file_list = os.listdir(self.model_path) 128 | for file in file_list: 129 | os.remove(os.path.join(self.model_path, file)) 130 | 131 | def is_saved_model(self) -> bool: 132 | return os.path.exists(self.ppo_actor_path) and os.path.exists(self.ppo_critic_path) 133 | -------------------------------------------------------------------------------- /policy/independent_ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch.distributions import MultivariateNormal 7 | 8 | from networks.ppo_net import IndependentPPOActor, IndependentPPOCritic 9 | from policy.base_policy import BasePolicy 10 | from utils.config_utils import ConfigObjectFactory 11 | from utils.train_utils import weight_init 12 | 13 | 14 | class IndependentPPO(BasePolicy): 15 | def __init__(self, env_info: dict): 16 | # 读取配置 17 | self.train_config = ConfigObjectFactory.get_train_config() 18 | self.env_config = ConfigObjectFactory.get_environment_config() 19 | self.n_agents = env_info['n_agents'] 20 | self.action_dim = env_info['action_dim'] 21 | 22 | # 初始化网络 23 | self.rnn_hidden_dim = 64 24 | self.ppo_actor = IndependentPPOActor(env_info['obs_space'], self.action_dim, self.rnn_hidden_dim) 25 | self.ppo_critic = IndependentPPOCritic(env_info['obs_space']) 26 | self.optimizer_actor = torch.optim.Adam(params=self.ppo_actor.parameters(), 27 | lr=self.train_config.lr_actor) 28 | self.optimizer_critic = torch.optim.Adam(params=self.ppo_critic.parameters(), 29 | lr=self.train_config.lr_critic) 30 | 31 | # 初始化路径 32 | self.model_path = os.path.join(self.train_config.model_dir, self.env_config.learn_policy) 33 | self.result_path = os.path.join(self.train_config.result_dir, self.env_config.learn_policy) 34 | self.init_path(self.model_path, self.result_path) 35 | self.ppo_actor_path = os.path.join(self.model_path, "ppo_actor.pth") 36 | self.ppo_critic_path = os.path.join(self.model_path, "ppo_critic.pth") 37 | 38 | # 是否使用GPU加速 39 | if self.train_config.cuda: 40 | torch.cuda.empty_cache() 41 | self.device = torch.device('cuda:0') 42 | else: 43 | self.device = torch.device('cpu') 44 | 45 | self.ppo_actor.to(self.device) 46 | self.ppo_critic.to(self.device) 47 | 48 | # 初始化动作的协方差矩阵,以便动作取样 49 | self.cov_var = torch.full(size=(self.action_dim,), fill_value=0.05) 50 | self.cov_mat = torch.diag(self.cov_var).to(self.device) 51 | self.init_wight() 52 | 53 | def init_wight(self): 54 | self.ppo_actor.apply(weight_init) 55 | self.ppo_critic.apply(weight_init) 56 | 57 | def init_hidden(self, batch_size: int): 58 | # 把 hidden做成一个字典,不用tenseor保存 59 | self.rnn_hidden = {} 60 | for i in range(self.n_agents): 61 | self.rnn_hidden[i] = torch.zeros((batch_size, self.rnn_hidden_dim)).to(self.device) 62 | 63 | def learn(self, batch_data: dict, episode_num: int): 64 | # 从batch data中提取数据 65 | obs = batch_data['obs'].to(self.device).detach() 66 | actions = batch_data['actions'].to(self.device) 67 | log_probs = batch_data['log_probs'].to(self.device) 68 | batch_size = sum(batch_data['per_episode_len']) 69 | rewards = batch_data['rewards'] 70 | discount_reward = self.get_discount_reward(rewards).to(self.device) 71 | self.init_hidden(batch_size) 72 | # 计算状态价值和优势函数 73 | with torch.no_grad(): 74 | state_values = [] 75 | for i in range(self.n_agents): 76 | one_state_value = self.ppo_critic(obs[:, i]) 77 | state_values.append(one_state_value) 78 | state_values = torch.stack(state_values, dim=0) 79 | advantage_function = discount_reward - state_values 80 | # 标准化advantage_function减少环境不确定带来的影响 81 | advantage_function = ((advantage_function - advantage_function.mean()) / ( 82 | advantage_function.std() + 1e-10)).unsqueeze(dim=-1) 83 | # 开始学习,重度重采样。 84 | for i in range(self.train_config.learn_num): 85 | curr_log_probs = [] 86 | curr_state_values = [] 87 | for agent_num in range(self.n_agents): 88 | one_action_mean, self.rnn_hidden[i] = self.ppo_actor(obs[:, i], self.rnn_hidden[i]) 89 | curr_state_value = self.ppo_critic(obs[:, i]) 90 | dist = MultivariateNormal(one_action_mean, self.cov_mat) 91 | curr_log_prob = dist.log_prob(actions[:, i]) 92 | curr_log_probs.append(curr_log_prob) 93 | curr_state_values.append(curr_state_value) 94 | curr_log_probs = torch.stack(curr_log_probs, dim=1) 95 | curr_state_values = torch.stack(curr_state_values, dim=0) 96 | # 计算loss 97 | ratios = torch.exp(curr_log_probs - log_probs) 98 | surr1 = ratios * advantage_function 99 | surr2 = torch.clamp(ratios, 1 - self.train_config.ppo_loss_clip, 100 | 1 + self.train_config.ppo_loss_clip) * advantage_function 101 | actor_loss = (-torch.min(surr1, surr2)).mean() 102 | 103 | # actor_loss:取两个函数的最小值 104 | self.optimizer_actor.zero_grad() 105 | actor_loss.backward() 106 | self.optimizer_actor.step() 107 | 108 | # critic_loss: td_error的均方误差 109 | critic_loss = nn.MSELoss()(curr_state_values, discount_reward) 110 | self.optimizer_critic.zero_grad() 111 | critic_loss.backward() 112 | self.optimizer_critic.step() 113 | 114 | def get_cov_mat(self) -> Tensor: 115 | return self.cov_mat 116 | 117 | def get_discount_reward(self, batch_reward: list) -> Tensor: 118 | discount_rewards = [] 119 | for reward in reversed(batch_reward): 120 | discounted_reward = 0 121 | for one_reward in reversed(reward): 122 | discounted_reward = one_reward + discounted_reward * self.train_config.gamma 123 | discount_rewards.insert(0, discounted_reward) 124 | return torch.Tensor(discount_rewards) 125 | 126 | def save_model(self): 127 | torch.save(self.ppo_actor.state_dict(), self.ppo_actor_path) 128 | torch.save(self.ppo_critic.state_dict(), self.ppo_critic_path) 129 | 130 | def load_model(self): 131 | self.ppo_actor.load_state_dict(torch.load(self.ppo_actor_path)) 132 | self.ppo_critic.load_state_dict(torch.load(self.ppo_critic_path)) 133 | 134 | def del_model(self): 135 | file_list = os.listdir(self.model_path) 136 | for file in file_list: 137 | os.remove(os.path.join(self.model_path, file)) 138 | 139 | def is_saved_model(self) -> bool: 140 | return os.path.exists(self.ppo_actor_path) and os.path.exists(self.ppo_critic_path) 141 | -------------------------------------------------------------------------------- /policy/grid_wise_control_ppo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import Tensor 6 | from torch.distributions import MultivariateNormal 7 | 8 | from networks.grid_net_actor import AutoEncoderContinuousActions 9 | from networks.grid_net_critic import StateValueModel 10 | from policy.base_policy import BasePolicy 11 | from utils.config_utils import ConfigObjectFactory 12 | from utils.train_utils import weight_init 13 | 14 | 15 | class GridWiseControlPPO(BasePolicy): 16 | 17 | def __init__(self, env_info: dict): 18 | # 读取配置 19 | self.train_config = ConfigObjectFactory.get_train_config() 20 | self.env_config = ConfigObjectFactory.get_environment_config() 21 | self.action_dim = env_info['action_dim'] 22 | self.grid_input_shape = env_info['grid_input_shape'] 23 | # 初始化网络 24 | self.auto_encoder = AutoEncoderContinuousActions(self.grid_input_shape) 25 | self.state_value_network = StateValueModel(self.grid_input_shape) 26 | self.optimizer_actor = torch.optim.Adam(params=self.auto_encoder.parameters(), 27 | lr=self.train_config.lr_actor) 28 | self.optimizer_critic = torch.optim.Adam(params=self.state_value_network.parameters(), 29 | lr=self.train_config.lr_critic) 30 | 31 | # 初始化路径 32 | self.model_path = os.path.join(self.train_config.model_dir, self.env_config.learn_policy) 33 | self.result_path = os.path.join(self.train_config.result_dir, self.env_config.learn_policy) 34 | self.init_path(self.model_path, self.result_path) 35 | self.state_value_network_path = os.path.join(self.model_path, "grid_wise_control_ppo_state_value.pth") 36 | self.auto_encoder_path = os.path.join(self.model_path, "grid_wise_control_ppo_auto_encoder.pth") 37 | 38 | # 是否使用GPU加速 39 | if self.train_config.cuda: 40 | torch.cuda.empty_cache() 41 | self.device = torch.device('cuda:0') 42 | else: 43 | self.device = torch.device('cpu') 44 | self.auto_encoder.to(self.device) 45 | self.state_value_network.to(self.device) 46 | self.init_wight() 47 | 48 | # 初始化动作的协方差矩阵,以便动作取样 49 | self.cov_var = torch.full(size=(self.action_dim,), fill_value=0.1) 50 | self.cov_mat = torch.diag(self.cov_var).to(self.device) 51 | 52 | def init_wight(self): 53 | self.auto_encoder.apply(weight_init) 54 | self.state_value_network.apply(weight_init) 55 | 56 | def learn(self, batch_data: dict, episode_num: int): 57 | # 从batch data中提取数据 58 | grid_inputs = batch_data['grid_inputs'].to(self.device) 59 | unit_pos = batch_data['unit_pos'].to(self.device) 60 | actions = batch_data['actions'].to(self.device) 61 | log_probs = batch_data['log_probs'].to(self.device) 62 | rewards = batch_data['rewards'] 63 | discount_reward = self.get_discount_reward(rewards).to(self.device) 64 | # 计算状态价值 65 | with torch.no_grad(): 66 | action_map, encoder_output = self.auto_encoder(grid_inputs) 67 | state_value = self.state_value_network(encoder_output) 68 | advantage_function = discount_reward - state_value 69 | # 标准化advantage_function减少环境不确定带来的影响 70 | advantage_function = ((advantage_function - advantage_function.mean()) / ( 71 | advantage_function.std() + 1e-10)).unsqueeze(dim=-1) 72 | 73 | # 开始学习,重度重采样。 74 | for i in range(self.train_config.learn_num): 75 | curr_action_map, curr_encoder_output = self.auto_encoder(grid_inputs) 76 | curr_log_probs = self.get_action_log_probs(curr_action_map, actions, unit_pos) 77 | # 计算loss 78 | ratios = torch.exp(curr_log_probs - log_probs) 79 | surr1 = ratios * advantage_function 80 | surr2 = torch.clamp(ratios, 1 - self.train_config.ppo_loss_clip, 81 | 1 + self.train_config.ppo_loss_clip) * advantage_function 82 | 83 | # actor_loss:取两个函数的最小值 84 | actor_loss = (-torch.min(surr1, surr2)).mean() 85 | self.optimizer_actor.zero_grad() 86 | actor_loss.backward() 87 | self.optimizer_actor.step() 88 | 89 | # critic_loss: td_error的均方误差 90 | curr_state_value = self.state_value_network(curr_encoder_output) 91 | critic_loss = nn.MSELoss()(curr_state_value, discount_reward) 92 | self.optimizer_critic.zero_grad() 93 | critic_loss.backward() 94 | self.optimizer_critic.step() 95 | 96 | def get_action_log_probs(self, action_map: Tensor, actions: Tensor, unit_pos: Tensor) -> Tensor: 97 | action_means = [] 98 | for i, agents_pos in enumerate(unit_pos): 99 | one_step_mean = [] 100 | for agent_num, pos in enumerate(agents_pos): 101 | x = int(pos[0]) 102 | y = int(pos[1]) 103 | action_mean = action_map[i, :, y, x] 104 | one_step_mean.append(action_mean) 105 | action_means.append(torch.stack(one_step_mean, dim=0)) 106 | action_means = torch.stack(action_means, dim=0) 107 | dist = MultivariateNormal(action_means, self.cov_mat) 108 | log_probs = dist.log_prob(actions) 109 | return log_probs 110 | 111 | def get_discount_reward(self, batch_reward: list) -> Tensor: 112 | discount_rewards = [] 113 | for reward in reversed(batch_reward): 114 | discounted_reward = 0 115 | for one_reward in reversed(reward): 116 | discounted_reward = one_reward + discounted_reward * self.train_config.gamma 117 | discount_rewards.insert(0, discounted_reward) 118 | return torch.Tensor(discount_rewards) 119 | 120 | def get_cov_mat(self) -> Tensor: 121 | return self.cov_mat 122 | 123 | def get_action_map(self, grid_input: Tensor) -> Tensor: 124 | with torch.no_grad(): 125 | action_map, _ = self.auto_encoder(grid_input) 126 | return action_map 127 | 128 | def save_model(self): 129 | torch.save(self.auto_encoder.state_dict(), self.auto_encoder_path) 130 | torch.save(self.state_value_network.state_dict(), self.state_value_network_path) 131 | 132 | def load_model(self): 133 | self.auto_encoder.load_state_dict(torch.load(self.auto_encoder_path)) 134 | self.state_value_network.load_state_dict(torch.load(self.state_value_network_path)) 135 | 136 | def del_model(self): 137 | file_list = os.listdir(self.model_path) 138 | for file in file_list: 139 | os.remove(os.path.join(self.model_path, file)) 140 | 141 | def is_saved_model(self) -> bool: 142 | return os.path.exists(self.auto_encoder_path) and os.path.exists(self.state_value_network_path) 143 | -------------------------------------------------------------------------------- /policy/qmix.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from networks.qmix_net import RNN, QMixNet 7 | from policy.base_policy import BasePolicy 8 | from utils.config_utils import ConfigObjectFactory 9 | from utils.train_utils import weight_init 10 | 11 | 12 | class QMix(BasePolicy): 13 | 14 | def __init__(self, env_info: dict): 15 | self.train_config = ConfigObjectFactory.get_train_config() 16 | self.env_config = ConfigObjectFactory.get_environment_config() 17 | self.n_agents = env_info['n_agents'] 18 | self.n_actions = env_info['n_actions'] 19 | input_shape = env_info['obs_space'] + self.n_agents + self.n_actions 20 | state_space = env_info['state_space'] 21 | 22 | # 神经网络 23 | self.rnn_hidden_dim = 64 24 | # 每个agent选动作的网络 25 | self.rnn_eval = RNN(input_shape, self.n_actions, self.rnn_hidden_dim) 26 | self.rnn_target = RNN(input_shape, self.n_actions, self.rnn_hidden_dim) 27 | # 把agentsQ值加起来的网络 28 | self.qmix_net_eval = QMixNet(self.n_agents, state_space) 29 | self.qmix_net_target = QMixNet(self.n_agents, state_space) 30 | self.init_wight() 31 | self.eval_parameters = list(self.qmix_net_eval.parameters()) + list(self.rnn_eval.parameters()) 32 | self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=self.train_config.lr_critic) 33 | 34 | # 初始化路径 35 | self.model_path = os.path.join(self.train_config.model_dir, self.env_config.learn_policy) 36 | self.result_path = os.path.join(self.train_config.result_dir, self.env_config.learn_policy) 37 | self.init_path(self.model_path, self.result_path) 38 | self.rnn_eval_path = os.path.join(self.model_path, "rnn_eval.pth") 39 | self.rnn_target_path = os.path.join(self.model_path, "rnn_target.pth") 40 | self.qmix_net_eval_path = os.path.join(self.model_path, "qmix_net_eval.pth") 41 | self.qmix_net_target_path = os.path.join(self.model_path, "qmix_net_target.pth") 42 | 43 | # 是否使用GPU加速 44 | if self.train_config.cuda: 45 | torch.cuda.empty_cache() 46 | self.device = torch.device('cuda:0') 47 | else: 48 | self.device = torch.device('cpu') 49 | self.rnn_eval.to(self.device) 50 | self.rnn_target.to(self.device) 51 | self.qmix_net_eval.to(self.device) 52 | self.qmix_net_target.to(self.device) 53 | 54 | def init_wight(self): 55 | self.rnn_eval.apply(weight_init) 56 | self.rnn_target.apply(weight_init) 57 | self.qmix_net_eval.apply(weight_init) 58 | self.qmix_net_target.apply(weight_init) 59 | 60 | def learn(self, batch_data: dict, episode_num: int): 61 | obs = batch_data['obs'].to(self.device) 62 | obs_next = batch_data['obs_next'].to(self.device) 63 | state = batch_data['state'].to(self.device) 64 | state_next = batch_data['state_next'].to(self.device) 65 | rewards = batch_data['rewards'].unsqueeze(dim=-1).to(self.device) 66 | actions = batch_data['actions'].long().to(self.device) 67 | actions_onehot = batch_data['actions_onehot'].to(self.device) 68 | terminated = batch_data['terminated'].unsqueeze(dim=-1).to(self.device) 69 | 70 | q_evals, q_targets = [], [] 71 | batch_size = batch_data['sample_size'] 72 | self.init_hidden(batch_size) 73 | for i in range(batch_data['max_step']): 74 | inputs, inputs_next = self._get_inputs(batch_size, i, obs[:, i], obs_next[:, i], 75 | actions_onehot) 76 | q_eval, self.eval_hidden = self.rnn_eval(inputs, self.eval_hidden) 77 | q_target, self.target_hidden = self.rnn_target(inputs_next, self.target_hidden) 78 | # 将q值reshape, 以n_agents分开 79 | q_eval = q_eval.view(batch_size, self.n_agents, -1) 80 | q_target = q_target.view(batch_size, self.n_agents, -1) 81 | q_evals.append(q_eval) 82 | q_targets.append(q_target) 83 | # 将上面的到的q值聚合 84 | q_evals = torch.stack(q_evals, dim=1) 85 | q_targets = torch.stack(q_targets, dim=1) 86 | # 找出所选动作对应的q值 87 | q_evals = torch.gather(q_evals, dim=3, index=actions).squeeze(3) 88 | q_targets = q_targets.max(dim=3)[0] 89 | # 将q值和state输入mix网络 90 | q_total_eval = self.qmix_net_eval(q_evals, state) 91 | q_total_target = self.qmix_net_target(q_targets, state_next) 92 | 93 | targets = rewards + self.train_config.gamma * q_total_target * terminated 94 | td_error = (q_total_eval - targets.detach()) 95 | # 抹掉填充的经验的td_error 96 | masked_td_error = terminated * td_error 97 | # 计算损失函数 98 | loss = (masked_td_error ** 2).sum() / terminated.sum() 99 | self.optimizer.zero_grad() 100 | loss.backward() 101 | torch.nn.utils.clip_grad_norm_(self.eval_parameters, self.train_config.grad_norm_clip) 102 | self.optimizer.step() 103 | if episode_num > 0 and episode_num % self.train_config.target_update_cycle == 0: 104 | self.rnn_target.load_state_dict(self.rnn_eval.state_dict()) 105 | self.qmix_net_target.load_state_dict(self.qmix_net_eval.state_dict()) 106 | 107 | def _get_inputs(self, batch_size: int, batch_index: int, obs: Tensor, obs_next: Tensor, 108 | actions_onehot: Tensor) -> tuple: 109 | """ 110 | 获取q网络的输入值, 将动作放入obs中 111 | :return: 112 | """ 113 | inputs, inputs_next = [], [] 114 | inputs.append(obs) 115 | inputs_next.append(obs_next) 116 | if batch_index == 0: 117 | inputs.append(torch.zeros_like(actions_onehot[:, batch_index])) 118 | else: 119 | inputs.append(actions_onehot[:, batch_index - 1]) 120 | inputs_next.append(actions_onehot[:, batch_index]) 121 | inputs.append(torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)) 122 | inputs_next.append(torch.eye(self.n_agents).unsqueeze(0).expand(batch_size, -1, -1).to(self.device)) 123 | inputs = torch.cat([x.reshape(batch_size * self.n_agents, -1) for x in inputs], dim=1) 124 | inputs_next = torch.cat([x.reshape(batch_size * self.n_agents, -1) for x in inputs_next], dim=1) 125 | return inputs, inputs_next 126 | 127 | def init_hidden(self, batch_size): 128 | # 为每个episode中的每个agent都初始化一个eval_hidden、target_hidden 129 | self.eval_hidden = torch.zeros((batch_size, self.n_agents, self.rnn_hidden_dim)).to(self.device) 130 | self.target_hidden = torch.zeros((batch_size, self.n_agents, self.rnn_hidden_dim)).to(self.device) 131 | 132 | def save_model(self): 133 | torch.save(self.rnn_eval.state_dict(), self.rnn_eval_path) 134 | torch.save(self.rnn_target.state_dict(), self.rnn_target_path) 135 | torch.save(self.qmix_net_eval.state_dict(), self.qmix_net_eval_path) 136 | torch.save(self.qmix_net_target.state_dict(), self.qmix_net_target_path) 137 | 138 | def load_model(self): 139 | self.rnn_eval.load_state_dict(torch.load(self.rnn_eval_path)) 140 | self.rnn_target.load_state_dict(torch.load(self.rnn_target_path)) 141 | self.qmix_net_eval.load_state_dict(torch.load(self.qmix_net_eval_path)) 142 | self.qmix_net_target.load_state_dict(torch.load(self.qmix_net_target_path)) 143 | 144 | def del_model(self): 145 | file_list = os.listdir(self.model_path) 146 | for file in file_list: 147 | os.remove(os.path.join(self.model_path, file)) 148 | 149 | def is_saved_model(self) -> bool: 150 | return os.path.exists(self.rnn_eval_path) and os.path.exists( 151 | self.rnn_target_path) and os.path.exists(self.qmix_net_eval_path) and os.path.exists( 152 | self.qmix_net_target_path) 153 | -------------------------------------------------------------------------------- /policy/grid_wise_control.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from networks.grid_net_actor import AutoEncoder 7 | from networks.grid_net_critic import StateValueModel 8 | from policy.base_policy import BasePolicy 9 | from utils.config_utils import ConfigObjectFactory 10 | from utils.train_utils import weight_init 11 | 12 | 13 | class GridWiseControl(BasePolicy): 14 | 15 | def __init__(self, env_info: dict): 16 | self.train_config = ConfigObjectFactory.get_train_config() 17 | self.env_config = ConfigObjectFactory.get_environment_config() 18 | self.n_agents = env_info["n_agents"] 19 | self.n_actions = env_info["n_actions"] 20 | # 初始化模型可优化器 21 | self.auto_encoder = AutoEncoder(self.n_actions, env_info["grid_input_shape"]) 22 | self.state_value_network_eval = StateValueModel(env_info["grid_input_shape"]) 23 | self.state_value_network_target = StateValueModel(env_info["grid_input_shape"]).requires_grad_(False) 24 | self.optimizer_actor = torch.optim.RMSprop(params=self.auto_encoder.parameters(), lr=self.train_config.lr_actor) 25 | self.optimizer_critic = torch.optim.RMSprop(params=self.state_value_network_eval.parameters(), 26 | lr=self.train_config.lr_critic) 27 | 28 | # 初始化路径 29 | self.model_path = os.path.join(self.train_config.model_dir, self.env_config.learn_policy) 30 | self.result_path = os.path.join(self.train_config.result_dir, self.env_config.learn_policy) 31 | self.init_path(self.model_path, self.result_path) 32 | self.state_value_network_eval_path = os.path.join(self.model_path, "grid_wise_control_state_value_eval.pth") 33 | self.state_value_network_target_path = os.path.join(self.model_path, "grid_wise_control_state_value_target.pth") 34 | self.auto_encoder_path = os.path.join(self.model_path, "grid_wise_control_auto_encoder.pth") 35 | 36 | # 是否使用GPU加速 37 | if self.train_config.cuda: 38 | torch.cuda.empty_cache() 39 | self.device = torch.device('cuda:0') 40 | else: 41 | self.device = torch.device('cpu') 42 | self.auto_encoder.to(self.device) 43 | self.state_value_network_eval.to(self.device) 44 | self.state_value_network_target.to(self.device) 45 | self.init_wight() 46 | 47 | def init_wight(self): 48 | self.auto_encoder.apply(weight_init) 49 | self.state_value_network_eval.apply(weight_init) 50 | self.state_value_network_target.apply(weight_init) 51 | 52 | def learn(self, batch_data: dict, episode_num: int): 53 | grid_inputs = batch_data['grid_inputs'].to(self.device) 54 | grid_inputs_next = batch_data['grid_inputs_next'].to(self.device) 55 | unit_pos = batch_data['unit_pos'].to(self.device) 56 | reward = batch_data['reward'].to(self.device) 57 | actions = batch_data['actions'].long().to(self.device) 58 | terminated = batch_data['terminated'].to(self.device) 59 | mask = terminated.unsqueeze(dim=-1).repeat(1, 1, self.n_agents).to(self.device) 60 | 61 | state_values = [] 62 | state_values_next = [] 63 | for i in range(batch_data['max_step']): 64 | one_grid_input = grid_inputs[:, i] 65 | _, one_encoder_out = self.auto_encoder(one_grid_input) 66 | state_value = self.state_value_network_eval(one_encoder_out) 67 | with torch.no_grad(): 68 | one_grid_input_next = grid_inputs_next[:, i] 69 | _, one_encoder_out_next = self.auto_encoder(one_grid_input_next) 70 | state_value_next = self.state_value_network_target(one_encoder_out_next) 71 | state_values.append(state_value) 72 | state_values_next.append(state_value_next) 73 | # 获取状态价值 74 | state_values = torch.stack(state_values, dim=1) 75 | state_values_next = torch.stack(state_values_next, dim=1) 76 | # 计算td-error,再除去填充的部分 77 | targets = reward + self.train_config.gamma * state_values_next 78 | td_error = (state_values - targets.detach()) 79 | masked_td_error = td_error * terminated 80 | 81 | # 优化critic 82 | loss_critic = (masked_td_error ** 2).sum() / terminated.sum() 83 | self.optimizer_critic.zero_grad() 84 | loss_critic.backward() 85 | # 梯度截断 86 | torch.nn.utils.clip_grad_norm_(list(self.state_value_network_eval.parameters()), 87 | self.train_config.grad_norm_clip) 88 | self.optimizer_critic.step() 89 | 90 | actions_probes = [] 91 | for i in range(batch_data['max_step']): 92 | one_grid_input = grid_inputs[:, i] 93 | one_unit_pos = unit_pos[:, i] 94 | one_action_map, _ = self.auto_encoder(one_grid_input) 95 | one_actions_prob = self.get_actions_prob(one_action_map, one_unit_pos).to(self.device) 96 | actions_probes.append(one_actions_prob) 97 | # 获取动作概率 98 | actions_probes = torch.stack(actions_probes, dim=1) 99 | # 取每个动作对应的概率 100 | pi_taken = torch.gather(actions_probes, dim=3, index=actions).squeeze() 101 | # 因为要取对数,对于那些填充的经验,所有概率都为0,取了log就是负无穷了,所以让它们变成1 102 | pi_taken[mask == 0] = 1.0 103 | log_pi_taken = torch.log(pi_taken) 104 | advantage = masked_td_error.detach().unsqueeze(dim=-1) 105 | # 优化actor 106 | loss_actor = - ((advantage * log_pi_taken) * mask).sum() / mask.sum() 107 | self.optimizer_actor.zero_grad() 108 | loss_actor.backward() 109 | self.optimizer_actor.step() 110 | # 参数截断, 防止梯度爆炸 111 | for parm in self.auto_encoder.parameters(): 112 | parm.data.clamp_(-10, 10) 113 | # 到一定回合数时,target加载eval的最新网络参数 114 | if episode_num > 0 and episode_num % self.train_config.target_update_cycle == 0: 115 | self.state_value_network_target.load_state_dict(self.state_value_network_eval.state_dict()) 116 | 117 | def get_action_map(self, grid_input: Tensor) -> Tensor: 118 | with torch.no_grad(): 119 | action_map, _ = self.auto_encoder(grid_input) 120 | return action_map 121 | 122 | @staticmethod 123 | def get_actions_prob(action_map: Tensor, unit_pos: Tensor) -> Tensor: 124 | actions_prob = [] 125 | for batch_num, pos in enumerate(unit_pos): 126 | batch_actions_prob = [] 127 | for agent_num, one_agent_pos in enumerate(pos): 128 | batch_actions_prob.append(action_map[batch_num, :, int(one_agent_pos[1]), int(one_agent_pos[0])]) 129 | actions_prob.append(torch.stack(batch_actions_prob, dim=0)) 130 | actions_probes = torch.stack(actions_prob, dim=0) 131 | # 进行归一化 132 | actions_probes_final = actions_probes / actions_probes.sum(dim=-1, keepdim=True) 133 | return actions_probes_final 134 | 135 | def save_model(self): 136 | torch.save(self.state_value_network_eval.state_dict(), self.state_value_network_eval_path) 137 | torch.save(self.state_value_network_target.state_dict(), self.state_value_network_target_path) 138 | torch.save(self.auto_encoder.state_dict(), self.auto_encoder_path) 139 | 140 | def load_model(self): 141 | self.state_value_network_eval.load_state_dict(torch.load(self.state_value_network_eval_path)) 142 | self.state_value_network_target.load_state_dict(torch.load(self.state_value_network_target_path)) 143 | self.auto_encoder.load_state_dict(torch.load(self.auto_encoder_path)) 144 | 145 | def del_model(self): 146 | file_list = os.listdir(self.model_path) 147 | for file in file_list: 148 | os.remove(os.path.join(self.model_path, file)) 149 | 150 | def is_saved_model(self) -> bool: 151 | return os.path.exists(self.auto_encoder_path) and os.path.exists( 152 | self.state_value_network_eval_path) and os.path.exists(self.state_value_network_target_path) 153 | -------------------------------------------------------------------------------- /agent/agents.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | from torch.distributions import Categorical, MultivariateNormal 7 | 8 | from policy.centralized_ppo import CentralizedPPO 9 | from policy.grid_wise_control import GridWiseControl 10 | from policy.grid_wise_control_ddpg import GridWiseControlDDPG 11 | from policy.grid_wise_control_ppo import GridWiseControlPPO 12 | from policy.independent_ppo import IndependentPPO 13 | from policy.qmix import QMix 14 | from utils.config_utils import ConfigObjectFactory 15 | 16 | 17 | class MyAgents: 18 | def __init__(self, env_info: dict): 19 | self.env_info = env_info 20 | self.train_config = ConfigObjectFactory.get_train_config() 21 | self.env_config = ConfigObjectFactory.get_environment_config() 22 | self.n_agents = self.env_info['n_agents'] 23 | 24 | if self.train_config.cuda: 25 | torch.cuda.empty_cache() 26 | self.device = torch.device('cuda:0') 27 | else: 28 | self.device = torch.device('cpu') 29 | 30 | # 初始化学习策略,需要注意的是不同算法对应不同的动作空间(连续/离散) 31 | if self.env_config.learn_policy == "grid_wise_control": 32 | self.n_actions = self.env_info['n_actions'] 33 | self.policy = GridWiseControl(self.env_info) 34 | 35 | elif self.env_config.learn_policy == "grid_wise_control+ddpg": 36 | self.action_space = self.env_info['action_space'] 37 | self.policy = GridWiseControlDDPG(self.env_info) 38 | 39 | elif self.env_config.learn_policy == "grid_wise_control+ppo": 40 | self.action_space = self.env_info['action_space'] 41 | self.policy = GridWiseControlPPO(self.env_info) 42 | 43 | # 下面三个算法作为baseline 44 | elif self.env_config.learn_policy == "qmix": 45 | self.n_actions = self.env_info['n_actions'] 46 | self.policy = QMix(self.env_info) 47 | 48 | elif self.env_config.learn_policy == "centralized_ppo": 49 | self.action_space = self.env_info['action_space'] 50 | self.policy = CentralizedPPO(self.env_info) 51 | 52 | elif self.env_config.learn_policy == "independent_ppo": 53 | self.action_space = self.env_info['action_space'] 54 | self.policy = IndependentPPO(self.env_info) 55 | 56 | else: 57 | raise ValueError( 58 | "learn_policy error, just support grid_wise_control, grid_wise_control+ddpg, grid_wise_control+ppo, " 59 | "qmix, centralized_ppo") 60 | 61 | def learn(self, batch_data: dict, episode_num: int = 0): 62 | self.policy.learn(batch_data, episode_num) 63 | 64 | def choose_actions_in_grid(self, unit_pos: list, grid_input: Tensor) -> tuple: 65 | actions_with_name = {} 66 | actions = [] 67 | log_probs = [] 68 | action_map = None 69 | if self.train_config.cuda: 70 | grid_input = grid_input.to(self.device) 71 | if isinstance(self.policy, GridWiseControl) or isinstance(self.policy, GridWiseControlDDPG) or isinstance( 72 | self.policy, GridWiseControlPPO): 73 | action_map = self.policy.get_action_map(grid_input) 74 | for agent_name, pos in zip(self.env_info['agents_name'], unit_pos): 75 | pos_x = pos[0] 76 | pos_y = pos[1] 77 | action_prop = action_map[0, :, pos_y, pos_x] 78 | if self.env_config.learn_policy == "grid_wise_control": 79 | action = Categorical(action_prop).sample().int() 80 | actions_with_name[agent_name] = (int(action)) 81 | actions.append(int(action)) 82 | elif isinstance(self.policy, GridWiseControlPPO): 83 | dist = MultivariateNormal(action_prop, self.policy.get_cov_mat()) 84 | action = np.clip(dist.sample().cpu().numpy(), self.action_space.low, 85 | self.action_space.high).astype(dtype=np.float32) 86 | log_probs.append(dist.log_prob(torch.Tensor(action).to(self.device))) 87 | actions_with_name[agent_name] = action 88 | actions.append(action) 89 | else: 90 | action_with_noise = np.clip( 91 | np.random.normal(action_prop.cpu().numpy(), self.train_config.var), self.action_space.low, 92 | self.action_space.high).astype(dtype=np.float32) 93 | actions_with_name[agent_name] = action_with_noise 94 | actions.append(action_with_noise) 95 | return actions_with_name, actions, log_probs 96 | 97 | def choose_actions(self, obs: dict) -> tuple: 98 | actions_with_name = {} 99 | actions = [] 100 | log_probs = [] 101 | obs = torch.stack([torch.Tensor(value) for value in obs.values()], dim=0) 102 | self.policy.init_hidden(1) 103 | if isinstance(self.policy, QMix): 104 | actions_ind = [i for i in range(self.n_actions)] 105 | for i, agent in enumerate(self.env_info['agents_name']): 106 | inputs = list() 107 | inputs.append(obs[i, :]) 108 | inputs.append(torch.zeros(self.n_actions)) 109 | agent_id = torch.zeros(self.n_agents) 110 | agent_id[i] = 1 111 | inputs.append(agent_id) 112 | inputs = torch.cat(inputs).unsqueeze(dim=0).to(self.device) 113 | with torch.no_grad(): 114 | hidden_state = self.policy.eval_hidden[:, i, :] 115 | q_value, _ = self.policy.rnn_eval(inputs, hidden_state) 116 | if random.uniform(0, 1) > self.train_config.epsilon: 117 | action = random.sample(actions_ind, 1)[0] 118 | else: 119 | action = int(torch.argmax(q_value.squeeze())) 120 | actions_with_name[agent] = action 121 | actions.append(action) 122 | elif isinstance(self.policy, CentralizedPPO): 123 | obs = obs.reshape(1, -1).to(self.device) 124 | with torch.no_grad(): 125 | action_means, _ = self.policy.ppo_actor(obs, self.policy.rnn_hidden) 126 | for i, agent_name in enumerate(self.env_info['agents_name']): 127 | action_mean = action_means[:, i].squeeze() 128 | dist = MultivariateNormal(action_mean, self.policy.get_cov_mat()) 129 | action = np.clip(dist.sample().cpu().numpy(), self.action_space.low, 130 | self.action_space.high).astype(dtype=np.float32) 131 | log_probs.append(dist.log_prob(torch.Tensor(action).to(self.device))) 132 | actions_with_name[agent_name] = action 133 | actions.append(action) 134 | elif isinstance(self.policy, IndependentPPO): 135 | obs = obs.to(self.device) 136 | for i, agent_name in enumerate(self.env_info['agents_name']): 137 | with torch.no_grad(): 138 | action_mean, _ = self.policy.ppo_actor(obs[i].unsqueeze(dim=0), self.policy.rnn_hidden[i]) 139 | action_mean = action_mean.squeeze() 140 | dist = MultivariateNormal(action_mean, self.policy.get_cov_mat()) 141 | action = np.clip(dist.sample().cpu().numpy(), self.action_space.low, 142 | self.action_space.high).astype(dtype=np.float32) 143 | log_probs.append(dist.log_prob(torch.Tensor(action).to(self.device))) 144 | actions_with_name[agent_name] = action 145 | actions.append(action) 146 | return actions_with_name, actions, log_probs 147 | 148 | def save_model(self): 149 | self.policy.save_model() 150 | 151 | def load_model(self): 152 | self.policy.load_model() 153 | 154 | def del_model(self): 155 | self.policy.del_model() 156 | 157 | def is_saved_model(self) -> bool: 158 | return self.policy.is_saved_model() 159 | 160 | def get_results_path(self): 161 | return self.policy.result_path 162 | -------------------------------------------------------------------------------- /policy/grid_wise_control_ddpg.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from networks.grid_net_actor import AutoEncoderContinuousActions 7 | from networks.grid_net_critic import QValueModelDDPG 8 | from policy.base_policy import BasePolicy 9 | from utils.config_utils import ConfigObjectFactory 10 | from utils.train_utils import weight_init 11 | 12 | 13 | class GridWiseControlDDPG(BasePolicy): 14 | def __init__(self, env_info: dict): 15 | # 读取配置 16 | self.train_config = ConfigObjectFactory.get_train_config() 17 | self.env_config = ConfigObjectFactory.get_environment_config() 18 | self.n_agents = env_info["n_agents"] 19 | self.action_dim = env_info["action_dim"] 20 | # 初始化各种网络 21 | self.auto_encoder_eval = AutoEncoderContinuousActions(env_info["grid_input_shape"]) 22 | self.auto_encoder_target = AutoEncoderContinuousActions(env_info["grid_input_shape"]).requires_grad_(False) 23 | self.q_value_network_eval = QValueModelDDPG(env_info["grid_input_shape"], self.n_agents, 24 | self.action_dim) 25 | self.q_value_network_target = QValueModelDDPG(env_info["grid_input_shape"], 26 | self.n_agents, self.action_dim).requires_grad_(False) 27 | self.optimizer_actor = torch.optim.RMSprop(params=self.auto_encoder_eval.parameters(), 28 | lr=self.train_config.lr_actor) 29 | self.optimizer_critic = torch.optim.RMSprop(params=self.q_value_network_eval.parameters(), 30 | lr=self.train_config.lr_critic) 31 | # 初始化路径 32 | self.model_path = os.path.join(self.train_config.model_dir, self.env_config.learn_policy) 33 | self.result_path = os.path.join(self.train_config.result_dir, self.env_config.learn_policy) 34 | self.init_path(self.model_path, self.result_path) 35 | self.q_value_network_eval_path = os.path.join(self.model_path, 36 | "grid_wise_control_ddpg_q_value_eval.pth") 37 | self.q_value_network_target_path = os.path.join(self.model_path, 38 | "grid_wise_control_ddpg_q_value_target.pth") 39 | self.auto_encoder_eval_path = os.path.join(self.model_path, "grid_wise_control_ddpg_auto_encoder_eval.pth") 40 | self.auto_encoder_target_path = os.path.join(self.model_path, "grid_wise_control_ddpg_auto_encoder_target.pth") 41 | 42 | # 是否使用GPU加速 43 | if self.train_config.cuda: 44 | torch.cuda.empty_cache() 45 | self.device = torch.device('cuda:0') 46 | else: 47 | self.device = torch.device('cpu') 48 | self.auto_encoder_eval.to(self.device) 49 | self.auto_encoder_target.to(self.device) 50 | self.q_value_network_eval.to(self.device) 51 | self.q_value_network_target.to(self.device) 52 | self.init_wight() 53 | 54 | def init_wight(self): 55 | self.auto_encoder_eval.apply(weight_init) 56 | self.auto_encoder_target.apply(weight_init) 57 | self.q_value_network_eval.apply(weight_init) 58 | self.q_value_network_target.apply(weight_init) 59 | 60 | def learn(self, batch_data: dict, episode_num: int): 61 | grid_inputs = batch_data['grid_inputs'].to(self.device) 62 | grid_inputs_next = batch_data['grid_inputs_next'].to(self.device) 63 | unit_pos = batch_data['unit_pos'].to(self.device) 64 | reward = batch_data['reward'].to(self.device) 65 | actions = batch_data['actions'].to(self.device).squeeze() 66 | terminated = batch_data['terminated'].to(self.device) 67 | q_eval = [] 68 | q_target = [] 69 | for i in range(batch_data['max_step']): 70 | one_grid_input = grid_inputs[:, i] 71 | one_unit_pos = unit_pos[:, i] 72 | one_action_map, one_encoder_out = self.auto_encoder_eval(one_grid_input) 73 | one_actions_output = self.get_actions_output(one_action_map, one_unit_pos).to(self.device) 74 | one_q_eval = self.q_value_network_eval(one_encoder_out, one_actions_output) 75 | with torch.no_grad(): 76 | one_grid_input_next = grid_inputs_next[:, i] 77 | one_action_map_next, one_encoder_out_next = self.auto_encoder_target(one_grid_input_next) 78 | one_actions_output_next = self.get_actions_output(one_action_map_next, one_unit_pos).to(self.device) 79 | one_q_target = self.q_value_network_target(one_encoder_out_next, one_actions_output_next) 80 | 81 | q_eval.append(one_q_eval) 82 | q_target.append(one_q_target) 83 | # 获取动作价值 84 | q_eval = torch.stack(q_eval, dim=1).squeeze() 85 | q_target = torch.stack(q_target, dim=1).squeeze().detach() 86 | # 计算td-error,再除去填充的部分 87 | targets = reward + self.train_config.gamma * q_target 88 | td_error = (q_eval - targets.detach()) 89 | masked_td_error = td_error * terminated 90 | # 优化critic 91 | loss_critic = (masked_td_error ** 2).sum() / terminated.sum() 92 | self.optimizer_critic.zero_grad() 93 | loss_critic.backward() 94 | # 梯度截断 95 | torch.nn.utils.clip_grad_norm_(list(self.q_value_network_eval.parameters()), 96 | self.train_config.grad_norm_clip) 97 | self.optimizer_critic.step() 98 | 99 | # 获取actor的q值 100 | q_value = [] 101 | for i in range(batch_data['max_step']): 102 | one_grid_input = grid_inputs[:, i] 103 | one_action = actions[:, i] 104 | _, one_encoder_out = self.auto_encoder_eval(one_grid_input) 105 | one_q_value = self.q_value_network_eval(one_encoder_out, one_action) 106 | q_value.append(one_q_value) 107 | q_value = torch.stack(q_value, dim=1).squeeze() 108 | 109 | # 优化actor 110 | loss_actor = - (q_value * terminated).sum() / terminated.sum() 111 | self.optimizer_actor.zero_grad() 112 | loss_actor.backward() 113 | self.optimizer_actor.step() 114 | # 参数截断, 防止梯度爆炸 115 | for parm in self.auto_encoder_eval.parameters(): 116 | parm.data.clamp_(-10, 10) 117 | # 到一定回合数时,target加载eval的最新网络参数 118 | if episode_num > 0 and episode_num % self.train_config.target_update_cycle == 0: 119 | self.auto_encoder_target.load_state_dict(self.auto_encoder_eval.state_dict()) 120 | self.q_value_network_target.load_state_dict(self.q_value_network_eval.state_dict()) 121 | 122 | def get_action_map(self, grid_input: Tensor) -> Tensor: 123 | with torch.no_grad(): 124 | action_map, _ = self.auto_encoder_eval(grid_input) 125 | return action_map 126 | 127 | @staticmethod 128 | def get_actions_output(action_map: Tensor, unit_pos: Tensor) -> Tensor: 129 | actions_output = [] 130 | for batch_num, pos in enumerate(unit_pos): 131 | batch_actions_output = [] 132 | for agent_num, one_agent_pos in enumerate(pos): 133 | batch_actions_output.append(action_map[batch_num, :, int(one_agent_pos[1]), int(one_agent_pos[0])]) 134 | actions_output.append(torch.stack(batch_actions_output, dim=0)) 135 | actions_outputs = torch.stack(actions_output, dim=0) 136 | return actions_outputs 137 | 138 | def save_model(self): 139 | torch.save(self.q_value_network_eval.state_dict(), self.q_value_network_eval_path) 140 | torch.save(self.q_value_network_target.state_dict(), self.q_value_network_target_path) 141 | torch.save(self.auto_encoder_eval.state_dict(), self.auto_encoder_eval_path) 142 | torch.save(self.auto_encoder_target.state_dict(), self.auto_encoder_target_path) 143 | 144 | def load_model(self): 145 | self.q_value_network_eval.load_state_dict(torch.load(self.q_value_network_eval_path)) 146 | self.q_value_network_target.load_state_dict(torch.load(self.q_value_network_target_path)) 147 | self.auto_encoder_eval.load_state_dict(torch.load(self.auto_encoder_eval_path)) 148 | self.auto_encoder_target.load_state_dict(torch.load(self.auto_encoder_target_path)) 149 | 150 | def del_model(self): 151 | file_list = os.listdir(self.model_path) 152 | for file in file_list: 153 | os.remove(os.path.join(self.model_path, file)) 154 | 155 | def is_saved_model(self) -> bool: 156 | return os.path.exists(self.auto_encoder_eval_path) and os.path.exists( 157 | self.auto_encoder_target_path) and os.path.exists(self.q_value_network_eval_path) and os.path.exists( 158 | self.q_value_network_target_path) 159 | -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | import pickle 4 | 5 | from agent.agents import MyAgents 6 | from common.pettingzoo_environment import SimpleSpreadEnv 7 | from common.reply_buffer import * 8 | 9 | 10 | class RunnerSimpleSpreadEnv(object): 11 | 12 | def __init__(self, env: SimpleSpreadEnv): 13 | self.env = env 14 | self.train_config = ConfigObjectFactory.get_train_config() 15 | self.env_config = ConfigObjectFactory.get_environment_config() 16 | self.current_epoch = 0 17 | self.result_buffer = [] 18 | self.env_info = self.env.get_env_info() 19 | self.agents = MyAgents(self.env_info) 20 | # 初始化reply_buffer 21 | if "grid_wise_control+ppo" == self.env_config.learn_policy: 22 | # 需要注意的是ppo算法是每局游戏内进行更新的,因此只需要搜集每局游戏的数据,并不需要把所有局的游戏整合在一起再采样 23 | self.memory = None 24 | self.batch_episode_memory = GridBatchEpisodeMemory() 25 | elif "grid_wise_control" in self.env_config.learn_policy: 26 | self.memory = GridMemory() 27 | self.batch_episode_memory = GridBatchEpisodeMemory() 28 | elif "centralized_ppo" == self.env_config.learn_policy or "independent_ppo" == self.env_config.learn_policy: 29 | self.memory = None 30 | self.batch_episode_memory = CommBatchEpisodeMemory(continuous_actions=True) 31 | else: 32 | self.memory = CommMemory() 33 | self.batch_episode_memory = CommBatchEpisodeMemory(continuous_actions=False, 34 | n_actions=self.env_info['n_actions'], 35 | n_agents=self.env_info['n_agents']) 36 | self.lock = threading.Lock() 37 | # 初始化路径 38 | self.results_path = self.agents.get_results_path() 39 | self.memory_path = os.path.join(self.results_path, "memory.txt") 40 | self.result_path = os.path.join(self.results_path, "result.csv") 41 | 42 | def run_marl(self): 43 | self.init_saved_model() 44 | run_episode = self.train_config.run_episode_before_train if "ppo" in self.env_config.learn_policy else 1 45 | for epoch in range(self.current_epoch, self.train_config.epochs + 1): 46 | # 在正式开始训练之前做一些动作并将信息存进记忆单元中 47 | # grid_wise_control系列算法和常规marl算法不同, 是以格子作为观测空间。 48 | # ppo 属于on policy算法,训练数据要是同策略的 49 | total_reward = 0 50 | if "grid_wise_control" in self.env_config.learn_policy and isinstance(self.batch_episode_memory, 51 | GridBatchEpisodeMemory): 52 | for i in range(run_episode): 53 | self.env.reset() 54 | finish_game = False 55 | cycle = 0 56 | while not finish_game and cycle < self.env_config.max_cycles: 57 | grid_input = self.env.get_grid_input() 58 | unit_pos = self.env.get_agents_approximate_pos() 59 | actions_with_name, actions, log_probs = self.agents.choose_actions_in_grid(unit_pos=unit_pos, 60 | grid_input=grid_input) 61 | observations, rewards, finish_game, infos = self.env.step(actions_with_name) 62 | grid_input_next = self.env.get_grid_input() 63 | self.batch_episode_memory.store_one_episode(grid_input, grid_input_next, unit_pos, 64 | actions, rewards, log_probs) 65 | total_reward += rewards 66 | cycle += 1 67 | self.batch_episode_memory.set_per_episode_len(cycle) 68 | elif isinstance(self.batch_episode_memory, CommBatchEpisodeMemory): 69 | 70 | for i in range(run_episode): 71 | obs = self.env.reset()[0] 72 | finish_game = False 73 | cycle = 0 74 | while not finish_game and cycle < self.env_config.max_cycles: 75 | state = self.env.state() 76 | actions_with_name, actions, log_probs = self.agents.choose_actions(obs) 77 | obs_next, rewards, finish_game, infos = self.env.step(actions_with_name) 78 | state_next = self.env.state() 79 | if "ppo" in self.env_config.learn_policy: 80 | self.batch_episode_memory.store_one_episode(one_obs=obs, one_state=state, action=actions, 81 | reward=rewards, log_probs=log_probs) 82 | else: 83 | self.batch_episode_memory.store_one_episode(one_obs=obs, one_state=state, action=actions, 84 | reward=rewards, one_obs_next=obs_next, 85 | one_state_next=state_next) 86 | total_reward += rewards 87 | obs = obs_next 88 | cycle += 1 89 | self.batch_episode_memory.set_per_episode_len(cycle) 90 | if "ppo" in self.env_config.learn_policy: 91 | # 可以用一个policy跑一个batch的数据来收集,由于性能问题假设batch=1,后续来优化 92 | batch_data = self.batch_episode_memory.get_batch_data() 93 | self.agents.learn(batch_data) 94 | self.batch_episode_memory.clear_memories() 95 | else: 96 | self.memory.store_episode(self.batch_episode_memory) 97 | self.batch_episode_memory.clear_memories() 98 | if self.memory.get_memory_real_size() >= 10: 99 | for i in range(self.train_config.learn_num): 100 | batch = self.memory.sample(self.train_config.memory_batch) 101 | self.agents.learn(batch, epoch) 102 | # avg_reward = self.evaluate() 103 | avg_reward = total_reward / run_episode 104 | one_result_buffer = [avg_reward] 105 | self.result_buffer.append(one_result_buffer) 106 | if epoch % self.train_config.save_epoch == 0 and epoch != 0: 107 | self.save_model_and_result(epoch) 108 | print("episode_{} over,avg_reward {}".format(epoch, avg_reward)) 109 | 110 | def init_saved_model(self): 111 | if os.path.exists(self.result_path) and ( 112 | os.path.exists(self.memory_path) or "ppo" in self.env_config.learn_policy) \ 113 | and self.agents.is_saved_model(): 114 | if "ppo" not in self.env_config.learn_policy: 115 | with open(self.memory_path, 'rb') as f: 116 | self.memory = pickle.load(f) 117 | self.current_epoch = self.memory.episode + 1 118 | self.result_buffer.clear() 119 | else: 120 | with open(self.result_path, 'r') as f: 121 | count = 0 122 | for _ in csv.reader(f): 123 | count += 1 124 | self.current_epoch = count 125 | self.result_buffer.clear() 126 | self.agents.load_model() 127 | else: 128 | self.agents.del_model() 129 | file_list = os.listdir(self.results_path) 130 | for file in file_list: 131 | os.remove(os.path.join(self.results_path, file)) 132 | 133 | def save_model_and_result(self, episode: int): 134 | self.agents.save_model() 135 | with open(self.result_path, 'a', newline='') as f: 136 | f_csv = csv.writer(f) 137 | f_csv.writerows(self.result_buffer) 138 | self.result_buffer.clear() 139 | if "ppo" not in self.env_config.learn_policy: 140 | with open(self.memory_path, 'wb') as f: 141 | self.memory.episode = episode 142 | pickle.dump(self.memory, f) 143 | 144 | def evaluate(self): 145 | total_rewards = 0 146 | for i in range(self.train_config.evaluate_epoch): 147 | if "grid_wise_control" in self.env_config.learn_policy: 148 | self.env.reset() 149 | terminated = False 150 | cycle = 0 151 | while not terminated and cycle < self.env_config.max_cycles: 152 | grid_input = self.env.get_grid_input() 153 | unit_pos = self.env.get_agents_approximate_pos() 154 | actions_with_name, _, _ = self.agents.choose_actions_in_grid(unit_pos=unit_pos, 155 | grid_input=grid_input) 156 | _, rewards, finish_game, _ = self.env.step(actions_with_name) 157 | total_rewards += rewards 158 | cycle += 1 159 | else: 160 | obs = self.env.reset()[0] 161 | finish_game = False 162 | cycle = 0 163 | while not finish_game and cycle < self.env_config.max_cycles: 164 | actions_with_name, actions, _ = self.agents.choose_actions(obs) 165 | obs_next, rewards, finish_game, _ = self.env.step(actions_with_name) 166 | total_rewards += rewards 167 | obs = obs_next 168 | cycle += 1 169 | return total_rewards / self.train_config.evaluate_epoch 170 | -------------------------------------------------------------------------------- /common/reply_buffer.py: -------------------------------------------------------------------------------- 1 | import threading 2 | from random import sample 3 | 4 | import torch 5 | from numpy import ndarray 6 | from torch import Tensor 7 | 8 | from utils.config_utils import * 9 | from utils.train_utils import reshape_tensor_from_list 10 | 11 | LOCK_MEMORY = threading.Lock() 12 | 13 | 14 | class GridBatchEpisodeMemory(object): 15 | """ 16 | 每一局游戏的临时存储单元, 17 | 一局游戏完毕后将临时数据放入Memory中, 18 | 并清空数据。 19 | """ 20 | 21 | def __init__(self): 22 | self.grid_inputs = [] 23 | self.grid_inputs_next = [] 24 | self.rewards = [] 25 | self.unit_poses = [] 26 | self.unit_actions = [] 27 | self.log_probs = [] 28 | self.per_episode_len = [] 29 | self.n_step = 0 30 | 31 | def store_one_episode(self, grid_input: Tensor, grid_input_next: Tensor, 32 | unit_pos: list, action: list, reward: float, log_probs: list): 33 | self.grid_inputs.append(grid_input) 34 | self.grid_inputs_next.append(grid_input_next) 35 | self.unit_poses.append(unit_pos) 36 | self.unit_actions.append(action) 37 | self.rewards.append(reward) 38 | self.log_probs.append(log_probs) 39 | self.n_step += 1 40 | 41 | def clear_memories(self): 42 | self.grid_inputs.clear() 43 | self.grid_inputs_next.clear() 44 | self.unit_poses.clear() 45 | self.unit_actions.clear() 46 | self.rewards.clear() 47 | self.log_probs.clear() 48 | self.per_episode_len.clear() 49 | self.n_step = 0 50 | 51 | def set_per_episode_len(self, episode_len: int): 52 | self.per_episode_len.append(episode_len) 53 | 54 | def get_batch_data(self) -> dict: 55 | """ 56 | 将一个bact的reward恢复成[batch_num, episode_len]的形式, 57 | 以便于后序计算每个episode的discount reward 58 | 其他的数据直接融合[batch_num * episode_len, -1] 的形式 59 | :return:一个batch的数据使用字典封装 60 | """ 61 | grid_inputs = torch.cat(self.grid_inputs, dim=0) 62 | unit_pos = torch.Tensor(self.unit_poses) 63 | actions = torch.Tensor(self.unit_actions) 64 | rewards = reshape_tensor_from_list(torch.Tensor(self.rewards), self.per_episode_len) 65 | log_probs = torch.Tensor(self.log_probs) 66 | data = { 67 | 'grid_inputs': grid_inputs, 68 | 'unit_pos': unit_pos, 69 | 'rewards': rewards, 70 | 'actions': actions, 71 | 'log_probs': log_probs 72 | } 73 | return data 74 | 75 | 76 | class GridMemory(object): 77 | """ 78 | 记忆单元,保存了之前所有局游戏的数据 79 | """ 80 | 81 | def __init__(self): 82 | self.train_config = ConfigObjectFactory.get_train_config() 83 | self.memory_size = self.train_config.memory_size 84 | self.current_idx = 0 85 | self.memory = [] 86 | 87 | def store_episode(self, one_episode_memory: GridBatchEpisodeMemory): 88 | with LOCK_MEMORY: 89 | grid_inputs = torch.cat(one_episode_memory.grid_inputs, dim=0) 90 | grid_inputs_next = torch.cat(one_episode_memory.grid_inputs_next, dim=0) 91 | unit_pos = torch.Tensor(one_episode_memory.unit_poses) 92 | actions = torch.Tensor(one_episode_memory.unit_actions) 93 | reward = torch.Tensor(one_episode_memory.rewards) 94 | data = { 95 | 'grid_inputs': grid_inputs, 96 | 'grid_inputs_next': grid_inputs_next, 97 | 'unit_pos': unit_pos, 98 | 'reward': reward, 99 | 'actions': actions, 100 | 'n_step': one_episode_memory.n_step 101 | } 102 | if len(self.memory) < self.memory_size: 103 | self.memory.append(data) 104 | else: 105 | self.memory[self.current_idx % self.memory_size] = data 106 | self.current_idx += 1 107 | 108 | def sample(self, batch_size) -> dict: 109 | """ 110 | 从记忆单元中随机抽样,但是每一局游戏的step不同,找出这个batch中 111 | 最大的那个,将其他游戏的数据补齐 112 | :param batch_size: 一个batch的大小 113 | :return: 一个batch的数据 114 | """ 115 | sample_size = min(len(self.memory), batch_size) 116 | sample_list = sample(self.memory, sample_size) 117 | n_step = torch.Tensor([one_data['n_step'] for one_data in sample_list]) 118 | max_step = int(torch.max(n_step)) 119 | 120 | grid_inputs = torch.stack( 121 | [torch.cat([one_data['grid_inputs'], 122 | torch.zeros([max_step - one_data['grid_inputs'].shape[0]] + 123 | list(one_data['grid_inputs'].shape[1:])) 124 | ]) 125 | for one_data in sample_list], dim=0).detach() 126 | 127 | grid_inputs_next = torch.stack( 128 | [torch.cat([one_data['grid_inputs_next'], 129 | torch.zeros(size=[max_step - one_data['grid_inputs_next'].shape[0]] + 130 | list(one_data['grid_inputs_next'].shape[1:]))]) 131 | for one_data in sample_list], dim=0).detach() 132 | 133 | unit_pos = torch.stack( 134 | [torch.cat([one_data['unit_pos'], 135 | torch.zeros([max_step - one_data['unit_pos'].shape[0]] + 136 | list(one_data['unit_pos'].shape[1:]))]) 137 | for one_data in sample_list], dim=0).detach() 138 | 139 | reward = torch.stack( 140 | [torch.cat([one_data['reward'], 141 | torch.zeros([max_step - one_data['reward'].shape[0]] + 142 | list(one_data['reward'].shape[1:]))]) 143 | for one_data in sample_list], dim=0).detach() 144 | 145 | actions = torch.stack( 146 | [torch.cat([one_data['actions'], 147 | torch.zeros([max_step - one_data['actions'].shape[0]] + 148 | list(one_data['actions'].shape[1:]))]) 149 | for one_data in sample_list], dim=0).unsqueeze(dim=-1).detach() 150 | 151 | terminated = torch.stack( 152 | [torch.cat([torch.ones(one_data['n_step']), torch.zeros(max_step - one_data['n_step'])]) 153 | for one_data in sample_list], dim=0).detach() 154 | 155 | batch_data = { 156 | 'grid_inputs': grid_inputs, 157 | 'grid_inputs_next': grid_inputs_next, 158 | 'unit_pos': unit_pos, 159 | 'reward': reward, 160 | 'actions': actions, 161 | 'max_step': max_step, 162 | 'sample_size': sample_size, 163 | 'terminated': terminated 164 | } 165 | return batch_data 166 | 167 | def get_memory_real_size(self): 168 | return len(self.memory) 169 | 170 | 171 | class CommBatchEpisodeMemory(object): 172 | """ 173 | 存储每局游戏的记忆单元, 适用于常规marl算法(grid_net除外) 174 | """ 175 | 176 | def __init__(self, continuous_actions: bool, n_actions: int = 0, n_agents: int = 0): 177 | self.continuous_actions = continuous_actions 178 | self.n_actions = n_actions 179 | self.n_agents = n_agents 180 | self.obs = [] 181 | self.obs_next = [] 182 | self.state = [] 183 | self.state_next = [] 184 | self.rewards = [] 185 | self.unit_actions = [] 186 | self.log_probs = [] 187 | self.unit_actions_onehot = [] 188 | self.per_episode_len = [] 189 | self.n_step = 0 190 | 191 | def store_one_episode(self, one_obs: dict, one_state: ndarray, action: list, reward: float, 192 | one_obs_next: dict = None, one_state_next: ndarray = None, log_probs: list = None): 193 | one_obs = torch.stack([torch.Tensor(value) for value in one_obs.values()], dim=0) 194 | self.obs.append(one_obs) 195 | self.state.append(torch.Tensor(one_state)) 196 | self.rewards.append(reward) 197 | self.unit_actions.append(action) 198 | if one_obs_next is not None: 199 | one_obs_next = torch.stack([torch.Tensor(value) for value in one_obs_next.values()], dim=0) 200 | self.obs_next.append(one_obs_next) 201 | if one_state_next is not None: 202 | self.state_next.append(torch.Tensor(one_state_next)) 203 | if log_probs is not None: 204 | self.log_probs.append(log_probs) 205 | if not self.continuous_actions: 206 | self.unit_actions_onehot.append( 207 | torch.zeros(self.n_agents, self.n_actions).scatter_(1, torch.LongTensor(action).unsqueeze(dim=-1), 1)) 208 | self.n_step += 1 209 | 210 | def clear_memories(self): 211 | self.obs.clear() 212 | self.obs_next.clear() 213 | self.state.clear() 214 | self.state_next.clear() 215 | self.rewards.clear() 216 | self.log_probs.clear() 217 | self.unit_actions.clear() 218 | self.unit_actions_onehot.clear() 219 | self.per_episode_len.clear() 220 | self.n_step = 0 221 | 222 | def set_per_episode_len(self, episode_len: int): 223 | self.per_episode_len.append(episode_len) 224 | 225 | def get_batch_data(self) -> dict: 226 | """ 227 | 获取一个batch的数据 228 | :return:一个batch的数据使用字典封装 229 | """ 230 | obs = torch.stack(self.obs, dim=0) 231 | state = torch.stack(self.state, dim=0) 232 | rewards = reshape_tensor_from_list(torch.Tensor(self.rewards), self.per_episode_len) 233 | actions = torch.Tensor(self.unit_actions) 234 | log_probs = torch.Tensor(self.log_probs) 235 | data = { 236 | 'obs': obs, 237 | 'state': state, 238 | 'rewards': rewards, 239 | 'actions': actions, 240 | 'log_probs': log_probs, 241 | 'per_episode_len': self.per_episode_len 242 | } 243 | return data 244 | 245 | 246 | class CommMemory(object): 247 | """ 248 | 存储所有游戏的记忆单元, 适用于常规marl算法(grid_net除外) 249 | """ 250 | 251 | def __init__(self): 252 | self.train_config = ConfigObjectFactory.get_train_config() 253 | self.memory_size = self.train_config.memory_size 254 | self.current_idx = 0 255 | self.memory = [] 256 | 257 | def store_episode(self, one_episode_memory: CommBatchEpisodeMemory): 258 | with LOCK_MEMORY: 259 | obs = torch.stack(one_episode_memory.obs, dim=0) 260 | obs_next = torch.stack(one_episode_memory.obs_next, dim=0) 261 | state = torch.stack(one_episode_memory.state, dim=0) 262 | state_next = torch.stack(one_episode_memory.state_next, dim=0) 263 | actions = torch.Tensor(one_episode_memory.unit_actions) 264 | actions_onehot = torch.stack(one_episode_memory.unit_actions_onehot, dim=0) 265 | reward = torch.Tensor(one_episode_memory.rewards) 266 | data = { 267 | 'obs': obs, 268 | 'obs_next': obs_next, 269 | 'state': state, 270 | 'state_next': state_next, 271 | 'rewards': reward, 272 | 'actions': actions, 273 | 'actions_onehot': actions_onehot, 274 | 'n_step': one_episode_memory.n_step 275 | } 276 | if len(self.memory) < self.memory_size: 277 | self.memory.append(data) 278 | else: 279 | self.memory[self.current_idx % self.memory_size] = data 280 | self.current_idx += 1 281 | 282 | def sample(self, batch_size) -> dict: 283 | """ 284 | 从记忆单元中随机抽样,但是每一局游戏的step不同,找出这个batch中 285 | 最大的那个,将其他游戏的数据补齐 286 | :param batch_size: 一个batch的大小 287 | :return: 一个batch的数据 288 | """ 289 | sample_size = min(len(self.memory), batch_size) 290 | sample_list = sample(self.memory, sample_size) 291 | n_step = torch.Tensor([one_data['n_step'] for one_data in sample_list]) 292 | max_step = int(torch.max(n_step)) 293 | 294 | obs = torch.stack( 295 | [torch.cat([one_data['obs'], 296 | torch.zeros([max_step - one_data['obs'].shape[0]] + 297 | list(one_data['obs'].shape[1:])) 298 | ]) 299 | for one_data in sample_list], dim=0).detach() 300 | 301 | obs_next = torch.stack( 302 | [torch.cat([one_data['obs_next'], 303 | torch.zeros(size=[max_step - one_data['obs_next'].shape[0]] + 304 | list(one_data['obs_next'].shape[1:]))]) 305 | for one_data in sample_list], dim=0).detach() 306 | 307 | state = torch.stack( 308 | [torch.cat([one_data['state'], 309 | torch.zeros([max_step - one_data['state'].shape[0]] + 310 | list(one_data['state'].shape[1:]))]) 311 | for one_data in sample_list], dim=0).detach() 312 | 313 | state_next = torch.stack( 314 | [torch.cat([one_data['state_next'], 315 | torch.zeros([max_step - one_data['state_next'].shape[0]] + 316 | list(one_data['state_next'].shape[1:]))]) 317 | for one_data in sample_list], dim=0).detach() 318 | 319 | rewards = torch.stack( 320 | [torch.cat([one_data['rewards'], 321 | torch.zeros([max_step - one_data['rewards'].shape[0]] + 322 | list(one_data['rewards'].shape[1:]))]) 323 | for one_data in sample_list], dim=0).detach() 324 | 325 | actions = torch.stack( 326 | [torch.cat([one_data['actions'], 327 | torch.zeros([max_step - one_data['actions'].shape[0]] + 328 | list(one_data['actions'].shape[1:]))]) 329 | for one_data in sample_list], dim=0).unsqueeze(dim=-1).detach() 330 | 331 | actions_onehot = torch.stack( 332 | [torch.cat([one_data['actions_onehot'], 333 | torch.zeros([max_step - one_data['actions_onehot'].shape[0]] + 334 | list(one_data['actions_onehot'].shape[1:]))]) 335 | for one_data in sample_list], dim=0).detach() 336 | 337 | terminated = torch.stack( 338 | [torch.cat([torch.ones(one_data['n_step']), torch.zeros(max_step - one_data['n_step'])]) 339 | for one_data in sample_list], dim=0).detach() 340 | 341 | batch_data = { 342 | 'obs': obs, 343 | 'obs_next': obs_next, 344 | 'state': state, 345 | 'state_next': state_next, 346 | 'rewards': rewards, 347 | 'actions': actions, 348 | 'actions_onehot': actions_onehot, 349 | 'max_step': max_step, 350 | 'sample_size': sample_size, 351 | 'terminated': terminated 352 | } 353 | return batch_data 354 | 355 | def get_memory_real_size(self): 356 | return len(self.memory) 357 | --------------------------------------------------------------------------------