├── ACORM_MAPPO.jpg ├── ACORM_MAPPO ├── algorithm │ ├── __pycache__ │ │ ├── acorm.cpython-37.pyc │ │ └── mappo.cpython-37.pyc │ ├── acorm.py │ └── mappo.py ├── main.py ├── result │ └── sacred │ │ ├── acorm │ │ ├── 2s3z_seed0.npy │ │ ├── 2s3z_seed1.npy │ │ ├── 2s3z_seed2.npy │ │ ├── 2s3z_seed3.npy │ │ ├── 3s5z_seed0.npy │ │ ├── 3s5z_seed1.npy │ │ ├── 3s5z_seed2.npy │ │ ├── 3s5z_seed3.npy │ │ ├── 3s5z_vs_3s6z_seed0.npy │ │ ├── 3s5z_vs_3s6z_seed1.npy │ │ ├── 3s5z_vs_3s6z_seed2.npy │ │ ├── 3s5z_vs_3s6z_seed3.npy │ │ ├── 5m_vs_6m_seed0.npy │ │ ├── 5m_vs_6m_seed1.npy │ │ ├── 5m_vs_6m_seed2.npy │ │ ├── 5m_vs_6m_seed3.npy │ │ ├── MMM2_seed0.npy │ │ ├── MMM2_seed1.npy │ │ ├── MMM2_seed2.npy │ │ ├── MMM2_seed3.npy │ │ ├── corridor_seed0.npy │ │ ├── corridor_seed1.npy │ │ ├── corridor_seed2.npy │ │ └── corridor_seed3.npy │ │ └── mappo │ │ ├── 2s3z_seed0.npy │ │ ├── 2s3z_seed1.npy │ │ ├── 2s3z_seed2.npy │ │ ├── 2s3z_seed3.npy │ │ ├── 3s5z_seed0.npy │ │ ├── 3s5z_seed1.npy │ │ ├── 3s5z_seed2.npy │ │ ├── 3s5z_seed3.npy │ │ ├── 3s5z_vs_3s6z_seed0.npy │ │ ├── 3s5z_vs_3s6z_seed1.npy │ │ ├── 3s5z_vs_3s6z_seed2.npy │ │ ├── 3s5z_vs_3s6z_seed3.npy │ │ ├── 5m_vs_6m_seed0.npy │ │ ├── 5m_vs_6m_seed1.npy │ │ ├── 5m_vs_6m_seed2.npy │ │ ├── 5m_vs_6m_seed3.npy │ │ ├── MMM2_seed0.npy │ │ ├── MMM2_seed1.npy │ │ ├── MMM2_seed2.npy │ │ ├── MMM2_seed3.npy │ │ ├── corridor_seed0.npy │ │ ├── corridor_seed1.npy │ │ ├── corridor_seed2.npy │ │ └── corridor_seed3.npy ├── run.py ├── run.sh └── util │ ├── __pycache__ │ ├── acorm_net.cpython-37.pyc │ ├── attention.cpython-37.pyc │ ├── net.cpython-37.pyc │ └── replay_buffer.cpython-37.pyc │ ├── acorm_net.py │ ├── attention.py │ ├── net.py │ └── replay_buffer.py ├── ACORM_QMIX.jpg ├── ACORM_QMIX ├── algorithm │ ├── __pycache__ │ │ ├── acorm.cpython-310.pyc │ │ ├── acorm.cpython-37.pyc │ │ ├── recl.cpython-37.pyc │ │ ├── vdn_qmix.cpython-310.pyc │ │ ├── vdn_qmix.cpython-37.pyc │ │ ├── vdn_qmix_reuse.cpython-37.pyc │ │ ├── vdn_qmix_reuse_v1.cpython-37.pyc │ │ └── vdn_qmix_reuse_v2.cpython-37.pyc │ ├── acorm.py │ └── vdn_qmix.py ├── main.py ├── result │ └── sacred │ │ ├── acorm │ │ ├── 10m_vs_11m_seed0.npy │ │ ├── 10m_vs_11m_seed1.npy │ │ ├── 10m_vs_11m_seed2.npy │ │ ├── 10m_vs_11m_seed3.npy │ │ ├── 1c3s5z_seed0.npy │ │ ├── 1c3s5z_seed1.npy │ │ ├── 1c3s5z_seed2.npy │ │ ├── 1c3s5z_seed3.npy │ │ ├── 27m_vs_30m_seed0.npy │ │ ├── 27m_vs_30m_seed1.npy │ │ ├── 27m_vs_30m_seed2.npy │ │ ├── 27m_vs_30m_seed3.npy │ │ ├── 2c_vs_64zg_seed0.npy │ │ ├── 2c_vs_64zg_seed1.npy │ │ ├── 2c_vs_64zg_seed2.npy │ │ ├── 2c_vs_64zg_seed3.npy │ │ ├── 2c_vs_64zg_seed4.npy │ │ ├── 2s3z_seed0.npy │ │ ├── 2s3z_seed1.npy │ │ ├── 2s3z_seed2.npy │ │ ├── 2s3z_seed3.npy │ │ ├── 3s5z_seed0.npy │ │ ├── 3s5z_seed1.npy │ │ ├── 3s5z_seed2.npy │ │ ├── 3s5z_seed3.npy │ │ ├── 3s5z_vs_3s6z_seed0.npy │ │ ├── 3s5z_vs_3s6z_seed1.npy │ │ ├── 3s5z_vs_3s6z_seed2.npy │ │ ├── 3s5z_vs_3s6z_seed3.npy │ │ ├── 5m_vs_6m_seed0.npy │ │ ├── 5m_vs_6m_seed1.npy │ │ ├── 5m_vs_6m_seed2.npy │ │ ├── 5m_vs_6m_seed3.npy │ │ ├── 6h_vs_8z_seed0.npy │ │ ├── 6h_vs_8z_seed1.npy │ │ ├── 6h_vs_8z_seed2.npy │ │ ├── 6h_vs_8z_seed3.npy │ │ ├── MMM2_seed0.npy │ │ ├── MMM2_seed1.npy │ │ ├── MMM2_seed2.npy │ │ ├── MMM2_seed3.npy │ │ ├── MMM2_seed4.npy │ │ ├── bane_vs_bane_seed0.npy │ │ ├── bane_vs_bane_seed1.npy │ │ ├── bane_vs_bane_seed2.npy │ │ ├── bane_vs_bane_seed3.npy │ │ ├── corridor_seed0.npy │ │ ├── corridor_seed1.npy │ │ ├── corridor_seed2.npy │ │ └── corridor_seed3.npy │ │ └── qmix │ │ ├── 10m_vs_11m_seed0.npy │ │ ├── 10m_vs_11m_seed1.npy │ │ ├── 10m_vs_11m_seed2.npy │ │ ├── 10m_vs_11m_seed3.npy │ │ ├── 1c3s5z_seed0.npy │ │ ├── 1c3s5z_seed1.npy │ │ ├── 1c3s5z_seed2.npy │ │ ├── 1c3s5z_seed3.npy │ │ ├── 27m_vs_30m_seed0.npy │ │ ├── 27m_vs_30m_seed1.npy │ │ ├── 27m_vs_30m_seed2.npy │ │ ├── 27m_vs_30m_seed3.npy │ │ ├── 2c_vs_64zg_seed0.npy │ │ ├── 2c_vs_64zg_seed1.npy │ │ ├── 2c_vs_64zg_seed2.npy │ │ ├── 2c_vs_64zg_seed3.npy │ │ ├── 2s3z_seed0.npy │ │ ├── 2s3z_seed1.npy │ │ ├── 2s3z_seed2.npy │ │ ├── 2s3z_seed3.npy │ │ ├── 3s5z_seed0.npy │ │ ├── 3s5z_seed1.npy │ │ ├── 3s5z_seed2.npy │ │ ├── 3s5z_seed3.npy │ │ ├── 3s5z_vs_3s6z_seed0.npy │ │ ├── 3s5z_vs_3s6z_seed1.npy │ │ ├── 3s5z_vs_3s6z_seed2.npy │ │ ├── 3s5z_vs_3s6z_seed3.npy │ │ ├── 5m_vs_6m_seed0.npy │ │ ├── 5m_vs_6m_seed1.npy │ │ ├── 5m_vs_6m_seed2.npy │ │ ├── 5m_vs_6m_seed3.npy │ │ ├── 6h_vs_8z_seed0.npy │ │ ├── 6h_vs_8z_seed1.npy │ │ ├── 6h_vs_8z_seed2.npy │ │ ├── 6h_vs_8z_seed3.npy │ │ ├── MMM2_seed0.npy │ │ ├── MMM2_seed1.npy │ │ ├── MMM2_seed2.npy │ │ ├── MMM2_seed3.npy │ │ ├── bane_vs_bane_seed0.npy │ │ ├── bane_vs_bane_seed1.npy │ │ ├── bane_vs_bane_seed2.npy │ │ ├── bane_vs_bane_seed3.npy │ │ ├── corridor_seed0.npy │ │ ├── corridor_seed1.npy │ │ ├── corridor_seed2.npy │ │ └── corridor_seed3.npy ├── run.py ├── run.sh └── util │ ├── __pycache__ │ ├── attention.cpython-310.pyc │ ├── attention.cpython-37.pyc │ ├── net.cpython-310.pyc │ ├── net.cpython-37.pyc │ ├── replay_buffer.cpython-310.pyc │ ├── replay_buffer.cpython-37.pyc │ └── replay_buffer_v1.cpython-37.pyc │ ├── attention.py │ ├── net.py │ └── replay_buffer.py ├── README.md ├── ablation.jpg ├── ablation_k_means.jpg ├── plot.py ├── requirements.txt ├── results_grf.jpg ├── results_mappo.jpg ├── results_smac.jpg ├── visual_mha.jpg └── visual_t.jpg /ACORM_MAPPO.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO.jpg -------------------------------------------------------------------------------- /ACORM_MAPPO/algorithm/__pycache__/acorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/algorithm/__pycache__/acorm.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_MAPPO/algorithm/__pycache__/mappo.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/algorithm/__pycache__/mappo.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_MAPPO/algorithm/acorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR 3 | from torch.distributions import Categorical 4 | from torch.utils.data.sampler import * 5 | import numpy as np 6 | import copy 7 | from util.acorm_net import ACORM_Actor, ACORM_Critic 8 | from sklearn.cluster import KMeans 9 | 10 | 11 | class ACORM(object): 12 | def __init__(self, args): 13 | self.N = args.N 14 | self.obs_dim = args.obs_dim 15 | self.action_dim = args.action_dim 16 | self.state_dim = args.state_dim 17 | self.agent_embedding_dim = args.agent_embedding_dim 18 | self.role_embedding_dim = args.role_embedding_dim 19 | self.rnn_hidden_dim = args.rnn_hidden_dim 20 | 21 | self.batch_size = args.batch_size 22 | self.mini_batch_size = args.mini_batch_size 23 | self.max_train_steps = args.max_train_steps 24 | 25 | self.actor_lr = args.actor_lr 26 | self.critic_lr = args.critic_lr 27 | self.lr = args.lr 28 | self.gamma = args.gamma 29 | self.lamda = args.lamda 30 | self.clip_epsilon = args.clip_epsilon 31 | self.K_epochs = args.K_epochs 32 | self.entropy_coef = args.entropy_coef 33 | self.tau = args.tau 34 | 35 | self.use_lr_decay = args.use_lr_decay 36 | self.lr_decay_steps = args.lr_decay_steps 37 | self.lr_decay_rate = args.lr_decay_rate 38 | self.use_adv_norm = args.use_adv_norm 39 | self.use_grad_clip = args.use_grad_clip 40 | 41 | # recl 42 | self.cl_lr = args.cl_lr 43 | self.multi_steps = args.multi_steps 44 | self.train_recl_freq = args.train_recl_freq 45 | self.cluster_num = args.cluster_num 46 | 47 | self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu') 48 | 49 | self.actor = ACORM_Actor(args) 50 | self.critic =ACORM_Critic(args) 51 | 52 | # self.actor_parameters = list(self.actor.actor_net.parameters()) \ 53 | # + list(self.actor.embedding_net.agent_embedding_net.parameters()) \ 54 | # + list(self.actor.embedding_net.role_embedding_net.encoder.parameters()) 55 | # self.actor_optimizer = torch.optim.Adam(self.actor_parameters, lr=self.actor_lr) 56 | # self.actor_lr_decay = StepLR(self.actor_optimizer, step_size=self.lr_decay_steps, gamma=self.lr_decay_rate) 57 | 58 | # self.critic_parameters = self.critic.parameters() 59 | # self.critic_optimizer = torch.optim.Adam(self.critic_parameters, lr=self.critic_lr) 60 | # self.critic_lr_decay = StepLR(self.critic_optimizer, step_size=self.lr_decay_steps, gamma=self.lr_decay_rate) 61 | 62 | self.encoder_decoder_para = list(self.actor.embedding_net.agent_embedding_net.parameters()) \ 63 | + list(self.actor.embedding_net.agent_embedding_decoder.parameters()) 64 | self.encoder_decoder_optimizer = torch.optim.Adam(self.encoder_decoder_para, lr=args.agent_embedding_lr) 65 | 66 | self.ac_parameters = list(self.actor.actor_net.parameters()) \ 67 | + list(self.actor.embedding_net.agent_embedding_net.parameters()) \ 68 | + list(self.actor.embedding_net.role_embedding_net.encoder.parameters()) \ 69 | + list(self.critic.parameters()) 70 | self.ac_optimizer = torch.optim.Adam(self.ac_parameters, lr=self.lr) 71 | self.ac_lr_decay = StepLR(self.ac_optimizer, step_size=self.lr_decay_steps, gamma=self.lr_decay_rate) 72 | 73 | self.cl_parameters = self.actor.embedding_net.parameters() 74 | self.cl_optimizer = torch.optim.Adam(self.cl_parameters, lr=self.cl_lr) 75 | self.cl_lr_decay = StepLR(self.cl_optimizer, step_size=self.lr_decay_steps, gamma=self.lr_decay_rate) 76 | 77 | self.actor.to(self.device) 78 | self.critic.to(self.device) 79 | 80 | self.train_step = 0 81 | 82 | def choose_action(self, agent_embedding, role_embedding, avail_a_n, evaluate): 83 | with torch.no_grad(): 84 | avail_a_n = torch.tensor(avail_a_n, dtype=torch.float32) # avail_a_n.shape=(N, action_dim) 85 | avail_a_n = avail_a_n.to(self.device) 86 | prob = self.actor.actor_forward(agent_embedding, role_embedding, avail_a_n) 87 | 88 | if evaluate: 89 | a_n = prob.argmax(dim=-1).to('cpu') 90 | return a_n.numpy(), None 91 | else: 92 | dist = Categorical(probs=prob) 93 | a_n = dist.sample() 94 | a_logprob_n = dist.log_prob(a_n) 95 | return a_n.to('cpu').numpy(), a_logprob_n.to('cpu').numpy() 96 | 97 | def get_value(self, s, obs_n, role_embed_n): 98 | with torch.no_grad(): 99 | # obs_n = torch.tensor(obs_n,dtype=torch.float32).to(self.device) # (N, obs_dim) 100 | state = torch.tensor(np.array(s), dtype=torch.float32).unsqueeze(0).to(self.device) # (state_dim,)->(1, state_dim) 101 | v_n = self.critic(obs_n, state, role_embed_n) 102 | return v_n.to('cpu').numpy().flatten() 103 | 104 | def train(self, replay_buffer): 105 | self.train_step += 1 106 | batch = replay_buffer.get_training_data() 107 | max_episode_len = replay_buffer.max_episode_len 108 | batch_obs, batch_s, batch_r, batch_v_n, batch_dw, batch_active, batch_avail_a_n, batch_a_n, batch_a_logprob_n = self.get_inputs(batch) 109 | 110 | if self.train_step % self.train_recl_freq == 0: 111 | self.update_recl(batch_obs, batch_active, max_episode_len) 112 | self.soft_update_params(self.actor.embedding_net.role_embedding_net.encoder, self.actor.embedding_net.role_embedding_net.target_encoder, self.tau) 113 | actor_loss, critic_loss = self.update_ppo(max_episode_len, batch_obs, batch_s, batch_r, batch_v_n, batch_dw, batch_active, batch_avail_a_n, batch_a_n, batch_a_logprob_n) 114 | self.soft_update_params(self.actor.embedding_net.role_embedding_net.encoder, self.actor.embedding_net.role_embedding_net.target_encoder, self.tau) 115 | return actor_loss, critic_loss 116 | 117 | def pretrain_agent_embedding(self, replay_buffer): 118 | batch = replay_buffer.get_training_data() 119 | max_episode_len = replay_buffer.max_episode_len 120 | batch_o = batch['obs_n'].to(self.device) # (batch, max_len, N, obs_dim) 121 | batch_active = batch['active'].to(self.device) # (batch, max_len, N) 122 | 123 | self.actor.embedding_net.agent_embedding_net.rnn_hidden = None 124 | agent_embeddings = [] 125 | for t in range(max_episode_len-1): 126 | agent_embedding = self.actor.embedding_net.agent_embed_forward(batch_o[:, t].reshape(-1, self.obs_dim), 127 | detach=False) 128 | agent_embeddings.append(agent_embedding.reshape(-1, self.N, self.agent_embedding_dim)) # (batch_size, N, agent_embedding_dim) 129 | agent_embeddings = torch.stack(agent_embeddings, dim=1) #(batch_size, max_episode_len, N, agent_embedding_dim) 130 | decoder_output = self.actor.embedding_net.agent_embedding_decoder(agent_embeddings.reshape(-1,self.agent_embedding_dim)).reshape(-1, max_episode_len-1, self.N, self.obs_dim+self.N) 131 | batch_obs_hat = batch_o[:,1:] 132 | agent_id_one_hot = torch.eye(self.N).unsqueeze(0).unsqueeze(0).repeat(batch_o.shape[0], max_episode_len-1, 1, 1).to(self.device) 133 | decoder_target = torch.cat([batch_obs_hat, agent_id_one_hot], dim=-1) # (batch_size, max_len, N, obs_dim+N) 134 | mask = batch_active[:,1:].unsqueeze(-1).repeat(1, 1, 1, self.obs_dim+self.N) 135 | loss = (((decoder_output - decoder_target) * mask)**2).sum()/mask.sum() 136 | 137 | self.encoder_decoder_optimizer.zero_grad() 138 | loss.backward() 139 | self.encoder_decoder_optimizer.step() 140 | return loss 141 | 142 | def pretrain_recl(self, replay_buffer): 143 | batch = replay_buffer.get_training_data() 144 | max_episode_len = replay_buffer.max_episode_len 145 | batch_o = batch['obs_n'].to(self.device) 146 | batch_active = batch['active'].to(self.device) 147 | recl_loss = self.update_recl(batch_o, batch_active, max_episode_len) 148 | self.soft_update_params(self.actor.embedding_net.role_embedding_net.encoder, self.actor.embedding_net.role_embedding_net.target_encoder, self.tau) 149 | return recl_loss 150 | 151 | def update_ppo(self, max_episode_len, batch_obs, batch_s, batch_r, batch_v_n, batch_dw, batch_active, batch_avail_a_n, batch_a_n, batch_a_logprob_n): 152 | adv = [] 153 | gae = 0 154 | with torch.no_grad(): # adv and v_target have no gradient 155 | # deltas.shape = (batch, max_episode_len, N) 156 | deltas = batch_r + self.gamma * (1-batch_dw) * batch_v_n[:, 1:] - batch_v_n[:, :-1] 157 | for t in reversed(range(max_episode_len)): 158 | gae = deltas[:, t] + self.gamma * self.lamda * gae 159 | adv.insert(0, gae) 160 | adv = torch.stack(adv, dim=1) # adv.shape=(batch, max_len, N) 161 | v_target = adv + batch_v_n[:, :-1] # v_target.shape=(batch, max_len, N) 162 | 163 | if self.use_adv_norm: 164 | adv_copy = copy.deepcopy(adv.to('cpu').numpy()) 165 | adv_copy[batch_active.to('cpu').numpy() == 0] = np.nan 166 | adv = ((adv - np.nanmean(adv_copy)) / (np.nanstd(adv_copy) + 1e-5)) 167 | adv = adv.to(self.device) 168 | sum_actor_loss, sum_critic_loss = 0, 0 169 | for _ in range(self.K_epochs): 170 | for index in BatchSampler(SequentialSampler(range(self.batch_size)), self.mini_batch_size, False): 171 | # rnn net need to stack according to the time 172 | self.actor.embedding_net.agent_embedding_net.rnn_hidden = None 173 | self.critic.state_gru_hidden = None 174 | self.critic.obs_gru_hidden = None 175 | agent_embeddings, tau_obs, tau_state = [], [], [] 176 | for t in range(max_episode_len): 177 | # batch_s.shape=(batch, max_len, state_dim) 178 | obs = batch_obs[index, t].reshape(-1, self.obs_dim) # (batch*N, obs_dim) 179 | s = batch_s[index, t].reshape(-1, self.state_dim) #(batch, state_dim) 180 | agent_embed = self.actor.embedding_net.agent_embed_forward(obs, detach=False) # (batch*N, agent_embed_dim) 181 | # h_obs = self.critic.obs_forward(obs,s.unsqueeze(1).repeat(1,self.N,1).reshape(-1, self.state_dim)) # (batch*N, rnn_dim) 182 | h_obs = self.critic.obs_forward(obs) # (batch*N, rnn_dim) 183 | h_state = self.critic.state_forward(s) # (batch, N*rnn_dim) 184 | agent_embeddings.append(agent_embed.reshape(self.mini_batch_size, self.N, -1)) 185 | tau_obs.append(h_obs.reshape(self.mini_batch_size, self.N, -1)) 186 | tau_state.append(h_state.reshape(self.mini_batch_size, -1)) 187 | # stack according to the time 188 | agent_embeddings = torch.stack(agent_embeddings, dim=1) # (batch, max_len, N, agent_embed_dim) 189 | tau_obs = torch.stack(tau_obs, dim=1) # (batch, max_len, N, rnn_dim) 190 | tau_state = torch.stack(tau_state, dim=1) # (batch, max_len, N*rnn_dim) 191 | 192 | # calculate prob, value 193 | role_embeddings = self.actor.embedding_net.role_embed_foward(agent_embeddings.reshape(-1, self.agent_embedding_dim)) # (batch*len*N, role_embed_dim) 194 | probs_now = self.actor.actor_forward(agent_embeddings.reshape(-1, self.agent_embedding_dim), 195 | role_embeddings, batch_avail_a_n[index].reshape(-1, self.action_dim)) # (batch*len*N, actor_dim) 196 | probs_now = probs_now.reshape(self.mini_batch_size, max_episode_len, self.N, -1) # (batch, len, N, actor_dim) 197 | 198 | tau_state = tau_state.reshape(-1, self.N, self.rnn_hidden_dim) # (batch*len, rnn_dim)->(batch*len, N, rnn_dim) 199 | att = self.critic.att_forward(tau_state, role_embeddings.reshape(-1, self.N, self.role_embedding_dim).detach()) # (batch*len, N, att_out_dim) 200 | values_now = self.critic.critic_forward(tau_obs.reshape(-1, tau_obs.shape[-1]), 201 | tau_state.unsqueeze(1).repeat(1,self.N,1,1).reshape(-1, self.N*self.rnn_hidden_dim), 202 | att.unsqueeze(1).repeat(1,self.N,1,1).reshape(-1, self.N*att.shape[-1])) # (batch*len*N, 1) 203 | values_now = values_now.reshape(self.mini_batch_size, max_episode_len, self.N) 204 | 205 | # calcute loss 206 | dist_now = Categorical(probs_now) 207 | dist_entropy = dist_now.entropy() # shape=(mini_batch, max_len, N) 208 | a_logprob_n_now = dist_now.log_prob(batch_a_n[index]) # shape=(mini_batch, max_len, N) 209 | # a/b = exp(log(a)-log(b)) 210 | ratios = torch.exp(a_logprob_n_now-batch_a_logprob_n[index].detach()) # ratios.shape=(mini_batch, max_len, N) 211 | surr1 = ratios * adv[index] 212 | surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * adv[index] 213 | actor_loss = -torch.min(surr1, surr2) - self.entropy_coef * dist_entropy 214 | actor_loss = (actor_loss * batch_active[index]).sum() / batch_active[index].sum() 215 | # sum_actor_loss += actor_loss.item() 216 | # self.actor_optimizer.zero_grad() 217 | # actor_loss.backward() 218 | # if self.use_grad_clip: 219 | # torch.nn.utils.clip_grad_norm_(self.actor_parameters, 10.0) 220 | # self.actor_optimizer.step() 221 | 222 | critic_loss = (values_now - v_target[index]) ** 2 223 | critic_loss = (critic_loss * batch_active[index]).sum() / batch_active[index].sum() 224 | # sum_critic_loss += critic_loss.item() 225 | # self.critic_optimizer.zero_grad() 226 | # critic_loss.backward() 227 | # if self.use_grad_clip: 228 | # torch.nn.utils.clip_grad_norm_(self.critic_parameters, 10.0) 229 | # self.critic_optimizer.step() 230 | 231 | self.ac_optimizer.zero_grad() 232 | ac_loss = actor_loss + critic_loss 233 | ac_loss.backward() 234 | if self.use_grad_clip: 235 | torch.nn.utils.clip_grad_norm_(self.ac_parameters, 10.0) 236 | self.ac_optimizer.step() 237 | if self.use_lr_decay: 238 | self.ac_lr_decay.step() 239 | 240 | # if self.use_lr_decay: 241 | # self.actor_lr_decay.step() 242 | # self.critic_lr_decay.step() 243 | 244 | return sum_actor_loss, sum_critic_loss 245 | 246 | 247 | def update_recl(self, batch_obs, batch_active, max_episode_len): # role embedding contrative learning 248 | loss = 0.0 249 | self.actor.embedding_net.agent_embedding_net.rnn_hidden = None 250 | labels = np.zeros((batch_obs.shape[0], self.N)) # (batch, N) 251 | for t in range(max_episode_len): # t = 0, 1, 2...(max_episode_len-1) 252 | with torch.no_grad(): 253 | agent_embedding = self.actor.embedding_net.agent_embed_forward(batch_obs[:, t].reshape(-1, self.obs_dim), detach=True) # (batch*N, obs_dim) 254 | role_embedding_query = self.actor.embedding_net.role_embed_foward(agent_embedding, detach=False, ema=False).reshape(-1, self.N, self.role_embedding_dim) # (batch, N, role_dim) 255 | role_embedding_key = self.actor.embedding_net.role_embed_foward(agent_embedding, detach=True, ema=True).reshape(-1, self.N, self.role_embedding_dim) 256 | 257 | logits = torch.bmm(role_embedding_query, self.actor.embedding_net.W.squeeze(0).expand((role_embedding_query.shape[0],self.role_embedding_dim,self.role_embedding_dim))) 258 | logits = torch.bmm(logits, role_embedding_key.transpose(1,2)) # (batch_size, N, N) 259 | logits = logits - torch.max(logits, dim=-1)[0][:,:,None] 260 | exp_logits = torch.exp(logits) # (batch_size, N, 1) 261 | agent_embedding = agent_embedding.reshape(batch_obs.shape[0],self.N, -1).to('cpu') # shape=(batch_size,N, agent_embed_dim) 262 | for idx in range(agent_embedding.shape[0]): # idx = 0,1,2...(batch_size-1) 263 | if torch.sum(batch_active[idx, t]).item() > (self.N -1): 264 | if t % self.multi_steps == 0: 265 | clusters_labels = KMeans(n_clusters=self.cluster_num).fit(agent_embedding[idx]).labels_ # (1,N) 266 | labels[idx] = copy.deepcopy(clusters_labels) 267 | else: 268 | clusters_labels = copy.deepcopy(labels[idx]) 269 | # clusters_labels, _ = kmeans(X=agent_embedding[idx],num_clusters=self.cluster_num) 270 | for j in range(self.cluster_num): # j = 0,1,...(cluster_num -1) 271 | label_pos = [idx for idx, value in enumerate(clusters_labels) if value==j] 272 | # label_neg = [idx for idx, value in enumerate(clusters_labels) if value!=j] 273 | for anchor in label_pos: 274 | loss += -torch.log(exp_logits[idx, anchor, label_pos].sum()/exp_logits[idx, anchor].sum()) 275 | loss /= (self.batch_size * max_episode_len * self.N*10) 276 | if torch.sum(batch_active[idx, t]).item() > (self.N -1): 277 | self.cl_optimizer.zero_grad() 278 | loss.backward() 279 | self.cl_optimizer.step() 280 | return loss 281 | 282 | def get_inputs(self, batch): 283 | batch_obs = batch['obs_n'].to(self.device) # (batch, max_len, N, obs_dim) 284 | batch_s = batch['s'].to(self.device) # (batch, max_len, state_dim) 285 | 286 | batch_r = batch['r'].to(self.device) # (batch, max_len, N) 287 | batch_v_n = batch['v_n'].to(self.device) # (batch, max_len+1, N) 288 | batch_dw = batch['dw'].to(self.device) # (batch, max_len, N) 289 | batch_active = batch['active'].to(self.device) # (batch, max_len, N) 290 | batch_avail_a_n = batch['avail_a_n'] # (batch, max_len, N, action_dim) 291 | batch_a_n = batch['a_n'].to(self.device) # (batch, max_len, N) 292 | batch_a_logprob_n = batch['a_logprob_n'].to(self.device) # (batch, max_len, N) 293 | 294 | return batch_obs, batch_s, batch_r, batch_v_n, batch_dw, batch_active, batch_avail_a_n, batch_a_n, batch_a_logprob_n 295 | 296 | def soft_update_params(self, net, target_net, tau): 297 | for param, target_param in zip(net.parameters(), target_net.parameters()): 298 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | -------------------------------------------------------------------------------- /ACORM_MAPPO/algorithm/mappo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.net import Actor, Critic 3 | from torch.distributions import Categorical 4 | from torch.utils.data.sampler import * 5 | import numpy as np 6 | import copy 7 | from torch.optim.lr_scheduler import StepLR 8 | 9 | class MAPPO(object): 10 | def __init__(self, args): 11 | self.N = args.N 12 | self.obs_dim = args.obs_dim 13 | self.action_dim = args.action_dim 14 | 15 | self.batch_size = args.batch_size 16 | self.mini_batch_size = args.mini_batch_size 17 | self.max_train_steps = args.max_train_steps 18 | 19 | self.lr = args.lr 20 | self.gamma = args.gamma 21 | self.lamda = args.lamda 22 | self.clip_epsilon = args.clip_epsilon 23 | self.K_epochs = args.K_epochs 24 | self.entropy_coef = args.entropy_coef 25 | 26 | self.use_lr_decay = args.use_lr_decay 27 | self.lr_decay_steps = args.lr_decay_steps 28 | self.lr_decay_rate = args.lr_decay_rate 29 | self.use_adv_norm = args.use_adv_norm 30 | self.use_grad_clip = args.use_grad_clip 31 | self.add_agent_id = args.add_agent_id 32 | self.use_agent_specific = args.use_agent_specific 33 | 34 | self.actor_input_dim = args.obs_dim 35 | self.critic_input_dim = args.state_dim 36 | 37 | if self.add_agent_id: 38 | self.actor_input_dim += args.N 39 | self.critic_input_dim += args.N 40 | if self.use_agent_specific: 41 | self.critic_input_dim += args.obs_dim 42 | self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu') 43 | 44 | self.actor = Actor(args, self.actor_input_dim) 45 | self.critic = Critic(args, self.critic_input_dim) 46 | 47 | self.ac_parameters = list(self.actor.parameters()) + list(self.critic.parameters()) 48 | self.ac_optimizer = torch.optim.Adam(self.ac_parameters, lr=self.lr, eps=1e-5) 49 | self.ac_lr_decay = StepLR(self.ac_optimizer, step_size=self.lr_decay_steps, gamma=self.lr_decay_rate) 50 | 51 | self.actor.to(self.device) 52 | self.critic.to(self.device) 53 | 54 | def choose_action(self, obs_n, avail_a_n, evaluate): 55 | with torch.no_grad(): 56 | actor_input = torch.tensor(np.array(obs_n), dtype=torch.float32) # obs_n.shape=(N, obs_dim) 57 | if self.add_agent_id: 58 | actor_input = torch.cat([actor_input, torch.eye(self.N)], dim=-1) # input.shape=(N, obs_dim+N) 59 | avail_a_n = torch.tensor(avail_a_n, dtype=torch.float32) # avail_a_n.shape=(N, action_dim) 60 | actor_input = actor_input.to(self.device) 61 | avail_a_n = avail_a_n.to(self.device) 62 | prob = self.actor(actor_input, avail_a_n) # prob.shape=(N, action_dim) 63 | 64 | if evaluate: 65 | a_n = prob.argmax(dim=-1).to('cpu') 66 | return a_n.numpy(), None 67 | else: 68 | dist = Categorical(probs=prob) 69 | a_n = dist.sample() 70 | a_logprob_n = dist.log_prob(a_n) 71 | return a_n.to('cpu').numpy(), a_logprob_n.to('cpu').numpy() 72 | 73 | def get_value(self, s, obs_n): 74 | with torch.no_grad(): 75 | obs_n = torch.tensor(np.array(obs_n),dtype=torch.float32) 76 | critic_input = torch.tensor(np.array(s), dtype=torch.float32).unsqueeze(0).repeat(self.N,1) # (state_dim,)->(N, state_dim) 77 | if self.use_agent_specific: 78 | critic_input = torch.cat([critic_input, obs_n], dim=-1) # (N, state_dim+obs_dim) 79 | if self.add_agent_id: 80 | critic_input = torch.cat([critic_input, torch.eye(self.N)], dim=-1) # (N, input_dim) 81 | critic_input = critic_input.to(self.device) 82 | v_n = self.critic(critic_input) # v_n.shape=(N, 1) 83 | return v_n.to('cpu').numpy().flatten() 84 | 85 | def train(self, replay_buffer): 86 | batch = replay_buffer.get_training_data() 87 | max_episode_len = replay_buffer.max_episode_len 88 | actor_inputs, critic_inputs, batch_r, batch_v_n, batch_dw, batch_active, batch_avail_a_n, batch_a_n, batch_a_logprob_n = self.get_inputs(batch) 89 | 90 | # Calculate the advantage using GAE 91 | adv = [] 92 | gae = 0 93 | with torch.no_grad(): # adv and v_target have no gradient 94 | # deltas.shape = (batch, max_episode_len, N) 95 | deltas = batch_r + self.gamma * (1-batch_dw) * batch_v_n[:, 1:] - batch_v_n[:, :-1] 96 | for t in reversed(range(max_episode_len)): 97 | gae = deltas[:, t] + self.gamma * self.lamda * gae 98 | adv.insert(0, gae) 99 | adv = torch.stack(adv, dim=1) # adv.shape=(batch, max_len, N) 100 | v_target = adv + batch_v_n[:, :-1] # v_target.shape=(batch, max_len, N) 101 | # normalization 102 | if self.use_adv_norm: 103 | adv_copy = copy.deepcopy(adv.to('cpu').numpy()) 104 | adv_copy[batch['active'].numpy() == 0] = np.nan 105 | adv = ((adv - np.nanmean(adv_copy)) / (np.nanstd(adv_copy) + 1e-5)) 106 | adv = adv.to(self.device) 107 | 108 | for _ in range(self.K_epochs): 109 | for index in BatchSampler(SequentialSampler(range(self.batch_size)), self.mini_batch_size, False): 110 | # probs_now.shape=(mini_batch, max_len, N, actor_dim) 111 | # values_now.shape=(mini_batch, max_len, N) 112 | self.actor.rnn_hidden = None 113 | self.critic.rnn_hidden = None 114 | probs_now, values_now = [], [] 115 | for t in range(max_episode_len): 116 | prob = self.actor(actor_inputs[index, t].reshape(self.mini_batch_size*self.N, -1), 117 | batch_avail_a_n[index, t].reshape(self.mini_batch_size*self.N, -1)) # prob.shape=(mini_batch*N,action_dim) 118 | probs_now.append(prob.reshape(self.mini_batch_size, self.N, -1)) 119 | value = self.critic(critic_inputs[index, t].reshape(self.mini_batch_size*self.N, -1)) # value.shape(mini_batch*N, 1) 120 | values_now.append(value.reshape(self.mini_batch_size, self.N)) 121 | # stack according to the time 122 | probs_now = torch.stack(probs_now, dim=1) 123 | values_now = torch.stack(values_now, dim=1) 124 | 125 | dist_now = Categorical(probs_now) 126 | dist_entropy = dist_now.entropy() # shape=(mini_batch, max_len, N) 127 | a_logprob_n_now = dist_now.log_prob(batch_a_n[index]) # shape=(mini_batch, max_len, N) 128 | # a/b = exp(log(a)-log(b)) 129 | ratios = torch.exp(a_logprob_n_now-batch_a_logprob_n[index].detach()) # ratios.shape=(mini_batch, max_len, N) 130 | surr1 = ratios * adv[index] 131 | surr2 = torch.clamp(ratios, 1 - self.clip_epsilon, 1 + self.clip_epsilon) * adv[index] 132 | actor_loss = -torch.min(surr1, surr2) - self.entropy_coef * dist_entropy 133 | actor_loss = (actor_loss * batch_active[index]).sum() / batch_active[index].sum() 134 | 135 | critic_loss = (values_now - v_target[index]) ** 2 136 | critic_loss = (critic_loss * batch_active[index]).sum() / batch_active[index].sum() 137 | 138 | self.ac_optimizer.zero_grad() 139 | ac_loss = actor_loss + critic_loss 140 | ac_loss.backward() 141 | if self.use_grad_clip: 142 | torch.nn.utils.clip_grad_norm_(self.ac_parameters, 10.0) 143 | self.ac_optimizer.step() 144 | if self.use_lr_decay: 145 | self.ac_lr_decay.step() 146 | 147 | def get_inputs(self, batch): 148 | # batch['obs_n'].shape=(batch, max_len, N, obs_dim) 149 | # batch['s].shape=(batch, max_len, state_dim) 150 | actor_inputs = copy.deepcopy(batch['obs_n']) 151 | critic_inputs = copy.deepcopy(batch['s'].unsqueeze(2).repeat(1, 1, self.N, 1)) 152 | if self.use_agent_specific: 153 | critic_inputs = torch.cat([critic_inputs, batch['obs_n']], dim=-1) # 154 | if self.add_agent_id: 155 | agent_id_one_hot = torch.eye(self.N).unsqueeze(0).unsqueeze(0).repeat(self.batch_size, batch['s'].shape[1], 1, 1) 156 | actor_inputs = torch.cat([actor_inputs, agent_id_one_hot], dim=-1) # shape=(batch, max_len, N, obs_dim+N) 157 | critic_inputs = torch.cat([critic_inputs, agent_id_one_hot], dim=-1) # shape=(batch, max_len, N, state_dim+obs_dim+N) 158 | 159 | actor_inputs = actor_inputs.to(self.device) 160 | critic_inputs = critic_inputs.to(self.device) 161 | batch_r = batch['r'].to(self.device) # (batch, max_len, N) 162 | batch_v_n = batch['v_n'].to(self.device) # (batch, max_len+1, N) 163 | batch_dw = batch['dw'].to(self.device) # (batch, max_len, N) 164 | batch_active = batch['active'].to(self.device) # (batch, max_len, N) 165 | batch_avail_a_n = batch['avail_a_n'] # (batch, max_len, N, action_dim) 166 | batch_a_n = batch['a_n'].to(self.device) # (batch, max_len, N, action_dim) 167 | batch_a_logprob_n = batch['a_logprob_n'].to(self.device) 168 | return actor_inputs, critic_inputs, batch_r, batch_v_n, batch_dw, batch_active, batch_avail_a_n, batch_a_n, batch_a_logprob_n 169 | 170 | def soft_update_params(self, net, target_net, tau): 171 | for param, target_param in zip(net.parameters(), target_net.parameters()): 172 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 173 | 174 | -------------------------------------------------------------------------------- /ACORM_MAPPO/main.py: -------------------------------------------------------------------------------- 1 | from run import Runner 2 | import argparse 3 | import torch 4 | 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser("Hyperparameters Setting for MAPPO in SMAC environment") 8 | parser.add_argument("--algorithm", type=str, default="acorm", help="acorm or mappo") 9 | parser.add_argument("--max_train_steps", type=int, default=int(3e6), help=" Maximum number of training steps") 10 | parser.add_argument("--evaluate_freq", type=float, default=5000, help="Evaluate the policy every 'evaluate_freq' steps") 11 | parser.add_argument("--evaluate_times", type=float, default=32, help="Evaluate times") 12 | parser.add_argument("--save_freq", type=int, default=int(1e5), help="Save frequency") 13 | 14 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size (the number of episodes)") 15 | parser.add_argument("--mini_batch_size", type=int, default=8, help="Minibatch size (the number of episodes)") 16 | parser.add_argument("--rnn_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of RNN") 17 | parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate") 18 | parser.add_argument("--actor_lr", type=float, default=5e-4, help="Learning rate") 19 | parser.add_argument("--critic_lr", type=float, default=8e-4, help="Learning rate") 20 | parser.add_argument("--lr_decay_steps", type=int, default=500, help="every steps decay steps") 21 | parser.add_argument("--lr_decay_rate", type=float, default=0.98, help="learn decay rate") 22 | parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor") 23 | parser.add_argument("--lamda", type=float, default=0.95, help="GAE parameter") 24 | parser.add_argument("--clip_epsilon", type=float, default=0.2, help="GAE parameter") 25 | parser.add_argument("--K_epochs", type=int, default=5, help="GAE parameter") 26 | parser.add_argument("--entropy_coef", type=float, default=0.015, help="policy entropy") 27 | 28 | # ppo tricks 29 | parser.add_argument("--use_lr_decay", type=bool, default=False, help="Trick:learning rate Decay") 30 | parser.add_argument("--use_adv_norm", type=bool, default=True, help="Trick:advantage normalization") 31 | parser.add_argument("--use_grad_clip", type=bool, default=False, help="Trick: Gradient clip") 32 | parser.add_argument("--use_orthogonal_init", type=bool, default=True, help="Trick: orthogonal initialization") 33 | parser.add_argument("--add_agent_id", type=bool, default=True, help="Whether to add agent_id. Here, we do not use it.") 34 | parser.add_argument("--use_agent_specific", type=bool, default=True, help="Whether to use agent specific global state.") 35 | 36 | parser.add_argument('--env_name', type=str, default='MMM2') #['3m', '8m', '2s3z'] 37 | parser.add_argument('--device', type=str, default='cuda:0') 38 | parser.add_argument("--seed", type=int, default=123, help="random seed") 39 | 40 | # recl 41 | parser.add_argument("--agent_embedding_dim", type=int, default=64, help="The dimension of the agent embedding") 42 | parser.add_argument("--role_embedding_dim", type=int, default=32, help="The dimension of the role embedding") 43 | parser.add_argument("--cluster_num", type=int, default=int(3), help="the cluster number of k-means") 44 | parser.add_argument("--cl_lr", type=float, default=5e-4, help="Learning rate") 45 | parser.add_argument("--agent_embedding_lr", type=float, default=1e-3, help="agent_embedding Learning rate") 46 | parser.add_argument("--train_recl_freq", type=int, default=2, help="Train frequency of the contrastive role embedding") 47 | parser.add_argument("--multi_steps", type=int, default=1, help="Train frequency of the RECL network") 48 | parser.add_argument("--tau", type=float, default=0.005, help="If use soft update") 49 | parser.add_argument("--agent_embed_pretrain_epochs", type=int, default=150, help="pretrain steps") 50 | parser.add_argument("--recl_pretrain_epochs", type=int, default=120, help="pretrain steps") 51 | 52 | # attention 53 | parser.add_argument("--att_dim", type=int, default=256, help="The dimension of the attention net") 54 | parser.add_argument("--att_out_dim", type=int, default=64, help="The dimension of the attention net") 55 | parser.add_argument("--n_heads", type=int, default=8, help="multi-head attention") 56 | parser.add_argument("--soft_temp", type=float, default=1.0, help="soft tempture") 57 | 58 | # save path 59 | parser.add_argument('--save_path', type=str, default='./result/') 60 | args = parser.parse_args() 61 | 62 | torch.multiprocessing.set_start_method('spawn') 63 | runner = Runner(args) 64 | runner.run() 65 | 66 | 67 | -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/2s3z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/2s3z_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/2s3z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/2s3z_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/2s3z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/2s3z_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/2s3z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/2s3z_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/3s5z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/3s5z_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/3s5z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/3s5z_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/3s5z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/3s5z_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/3s5z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/3s5z_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/3s5z_vs_3s6z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/3s5z_vs_3s6z_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/3s5z_vs_3s6z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/3s5z_vs_3s6z_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/3s5z_vs_3s6z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/3s5z_vs_3s6z_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/3s5z_vs_3s6z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/3s5z_vs_3s6z_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/5m_vs_6m_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/5m_vs_6m_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/5m_vs_6m_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/5m_vs_6m_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/5m_vs_6m_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/5m_vs_6m_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/5m_vs_6m_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/5m_vs_6m_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/MMM2_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/MMM2_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/MMM2_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/MMM2_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/MMM2_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/MMM2_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/MMM2_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/MMM2_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/corridor_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/corridor_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/corridor_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/corridor_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/corridor_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/corridor_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/acorm/corridor_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/acorm/corridor_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/2s3z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/2s3z_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/2s3z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/2s3z_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/2s3z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/2s3z_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/2s3z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/2s3z_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/3s5z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/3s5z_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/3s5z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/3s5z_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/3s5z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/3s5z_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/3s5z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/3s5z_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/3s5z_vs_3s6z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/3s5z_vs_3s6z_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/3s5z_vs_3s6z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/3s5z_vs_3s6z_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/3s5z_vs_3s6z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/3s5z_vs_3s6z_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/3s5z_vs_3s6z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/3s5z_vs_3s6z_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/5m_vs_6m_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/5m_vs_6m_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/5m_vs_6m_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/5m_vs_6m_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/5m_vs_6m_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/5m_vs_6m_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/5m_vs_6m_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/5m_vs_6m_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/MMM2_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/MMM2_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/MMM2_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/MMM2_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/MMM2_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/MMM2_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/MMM2_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/MMM2_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/corridor_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/corridor_seed0.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/corridor_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/corridor_seed1.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/corridor_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/corridor_seed2.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/result/sacred/mappo/corridor_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/result/sacred/mappo/corridor_seed3.npy -------------------------------------------------------------------------------- /ACORM_MAPPO/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from smac.env import StarCraft2Env 4 | from algorithm.mappo import MAPPO 5 | from algorithm.acorm import ACORM 6 | from util.replay_buffer import ReplayBuffer 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | class Runner(object): 12 | def __init__(self, args): 13 | self.args = args 14 | self.env_name = self.args.env_name 15 | self.seed = self.args.seed 16 | # Set random seed 17 | np.random.seed(self.seed) 18 | torch.manual_seed(self.seed) 19 | # Create env 20 | self.env = StarCraft2Env(map_name=self.env_name, seed=self.seed) 21 | self.env_info = self.env.get_env_info() 22 | self.args.N = self.env_info["n_agents"] # The number of agents 23 | self.args.obs_dim = self.env_info["obs_shape"] # The dimensions of an agent's observation space 24 | self.args.state_dim = self.env_info["state_shape"] # The dimensions of global state space 25 | self.args.action_dim = self.env_info["n_actions"] # The dimensions of an agent's action space 26 | self.args.episode_limit = self.env_info["episode_limit"] # Maximum number of steps per episode 27 | print("number of agents={}".format(self.args.N)) 28 | print("obs_dim={}".format(self.args.obs_dim)) 29 | print("state_dim={}".format(self.args.state_dim)) 30 | print("action_dim={}".format(self.args.action_dim)) 31 | print("episode_limit={}".format(self.args.episode_limit)) 32 | self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu') 33 | self.save_path = args.save_path + args.algorithm +'/'+ args.env_name + '/' 34 | 35 | # create N agent 36 | if args.algorithm == 'mappo': 37 | self.agent_n = MAPPO(self.args) 38 | elif args.algorithm == 'acorm': 39 | self.agent_n = ACORM(self.args) 40 | self.replay_buffer = ReplayBuffer(self.args) 41 | 42 | self.win_rates = [] # Record the win rates 43 | self.evaluate_reward = [] 44 | self.total_steps = 0 45 | self.agent_embed_pretrain_epoch, self.recl_pretrain_epoch = 0, 0 46 | self.pretrain_agent_embed_loss, self.pretrain_recl_loss = [], [] 47 | 48 | def run(self, ): 49 | evaluate_num = -1 # Record the number of evaluations 50 | while self.total_steps < self.args.max_train_steps: 51 | if self.total_steps // self.args.evaluate_freq > evaluate_num: 52 | self.evaluate_policy() # Evaluate the policy every 'evaluate_freq' steps 53 | evaluate_num += 1 54 | 55 | _, _, episode_steps = self.run_episode_smac(evaluate=False) # Run an episode 56 | 57 | if self.agent_embed_pretrain_epoch < self.args.agent_embed_pretrain_epochs: # agent_embed_pretrain mode 58 | if self.replay_buffer.episode_num == self.args.batch_size: 59 | self.agent_embed_pretrain_epoch += 1 60 | for _ in range(1): 61 | agent_embedding_loss = self.agent_n.pretrain_agent_embedding(self.replay_buffer) 62 | self.pretrain_agent_embed_loss.append(agent_embedding_loss.item()) 63 | recl_loss = self.agent_n.pretrain_recl(self.replay_buffer) 64 | self.pretrain_recl_loss.append(recl_loss.item()) 65 | self.replay_buffer.reset_buffer() 66 | 67 | if self.agent_embed_pretrain_epoch >= self.args.agent_embed_pretrain_epochs: # plot loss 68 | sns.set_style('whitegrid') 69 | plt.figure() 70 | x_step = np.array(range(len(self.pretrain_agent_embed_loss))) 71 | ax = sns.lineplot(x=x_step, y=np.array(self.pretrain_agent_embed_loss).flatten(), label='agent_embedding_loss') 72 | plt.ylabel('loss', fontsize=14) 73 | plt.xlabel(f'step', fontsize=14) 74 | plt.title(f'agent_embedding network pretrain') 75 | plt.savefig(f'{self.save_path}/{self.env_name}_agent_loss_seed{self.seed}.jpg') 76 | plt.figure() 77 | x_step = np.array(range(len(self.pretrain_recl_loss))) 78 | ax = sns.lineplot(x=x_step, y=np.array(self.pretrain_recl_loss).flatten(), label='recl_loss') 79 | plt.ylabel('loss', fontsize=14) 80 | plt.xlabel(f'step', fontsize=14) 81 | plt.title(f'RECL network pretrain') 82 | plt.savefig(f'{self.save_path}/{self.env_name}_recl_loss_seed{self.seed}.jpg') 83 | print("pretrain_end!") 84 | 85 | # else: 86 | # if self.recl_pretrain_epoch < self.args.recl_pretrain_epochs: # recl_pretrain mode 87 | # self.recl_pretrain_epoch += 1 88 | # for _ in range(self.args.K_epochs): 89 | # recl_loss = self.agent_n.pretrain_recl(self.replay_buffer) 90 | # self.pretrain_recl_loss.append(recl_loss.item()) 91 | # self.replay_buffer.reset_buffer() 92 | # if self.recl_pretrain_epoch >= self.args.recl_pretrain_epochs: # plot loss 93 | # sns.set_style('whitegrid') 94 | # plt.figure() 95 | # x_step = np.array(range(len(self.pretrain_agent_embed_loss))) 96 | # ax = sns.lineplot(x=x_step, y=np.array(self.pretrain_agent_embed_loss).flatten(), label='agent_embedding_loss') 97 | # plt.ylabel('loss', fontsize=14) 98 | # plt.xlabel(f'step', fontsize=14) 99 | # plt.title(f'agent_embedding network pretrain') 100 | # plt.savefig(f'{self.save_path}/{self.env_name}_agent_loss_seed{self.seed}.jpg') 101 | # plt.figure() 102 | # x_step = np.array(range(len(self.pretrain_recl_loss))) 103 | # ax = sns.lineplot(x=x_step, y=np.array(self.pretrain_recl_loss).flatten(), label='recl_loss') 104 | # plt.ylabel('loss', fontsize=14) 105 | # plt.xlabel(f'step', fontsize=14) 106 | # plt.title(f'RECL network pretrain') 107 | # plt.savefig(f'{self.save_path}/{self.env_name}_recl_loss_seed{self.seed}.jpg') 108 | # print("pretrain_end!") 109 | else: 110 | self.total_steps += episode_steps 111 | if self.replay_buffer.episode_num == self.args.batch_size: 112 | actor_loss, critic_loss = self.agent_n.train(self.replay_buffer) 113 | self.replay_buffer.reset_buffer() 114 | 115 | self.evaluate_policy() 116 | self.env.close() 117 | 118 | def evaluate_policy(self, ): 119 | win_times = 0 120 | evaluate_reward = 0 121 | for _ in range(self.args.evaluate_times): 122 | win_tag, episode_reward, _ = self.run_episode_smac(evaluate=True) 123 | if win_tag: 124 | win_times += 1 125 | evaluate_reward += episode_reward 126 | win_rate = win_times / self.args.evaluate_times 127 | self.win_rates.append(win_rate) 128 | evaluate_reward = evaluate_reward / self.args.evaluate_times 129 | self.evaluate_reward.append(evaluate_reward) 130 | print("total_steps:{} \t win_rate:{} \t evaluate_reward:{}".format(self.total_steps, win_rate, evaluate_reward)) 131 | # # plot curve 132 | sns.set_style('whitegrid') 133 | plt.figure() 134 | x_step = np.array(range(len(self.win_rates))) 135 | ax = sns.lineplot(x=x_step, y=np.array(self.win_rates).flatten(), label=self.args.algorithm) 136 | plt.ylabel('win_rates', fontsize=14) 137 | plt.xlabel(f'step*{self.args.evaluate_freq}', fontsize=14) 138 | plt.title(f'{self.args.algorithm} on {self.env_name}') 139 | plt.savefig(f'{self.save_path}/{self.env_name}_seed{self.seed}.jpg') 140 | 141 | # Save the win rates 142 | np.save(f'{self.save_path}/{self.env_name}_seed{self.seed}.npy', np.array(self.win_rates)) 143 | np.save(f'{self.save_path}/{self.env_name}_seed{self.seed}_return.npy', np.array(self.evaluate_reward)) 144 | 145 | def run_episode_smac(self, evaluate=False): 146 | win_tag = False 147 | episode_reward = 0 148 | self.env.reset() 149 | self.agent_n.actor.embedding_net.agent_embedding_net.rnn_hidden = None 150 | self.agent_n.critic.state_gru_hidden = None 151 | self.agent_n.critic.obs_gru_hidden = None 152 | for episode_step in range(self.args.episode_limit): 153 | obs_n = self.env.get_obs() # obs_n.shape=(N, obs_dim) 154 | temp_obs_n = obs_n 155 | s = self.env.get_state() # s.shape=(state_dim,) 156 | avail_a_n = self.env.get_avail_actions() # avail_a_n 157 | 158 | temp_obs_n = torch.tensor(np.array(temp_obs_n),dtype=torch.float32).to(self.device) 159 | agent_embedding = self.agent_n.actor.embedding_net.agent_embed_forward(temp_obs_n, detach=True) # (N, agent_embed_dim) 160 | role_embedding = self.agent_n.actor.embedding_net.role_embed_foward(agent_embedding, detach=True, ema=False) # (N, role_embed_dim) 161 | 162 | a_n, a_logprob_n = self.agent_n.choose_action(agent_embedding, role_embedding, avail_a_n, evaluate=evaluate) 163 | r, done, info = self.env.step(a_n) 164 | win_tag = True if done and 'battle_won' in info and info['battle_won'] else False 165 | episode_reward += r 166 | 167 | if not evaluate: 168 | if done and episode_step + 1 != self.args.episode_limit: 169 | dw = True 170 | else: 171 | dw = False 172 | v_n = self.agent_n.get_value(s, temp_obs_n, role_embedding.unsqueeze(0)) # Get the state values (V(s)) of N agents 173 | # Store the transition 174 | self.replay_buffer.store_transition(episode_step, obs_n, s, v_n, avail_a_n, a_n, a_logprob_n, r, dw) 175 | 176 | if done: 177 | break 178 | 179 | if not evaluate: 180 | # An episode is over, store obs_n, s and avail_a_n in the last step 181 | obs_n = self.env.get_obs() 182 | s = self.env.get_state() 183 | 184 | obs_n = torch.tensor(np.array(obs_n),dtype=torch.float32).to(self.device) 185 | agent_embedding = self.agent_n.actor.embedding_net.agent_embed_forward(obs_n, detach=True) # (N, agent_embed_dim) 186 | role_embedding = self.agent_n.actor.embedding_net.role_embed_foward(agent_embedding, detach=True, ema=False) # (N, role_embed_dim) 187 | v_n = self.agent_n.get_value(s, obs_n, role_embedding.unsqueeze(0)) 188 | self.replay_buffer.store_last_value(episode_step+1, v_n) 189 | 190 | return win_tag, episode_reward, episode_step+1 -------------------------------------------------------------------------------- /ACORM_MAPPO/run.sh: -------------------------------------------------------------------------------- 1 | nohup python main.py --algorithm 'acorm' \ 2 | --env_name 'MMM2' \ 3 | --seed 3 \ 4 | --device 'cuda:1' \ 5 | --max_train_steps 3000000 \ 6 | --actor_lr +6e-4 \ 7 | --critic_lr +8e-4 \ 8 | --cl_lr +5e-4 \ 9 | --batch_size 32 \ 10 | --mini_batch_size 32 \ 11 | --agent_embedding_dim 64 \ 12 | --role_embedding_dim 32 \ 13 | --rnn_hidden_dim 64 \ 14 | --gamma 0.99 \ 15 | --lamda 0.95 \ 16 | --clip_epsilon 0.2 \ 17 | --K_epochs 5 \ 18 | --entropy_coef 0.015 \ 19 | --add_agent_id False \ 20 | --use_adv_norm True \ 21 | --use_grad_clip False \ 22 | --use_orthogonal_init False \ 23 | --use_lr_decay False \ 24 | --cluster_num 3 \ 25 | --train_recl_freq 16 \ 26 | --multi_steps 1 \ 27 | --tau 0.005 \ 28 | --att_dim 128 \ 29 | --att_out_dim 64 \ 30 | --n_heads 4 \ 31 | --soft_temp 1.0 \ 32 | & 33 | 34 | 35 | -------------------------------------------------------------------------------- /ACORM_MAPPO/util/__pycache__/acorm_net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/util/__pycache__/acorm_net.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_MAPPO/util/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/util/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_MAPPO/util/__pycache__/net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/util/__pycache__/net.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_MAPPO/util/__pycache__/replay_buffer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_MAPPO/util/__pycache__/replay_buffer.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_MAPPO/util/acorm_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from util.attention import MultiHeadAttention 4 | 5 | class Agent_Embedding(nn.Module): 6 | def __init__(self, args): 7 | super(Agent_Embedding, self).__init__() 8 | self.input_dim = args.obs_dim 9 | self.agent_embedding_dim = args.agent_embedding_dim 10 | 11 | self.fc1 = nn.Linear(self.input_dim, self.input_dim) 12 | self.rnn_hidden = None 13 | self.rnn_fc = nn.GRUCell(self.input_dim, self.agent_embedding_dim) 14 | 15 | def forward(self, obs, detach=False): 16 | inputs = obs.view(-1, self.input_dim) 17 | fc1_out = torch.relu(self.fc1(inputs)) 18 | self.rnn_hidden = self.rnn_fc(fc1_out, self.rnn_hidden) 19 | fc2_out = self.rnn_hidden 20 | if detach: 21 | fc2_out.detach() 22 | return fc2_out 23 | 24 | class Agent_Embedding_Decoder(nn.Module): 25 | def __init__(self, args): 26 | super(Agent_Embedding_Decoder, self).__init__() 27 | self.agent_embedding_dim = args.agent_embedding_dim 28 | self.decoder_out_dim = args.obs_dim + args.N # out_put: o(t+1)+agent_idx 29 | 30 | self.fc1 = nn.Linear(self.agent_embedding_dim, self.agent_embedding_dim) 31 | self.fc2 = nn.Linear(self.agent_embedding_dim, self.decoder_out_dim) 32 | 33 | def forward(self, agent_embedding): 34 | fc1_out = torch.relu(self.fc1(agent_embedding)) 35 | decoder_out = self.fc2(fc1_out) 36 | return decoder_out 37 | 38 | 39 | class Role_Embedding(nn.Module): 40 | def __init__(self, args): 41 | super(Role_Embedding, self).__init__() 42 | self.agent_embedding_dim = args.agent_embedding_dim 43 | self.role_embedding_dim = args.role_embedding_dim 44 | self.encoder = nn.ModuleList([nn.Linear(self.agent_embedding_dim, self.agent_embedding_dim), 45 | nn.Linear(self.agent_embedding_dim, self.role_embedding_dim)]) 46 | self.target_encoder = nn.ModuleList([nn.Linear(self.agent_embedding_dim, self.agent_embedding_dim), 47 | nn.Linear(self.agent_embedding_dim, self.role_embedding_dim)]) 48 | 49 | self.target_encoder.load_state_dict(self.encoder.state_dict()) 50 | 51 | def forward(self, agent_embedding, detach=False, ema=False): 52 | if ema: # target encoder 53 | output = torch.relu(self.target_encoder[0](agent_embedding)) 54 | output = self.target_encoder[1](output) 55 | else: # encoder 56 | output = torch.relu(self.encoder[0](agent_embedding)) 57 | output = self.encoder[1](output) 58 | 59 | if detach: 60 | output.detach() 61 | return output 62 | 63 | class Embedding_Net(nn.Module): 64 | def __init__(self, args): 65 | super(Embedding_Net, self).__init__() 66 | self.agent_embedding_net = Agent_Embedding(args) 67 | self.agent_embedding_decoder = Agent_Embedding_Decoder(args) 68 | self.role_embedding_net = Role_Embedding(args) 69 | self.W = nn.Parameter(torch.rand(args.role_embedding_dim, args.role_embedding_dim)) 70 | 71 | def role_embed_foward(self, agent_embedding, detach=False, ema=False): 72 | return self.role_embedding_net(agent_embedding, detach, ema) 73 | 74 | def agent_embed_forward(self, obs, detach=False): 75 | return self.agent_embedding_net(obs, detach) 76 | 77 | def encoder_decoder_forward(self, obs): 78 | agent_embedding = self.agent_embed_forward(obs, detach=False) 79 | decoder_out = self.agent_embedding_decoder(agent_embedding) 80 | return decoder_out 81 | 82 | class ACORM_Actor(nn.Module): 83 | def __init__(self, args): 84 | super(ACORM_Actor, self).__init__() 85 | self.args = args 86 | self.embedding_net = Embedding_Net(args) 87 | self.actor_input_dim = args.agent_embedding_dim + args.role_embedding_dim 88 | self.actor_net = nn.ModuleList([nn.Linear(self.actor_input_dim, self.actor_input_dim), 89 | nn.Linear(self.actor_input_dim, args.action_dim)]) 90 | 91 | def actor_forward(self, agent_embedding, role_embedding, avail_a_n): 92 | actor_input = torch.cat([agent_embedding, role_embedding], dim=-1) 93 | output = torch.relu(self.actor_net[0](actor_input)) 94 | output = self.actor_net[1](output) 95 | output[avail_a_n==0] = -1e10 # mask the unavailable action 96 | prob = torch.softmax(output, dim=-1) 97 | return prob 98 | 99 | def forward(self, obs, avail_a_n): 100 | agent_embed = self.embedding_net.agent_embed_forward(obs, detach=True) 101 | role_embed = self.embedding_net.role_embed_foward(agent_embed, detach=True, ema=False) 102 | prob = self.actor_forward(agent_embed, role_embed, avail_a_n) 103 | return prob 104 | 105 | 106 | class ACORM_Critic(nn.Module): 107 | def __init__(self, args): 108 | super(ACORM_Critic, self).__init__() 109 | self.N = args.N 110 | self.att_out_dim = args.att_out_dim 111 | self.rnn_hidden_dim = args.rnn_hidden_dim 112 | self.state_dim = args.state_dim 113 | 114 | self.state_gru_net = nn.ModuleList([nn.Linear(args.state_dim, args.state_dim), 115 | nn.GRUCell(args.state_dim, args.N*args.rnn_hidden_dim)]) 116 | self.state_gru_hidden = None 117 | self.attention_net = MultiHeadAttention(args.n_heads, args.att_dim, args.att_out_dim, args.soft_temp, 118 | args.rnn_hidden_dim, args.role_embedding_dim, args.role_embedding_dim) 119 | self.obs_gru_net = nn.ModuleList([nn.Linear(args.obs_dim, args.rnn_hidden_dim), 120 | nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim)]) 121 | self.obs_gru_hidden = None 122 | self.fc_final = nn.ModuleList([nn.Linear(args.rnn_hidden_dim+self.N*args.rnn_hidden_dim+self.N*args.att_out_dim, 2*args.rnn_hidden_dim), 123 | nn.Linear(2*args.rnn_hidden_dim, 1)]) 124 | 125 | def state_forward(self, state): 126 | fc_out = torch.relu(self.state_gru_net[0](state)) 127 | self.state_gru_hidden = self.state_gru_net[1](fc_out, self.state_gru_hidden) 128 | return self.state_gru_hidden # (batch, rnn_dim) 129 | 130 | def att_forward(self, tau_s, role_embeddings): 131 | # tau_s.shape=(batch, N, rnn_hidden_dim), role_embeddings.shape=(batch, N, role_dim) 132 | return self.attention_net(tau_s, role_embeddings, role_embeddings) # output.shape=(batch, N, att_out_dim) 133 | 134 | def obs_forward(self, obs): 135 | # x = torch.cat([obs],dim=-1) 136 | x = torch.relu(self.obs_gru_net[0](obs)) 137 | self.obs_gru_hidden = self.obs_gru_net[1](x, self.obs_gru_hidden) 138 | return self.obs_gru_hidden # (batch*N, rnn_dim) 139 | 140 | def critic_forward(self, tau_obs, tau_s, att_out): 141 | x = torch.cat([tau_obs, tau_s, att_out], dim=-1) 142 | x = torch.relu(self.fc_final[0](x)) 143 | value = self.fc_final[1](x) 144 | return value 145 | 146 | def forward(self, obs, state, role_embeding): 147 | tau_s = self.state_forward(state) # (batch, state_dim)-> (batch, N*rnn_dim) 148 | tau_s = tau_s.reshape(-1, self.N, self.rnn_hidden_dim) # (batch, N, rnn_dim) 149 | att = self.att_forward(tau_s, role_embeding).unsqueeze(1).repeat(1,self.N,1,1).reshape(-1, self.N*self.att_out_dim) # (batch, N, att_out_dim)->(batch,N,N,att_out_dim)->(batch*N, att_out_dim) 150 | # tau_obs = self.obs_forward(obs, state.unsqueeze(1).repeat(1,self.N,1).reshape(-1,self.state_dim)) # (batch*N, obs_dim) 151 | tau_obs = self.obs_forward(obs) # (batch*N, obs_dim) 152 | value = self.critic_forward(tau_obs, tau_s.unsqueeze(1).repeat(1,self.N,1,1).reshape(-1,self.N*self.rnn_hidden_dim), att) # (batch*N, 1) 153 | return value -------------------------------------------------------------------------------- /ACORM_MAPPO/util/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MultiHeadAttention(nn.Module): 6 | def __init__(self, n_heads, att_dim, att_out_dim, soft_temp, dim_q, dim_k, dim_v): 7 | super(MultiHeadAttention, self).__init__() 8 | assert (att_dim % n_heads) == 0, "n_heads must divide att_dim" 9 | self.n_heads = n_heads 10 | self.att_dim = att_dim 11 | self.head_att_dim = att_dim // n_heads 12 | self.att_out_dim = att_out_dim 13 | self.temperature = self.head_att_dim ** 0.5 / soft_temp 14 | 15 | self.fc_q = nn.Linear(dim_q, self.att_dim, bias=False) 16 | self.fc_k = nn.Linear(dim_k, self.att_dim, bias=False) 17 | self.fc_v = nn.Linear(dim_v, self.att_dim) 18 | self.fc_final = nn.Linear(self.att_dim, self.att_out_dim) 19 | 20 | def forward(self, q, k, v): 21 | # q.shape = (batch, N, dim) 22 | batch_size = q.shape[0] 23 | # shape = (batch*N, att_dim)->(batch, N, heads, head_att_dim)->(batch, heads, N, head_att_dim) 24 | q = self.fc_q(q.view(-1, q.shape[2])).view(batch_size, -1, self.n_heads, self.head_att_dim).transpose(1, 2) 25 | # shape = (batch*N, att_dim)->(batch, N, heads, head_att_dim)->(batch, heads, head_att_dim, N) 26 | k_T = self.fc_k(k.view(-1, k.shape[2])).view(batch_size, -1, self.n_heads, self.head_att_dim).permute(0,2,3,1) 27 | v = self.fc_v(v.view(-1, v.shape[2])).view(batch_size, -1, self.n_heads, self.head_att_dim).transpose(1, 2) 28 | alpha = F.softmax(torch.matmul(q/self.temperature, k_T), dim=-1) # shape = (batch, heads, N, N) 29 | # shape = (batch, heads, N, head_att_dim)->(batch, N, heads, head_att_dim)->(batch, N, att_dim) 30 | result = torch.matmul(alpha, v).transpose(1, 2).reshape(batch_size, -1, self.att_dim) 31 | result = self.fc_final(result) # shape = (batch, N, att_out_dim) 32 | return result 33 | -------------------------------------------------------------------------------- /ACORM_MAPPO/util/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def orthogonal_init(layer, gain=1.0): 5 | for name, param in layer.named_parameters(): 6 | if 'bias' in name: 7 | nn.init.constant_(param, 0) 8 | elif 'weight' in name: 9 | nn.init.orthogonal_(param, gain=gain) 10 | 11 | class Actor(nn.Module): 12 | def __init__(self, args, input_dim): 13 | super(Actor, self).__init__() 14 | self. rnn_hidden = None 15 | self.fc1 = nn.Linear(input_dim, args.rnn_hidden_dim) 16 | self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) 17 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.action_dim) 18 | if args.use_orthogonal_init: 19 | orthogonal_init(self.fc1) 20 | orthogonal_init(self.rnn) 21 | orthogonal_init(self.fc2, gain=0.01) 22 | 23 | def forward(self, input, avail_a_n): 24 | # When 'choose_action': input.shape=(N, input_dim), prob.shape=(N, action_dim) 25 | # When 'train': input.shape=(batch*N, input_dim),prob.shape=(batch*N, action_dim) 26 | x = torch.relu(self.fc1(input)) 27 | self.rnn_hidden = self.rnn(x, self.rnn_hidden) 28 | x = self.fc2(self.rnn_hidden) 29 | x[avail_a_n==0] = -1e10 # mask the unavailable actions 30 | prob = torch.softmax(x, dim=-1) 31 | return prob 32 | 33 | 34 | class Critic(nn.Module): 35 | def __init__(self, args, input_dim): 36 | super(Critic, self).__init__() 37 | self.rnn_hidden = None 38 | self.fc1 = nn.Linear(input_dim, args.rnn_hidden_dim) 39 | self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) 40 | self.fc2 = nn.Linear(args.rnn_hidden_dim, 1) 41 | if args.use_orthogonal_init: 42 | orthogonal_init(self.fc1) 43 | orthogonal_init(self.rnn) 44 | orthogonal_init(self.fc2) 45 | 46 | def forward(self, input): 47 | # When 'get_value': input.shape=(N, input_dim), value.shape=(N, 1) 48 | # When 'train': input.shape=(batch*N, input_dim), value.shape=(batch_size*N, 1) 49 | x = torch.relu(self.fc1(input)) 50 | self.rnn_hidden = self.rnn(x, self.rnn_hidden) 51 | value = self.fc2(self.rnn_hidden) 52 | return value 53 | -------------------------------------------------------------------------------- /ACORM_MAPPO/util/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer: 6 | def __init__(self, args): 7 | self.N = args.N 8 | self.obs_dim = args.obs_dim 9 | self.state_dim = args.state_dim 10 | self.action_dim = args.action_dim 11 | self.episode_limit = args.episode_limit 12 | self.batch_size = args.batch_size 13 | self.episode_num = 0 14 | self.max_episode_len = 0 15 | 16 | self.buffer = {'obs_n': np.zeros([self.batch_size, self.episode_limit, self.N, self.obs_dim]), 17 | 's': np.zeros([self.batch_size, self.episode_limit, self.state_dim]), 18 | 'v_n': np.zeros([self.batch_size, self.episode_limit+1, self.N]), 19 | 'avail_a_n': np.ones([self.batch_size, self.episode_limit, self.N, self.action_dim]), 20 | 'a_n': np.zeros([self.batch_size, self.episode_limit, self.N]), 21 | 'a_logprob_n': np.zeros([self.batch_size, self.episode_limit, self.N]), 22 | 'r': np.zeros([self.batch_size, self.episode_limit, self.N]), # repeat N 23 | 'dw': np.ones([self.batch_size, self.episode_limit, self.N]), 24 | 'active': np.zeros([self.batch_size, self.episode_limit, self.N]) 25 | } 26 | 27 | def reset_buffer(self): 28 | self.buffer['active'] = np.zeros([self.batch_size, self.episode_limit, self.N]) 29 | self.episode_num = 0 30 | self.max_episode_len = 0 31 | 32 | def store_transition(self, episode_step, obs_n, s, v_n, avail_a_n, a_n, a_logprob_n, r, dw): 33 | self.buffer['obs_n'][self.episode_num][episode_step] = obs_n 34 | self.buffer['s'][self.episode_num][episode_step] = s 35 | self.buffer['v_n'][self.episode_num][episode_step] = v_n 36 | self.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_n 37 | self.buffer['a_n'][self.episode_num][episode_step] = a_n 38 | self.buffer['a_logprob_n'][self.episode_num][episode_step] = a_logprob_n 39 | self.buffer['r'][self.episode_num][episode_step] = np.array(r).repeat(self.N) 40 | self.buffer['dw'][self.episode_num][episode_step] = np.array(dw).repeat(self.N) 41 | 42 | self.buffer['active'][self.episode_num][episode_step] = np.ones(self.N) 43 | 44 | def store_last_value(self, episode_step, v_n): 45 | self.buffer['v_n'][self.episode_num][episode_step] = v_n 46 | self.episode_num += 1 47 | # Record max_episode_len 48 | if episode_step > self.max_episode_len: 49 | self.max_episode_len = episode_step 50 | 51 | def get_training_data(self): 52 | batch = {} 53 | for key in self.buffer.keys(): 54 | if key == 'a_n': 55 | batch[key] = torch.tensor(self.buffer[key][:, :self.max_episode_len], dtype=torch.long) 56 | elif key == 'v_n': 57 | batch[key] = torch.tensor(self.buffer[key][:, :self.max_episode_len + 1], dtype=torch.float32) 58 | else: 59 | batch[key] = torch.tensor(self.buffer[key][:, :self.max_episode_len], dtype=torch.float32) 60 | return batch -------------------------------------------------------------------------------- /ACORM_QMIX.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX.jpg -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/__pycache__/acorm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/algorithm/__pycache__/acorm.cpython-310.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/__pycache__/acorm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/algorithm/__pycache__/acorm.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/__pycache__/recl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/algorithm/__pycache__/recl.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/__pycache__/vdn_qmix.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/algorithm/__pycache__/vdn_qmix.cpython-310.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/__pycache__/vdn_qmix.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/algorithm/__pycache__/vdn_qmix.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/__pycache__/vdn_qmix_reuse.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/algorithm/__pycache__/vdn_qmix_reuse.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/__pycache__/vdn_qmix_reuse_v1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/algorithm/__pycache__/vdn_qmix_reuse_v1.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/__pycache__/vdn_qmix_reuse_v2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/algorithm/__pycache__/vdn_qmix_reuse_v2.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/acorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from util.net import * 5 | from util.attention import MultiHeadAttention 6 | # from algorithm.vdn_qmix import QMIX_Net 7 | import numpy as np 8 | import copy 9 | from sklearn.cluster import KMeans 10 | from torch.optim.lr_scheduler import StepLR 11 | # from kmeans_pytorch import kmeans 12 | 13 | 14 | class RECL_MIX(nn.Module): 15 | def __init__(self, args): 16 | super(RECL_MIX, self).__init__() 17 | self.args = args 18 | self.N = args.N 19 | self.state_dim = args.state_dim 20 | self.mix_input_dim = args.state_dim + args.N * args.att_out_dim 21 | self.batch_size = args.batch_size 22 | self.qmix_hidden_dim = args.qmix_hidden_dim 23 | self.hyper_hidden_dim = args.hyper_hidden_dim 24 | self.hyper_layers_num = args.hyper_layers_num 25 | 26 | self.state_fc = nn.Linear(args.state_dim, args.state_dim) 27 | self.state_gru = nn.GRUCell(args.state_dim, args.N*args.state_embed_dim) 28 | self.state_gru_hidden = None 29 | self.dim_q = args.state_embed_dim 30 | self.attention_net = MultiHeadAttention(args.n_heads, args.att_dim, args.att_out_dim, args.soft_temperature, self.dim_q,args.role_embedding_dim, args.role_embedding_dim) 31 | 32 | 33 | """ 34 | w1:(N, qmix_hidden_dim) 35 | b1:(1, qmix_hidden_dim) 36 | w2:(qmix_hidden_dim, 1) 37 | b2:(1, 1) 38 | 39 | """ 40 | if self.hyper_layers_num == 2: 41 | print("hyper_layers_num=2") 42 | self.hyper_w1 = nn.Sequential(nn.Linear(self.mix_input_dim, self.hyper_hidden_dim), 43 | nn.ReLU(), 44 | nn.Linear(self.hyper_hidden_dim, self.N * self.qmix_hidden_dim)) 45 | self.hyper_w2 = nn.Sequential(nn.Linear(self.mix_input_dim, self.hyper_hidden_dim), 46 | nn.ReLU(), 47 | nn.Linear(self.hyper_hidden_dim, self.qmix_hidden_dim * 1)) 48 | elif self.hyper_layers_num == 1: 49 | print("hyper_layers_num=1") 50 | self.hyper_w1 = nn.Linear(self.mix_input_dim, self.N * self.qmix_hidden_dim) 51 | self.hyper_w2 = nn.Linear(self.mix_input_dim, self.qmix_hidden_dim * 1) 52 | else: 53 | print("wrong!!!") 54 | 55 | self.hyper_b1 = nn.Linear(self.mix_input_dim, self.qmix_hidden_dim) 56 | self.hyper_b2 = nn.Sequential(nn.Linear(self.mix_input_dim, self.qmix_hidden_dim), 57 | nn.ReLU(), 58 | nn.Linear(self.qmix_hidden_dim, 1)) 59 | 60 | def role_gru_forward(self, role_embeddings): 61 | # role_embeddings.shape = (batch_size, N*role_embedding_dim) 62 | self.role_gru_hidden = self.role_gru(role_embeddings, self.role_gru_hidden) 63 | output = torch.sigmoid(self.role_gru_hidden) 64 | return output 65 | 66 | def forward(self, q, s, att): 67 | # q.shape(batch_size, max_episode_len, N) 68 | # s.shape(batch_size, max_episode_len,state_dim) 69 | 70 | q = q.view(-1, 1, self.N) # (batch_size * max_episode_len, 1, N) 71 | s = s.reshape(-1, self.state_dim) # (batch_size * max_episode_len, state_dim) 72 | att = att.reshape(-1, att.shape[2]) 73 | state = torch.cat([s, att], dim=-1) 74 | 75 | w1 = torch.abs(self.hyper_w1(state)) # (batch_size * max_episode_len, N * qmix_hidden_dim) 76 | b1 = self.hyper_b1(state) # (batch_size * max_episode_len, qmix_hidden_dim) 77 | w1 = w1.view(-1, self.N, self.qmix_hidden_dim) # (batch_size * max_episode_len, N, qmix_hidden_dim) 78 | b1 = b1.view(-1, 1, self.qmix_hidden_dim) # (batch_size * max_episode_len, 1, qmix_hidden_dim) 79 | 80 | # torch.bmm: 3 dimensional tensor multiplication 81 | q_hidden = F.elu(torch.bmm(q, w1) + b1) # (batch_size * max_episode_len, 1, qmix_hidden_dim) 82 | 83 | w2 = torch.abs(self.hyper_w2(state)) # (batch_size * max_episode_len, qmix_hidden_dim * 1) 84 | b2 = self.hyper_b2(state) # (batch_size * max_episode_len,1) 85 | w2 = w2.view(-1, self.qmix_hidden_dim, 1) # (b\atch_size * max_episode_len, qmix_hidden_dim, 1) 86 | b2 = b2.view(-1, 1, 1) # (batch_size * max_episode_len, 1, 1) 87 | 88 | q_total = torch.bmm(q_hidden, w2) + b2 # (batch_size * max_episode_len, 1, 1) 89 | q_total = q_total.view(self.batch_size, -1, 1) # (batch_size, max_episode_len, 1) 90 | return q_total 91 | 92 | class RECL_NET(nn.Module): 93 | def __init__(self, args): 94 | super(RECL_NET, self).__init__() 95 | 96 | self.N = args.N 97 | self.agent_embedding_dim = args.agent_embedding_dim 98 | self.role_embedding_dim = args.role_embedding_dim 99 | self.action_dim = args.action_dim 100 | self.obs_dim = args.obs_dim 101 | 102 | self.agent_embedding_net = Agent_Embedding(args) 103 | self.agent_embedding_decoder = Agent_Embedding_Decoder(args) 104 | self.role_embedding_net = Role_Embedding(args) 105 | self.role_embedding_target_net = Role_Embedding(args) 106 | self.role_embedding_target_net.load_state_dict(self.role_embedding_net.state_dict()) 107 | 108 | self.W = nn.Parameter(torch.rand(self.role_embedding_dim, self.role_embedding_dim)) 109 | 110 | def forward(self, obs, action, detach=False): 111 | agent_embedding = self.agent_embedding_net(obs, action, detach) 112 | role_embedding = self.role_embedding_net(agent_embedding) 113 | return role_embedding 114 | 115 | def encoder_decoder_forward(self, obs, action): 116 | agent_embedding = self.agent_embedding_forward(obs, action, detach=False) 117 | decoder_out = self.agent_embedding_decoder(agent_embedding) 118 | return decoder_out 119 | 120 | def agent_embedding_forward(self, obs, action, detach=False): 121 | return self.agent_embedding_net(obs, action, detach) 122 | 123 | def role_embedding_forward(self, agent_embedding, detach=False, ema=False): 124 | if ema: 125 | output = self.role_embedding_target_net(agent_embedding, detach) 126 | else: 127 | output = self.role_embedding_net(agent_embedding) 128 | return output 129 | 130 | def batch_role_embed_forward(self, batch_o, batch_a, max_episode_len, detach=False): 131 | self.agent_embedding_net.rnn_hidden = None 132 | agent_embeddings = [] 133 | for t in range(max_episode_len+1): # t = 0,1,2...(max_episode_len-1), max_episode_len 134 | agent_embedding = self.agent_embedding_forward(batch_o[:, t].reshape(-1, self.obs_dim), 135 | batch_a[:, t].reshape(-1, self.action_dim), 136 | detach=detach) # agent_embedding.shape=(batch_size*N, agent_embed_dim) 137 | agent_embedding = agent_embedding.reshape(batch_o.shape[0], self.N, -1) # shape=(batch_size,N, agent_embed_dim) 138 | agent_embeddings.append(agent_embedding.reshape(batch_o.shape[0],self.N, -1)) 139 | # Stack them according to the time (dim=1) 140 | agent_embeddings = torch.stack(agent_embeddings, dim=1).reshape(-1,self.agent_embedding_dim) # agent_embeddings.shape=(batch_size*(max_episode_len+1)*N, agent_embed_dim) 141 | role_embeddings = self.role_embedding_forward(agent_embeddings, detach=False, ema=False).reshape(-1, max_episode_len+1, self.N, self.role_embedding_dim) 142 | agent_embeddings = agent_embeddings.reshape(-1, max_episode_len+1, self.N, self.agent_embedding_dim) 143 | return agent_embeddings, role_embeddings 144 | 145 | class ACORM_Agent(object): 146 | def __init__(self, args): 147 | self.args = args 148 | self.N = args.N 149 | self.action_dim = args.action_dim 150 | self.obs_dim = args.obs_dim 151 | self.state_dim = args.state_dim 152 | self.role_embedding_dim = args.role_embedding_dim 153 | self.agent_embedding_dim = args.agent_embedding_dim 154 | self.att_out_dim = args.att_out_dim 155 | self.cluster_num = args.cluster_num 156 | self.add_last_action = args.add_last_action 157 | self.max_train_steps=args.max_train_steps 158 | self.lr = args.lr 159 | self.recl_lr = args.recl_lr 160 | self.agent_embedding_lr = args.agent_embedding_lr 161 | self.gamma = args.gamma 162 | 163 | self.batch_size = args.batch_size 164 | self.multi_steps = args.multi_steps 165 | self.target_update_freq = args.target_update_freq 166 | self.train_recl_freq = args.train_recl_freq 167 | self.tau = args.tau 168 | self.role_tau = args.role_tau 169 | self.use_hard_update = args.use_hard_update 170 | self.use_lr_decay = args.use_lr_decay 171 | self.lr_decay_steps = args.lr_decay_steps 172 | self.lr_decay_rate = args.lr_decay_rate 173 | self.algorithm = args.algorithm 174 | self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu') 175 | self.QMIX_input_dim = args.obs_dim 176 | if self.add_last_action: 177 | print("------add last action------") 178 | self.QMIX_input_dim += self.action_dim 179 | self.QMIX_input_dim += self.role_embedding_dim 180 | 181 | self.RECL = RECL_NET(args) 182 | # self.agent_embedding_optimizer = torch.optim.Adam(self.RECL.agent_embedding_net.parameters(), lr=self.recl_lr) 183 | self.role_parameters = list(self.RECL.role_embedding_net.parameters()) + list(self.RECL.agent_embedding_net.parameters()) 184 | self.role_embedding_optimizer = torch.optim.Adam(self.role_parameters, lr=self.lr) 185 | self.role_lr_decay = StepLR(self.role_embedding_optimizer, step_size=self.lr_decay_steps, gamma=self.lr_decay_rate) 186 | self.RECL_parameters = list(self.RECL.parameters()) 187 | self.RECL_optimizer = torch.optim.Adam(self.RECL_parameters, lr=self.recl_lr) 188 | self.encoder_decoder_para = list(self.RECL.agent_embedding_net.parameters()) + list(self.RECL.agent_embedding_decoder.parameters()) 189 | self.encoder_decoder_optimizer = torch.optim.Adam(self.encoder_decoder_para, lr=self.agent_embedding_lr) 190 | 191 | self.eval_Q_net = Q_network_RNN(args, self.QMIX_input_dim) 192 | self.target_Q_net = Q_network_RNN(args, self.QMIX_input_dim) 193 | self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict()) 194 | 195 | self.eval_mix_net = RECL_MIX(args) 196 | self.target_mix_net = RECL_MIX(args) 197 | self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict()) 198 | 199 | self.eval_parameters = list(self.eval_mix_net.parameters()) + list(self.eval_Q_net.parameters()) 200 | self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr) 201 | self.qmix_lr_decay = StepLR(self.optimizer, step_size=self.lr_decay_steps, gamma=self.lr_decay_rate) 202 | 203 | self.target_Q_net.to(self.device) 204 | self.eval_Q_net.to(self.device) 205 | self.target_mix_net.to(self.device) 206 | self.eval_mix_net.to(self.device) 207 | self.RECL.to(self.device) 208 | 209 | self.train_step = 0 210 | 211 | def get_role_embedding(self, obs_n, last_a): 212 | recl_obs = torch.tensor(np.array(obs_n), dtype=torch.float32).to(self.device) 213 | recl_last_a = torch.tensor(np.array(last_a), dtype=torch.float32).to(self.device) 214 | role_embedding = self.RECL(recl_obs, recl_last_a, detach=True) 215 | return role_embedding 216 | 217 | def choose_action(self, obs_n, last_onehot_a_n, role_embedding, avail_a_n, epsilon): 218 | with torch.no_grad(): 219 | if np.random.uniform() < epsilon: # epsilon-greedy 220 | # Only available actions can be chosen 221 | a_n = [np.random.choice(np.nonzero(avail_a)[0]) for avail_a in avail_a_n] 222 | else: 223 | inputs = copy.deepcopy(obs_n) 224 | if self.add_last_action: 225 | inputs = np.hstack((inputs, last_onehot_a_n)) 226 | inputs = np.hstack((inputs, role_embedding.to('cpu'))) 227 | inputs = torch.tensor(inputs, dtype=torch.float32) 228 | inputs = inputs.to(self.device) 229 | 230 | q_value = self.eval_Q_net(inputs) 231 | avail_a_n = torch.tensor(avail_a_n, dtype=torch.float32) # avail_a_n.shape=(N, action_dim) 232 | q_value = q_value.to('cpu') 233 | q_value[avail_a_n == 0] = -float('inf') # Mask the unavailable actions 234 | 235 | a_n = q_value.argmax(dim=-1).numpy() 236 | return a_n 237 | 238 | def get_inputs(self, batch): 239 | inputs = copy.deepcopy(batch['obs_n']) 240 | if self.add_last_action: 241 | inputs = np.concatenate((inputs, batch['last_onehot_a_n']),axis=-1) 242 | inputs = torch.tensor(inputs, dtype=torch.float32) 243 | 244 | inputs = inputs.to(self.device) 245 | batch_o = batch['obs_n'].to(self.device) 246 | batch_s = batch['s'].to(self.device) 247 | batch_r = batch['r'].to(self.device) 248 | batch_a = batch['a_n'].to(self.device) 249 | batch_last_a = batch['last_onehot_a_n'].to(self.device) 250 | batch_active = batch['active'].to(self.device) 251 | batch_dw = batch['dw'].to(self.device) 252 | batch_avail_a_n = batch['avail_a_n'] 253 | return inputs, batch_o, batch_s, batch_r, batch_a, batch_last_a, batch_avail_a_n, batch_active, batch_dw 254 | 255 | def train(self, replay_buffer): 256 | self.train_step += 1 257 | batch, max_episode_len = replay_buffer.sample(self.batch_size) # Get training data 258 | inputs, batch_o, batch_s, batch_r, batch_a, batch_last_a, batch_avail_a_n, batch_active, batch_dw = self.get_inputs(batch) 259 | 260 | if self.train_step % self.train_recl_freq == 0: 261 | self.update_recl(batch_o, batch_last_a, batch_active, max_episode_len) 262 | self.soft_update_params(self.RECL.role_embedding_net, self.RECL.role_embedding_target_net, self.role_tau) 263 | 264 | self.update_qmix(inputs, batch_o, batch_s, batch_r, batch_a, batch_last_a, batch_avail_a_n, batch_active, batch_dw, max_episode_len) 265 | if self.use_hard_update: 266 | # hard update 267 | if self.train_step % self.target_update_freq == 0: 268 | self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict()) 269 | self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict()) 270 | else: 271 | # Softly update the target networks 272 | self.soft_update_params(self.eval_Q_net, self.target_Q_net, self.tau) 273 | self.soft_update_params(self.eval_mix_net, self.target_mix_net, self.tau) 274 | self.soft_update_params(self.RECL.role_embedding_net, self.RECL.role_embedding_target_net, self.tau) 275 | 276 | if self.use_lr_decay: 277 | self.qmix_lr_decay.step() 278 | self.role_lr_decay.step() 279 | 280 | def pretrain_recl(self, replay_buffer): 281 | batch, max_episode_len = replay_buffer.sample(self.batch_size) 282 | batch_o = batch['obs_n'].to(self.device) 283 | batch_last_a = batch['last_onehot_a_n'].to(self.device) 284 | batch_active = batch['active'].to(self.device) 285 | recl_loss = self.update_recl(batch_o, batch_last_a, batch_active, max_episode_len) 286 | self.soft_update_params(self.RECL.role_embedding_net, self.RECL.role_embedding_target_net, self.role_tau) 287 | return recl_loss 288 | 289 | 290 | def pretrain_agent_embedding(self, replay_buffer): 291 | batch, max_episode_len = replay_buffer.sample(self.batch_size) 292 | batch_o = batch['obs_n'].to(self.device) 293 | batch_last_a = batch['last_onehot_a_n'].to(self.device) 294 | batch_active = batch['active'].to(self.device) 295 | 296 | self.RECL.agent_embedding_net.rnn_hidden = None 297 | agent_embeddings = [] 298 | for t in range(max_episode_len): 299 | agent_embedding = self.RECL.agent_embedding_forward(batch_o[:, t].reshape(-1, self.obs_dim), 300 | batch_last_a[:, t].reshape(-1, self.action_dim), 301 | detach=False) 302 | agent_embeddings.append(agent_embedding.reshape(-1, self.N, self.agent_embedding_dim)) # (batch_size, N, agent_embedding_dim) 303 | agent_embeddings = torch.stack(agent_embeddings, dim=1) #(batch_size, max_episode_len, N, agent_embedding_dim) 304 | decoder_output = self.RECL.agent_embedding_decoder(agent_embeddings.reshape(-1,self.agent_embedding_dim)).reshape(-1, max_episode_len, self.N, self.obs_dim+self.N) 305 | batch_obs_hat = batch_o[:,1:] 306 | agent_id_one_hot = torch.eye(self.N).unsqueeze(0).unsqueeze(0).repeat(batch_o.shape[0], max_episode_len, 1, 1).to(self.device) 307 | decoder_target = torch.cat([batch_obs_hat, agent_id_one_hot], dim=-1) # (batch_size, max_len, N, obs_dim+N) 308 | mask = batch_active.unsqueeze(-1).repeat(1, 1, self.N, self.obs_dim+self.N) 309 | loss = (((decoder_output - decoder_target) * mask)**2).sum()/mask.sum() 310 | 311 | self.encoder_decoder_optimizer.zero_grad() 312 | loss.backward() 313 | self.encoder_decoder_optimizer.step() 314 | return loss 315 | 316 | 317 | def update_recl(self, batch_o, batch_last_a, batch_active, max_episode_len): 318 | """ 319 | N = agent_num 320 | batch_o.shape = (batch_size, max_episode_len + 1, N, obs_dim) 321 | batch_a.shape = (batch_size, max_episode_len, N, action_dim) 322 | batch_active = (batch_size, max_episode_len, 1) 323 | """ 324 | self.RECL.agent_embedding_net.rnn_hidden = None 325 | loss = 0 326 | labels = np.zeros((batch_o.shape[0], self.N)) # (batch_size, N) 327 | for t in range(max_episode_len): # t = 0,1,2...(max_episode_len-1) 328 | with torch.no_grad(): 329 | agent_embedding = self.RECL.agent_embedding_forward(batch_o[:, t].reshape(-1, self.obs_dim), 330 | batch_last_a[:, t].reshape(-1, self.action_dim), 331 | detach=True) # agent_embedding.shape=(batch_size*N, agent_embed_dim) 332 | role_embedding_qury = self.RECL.role_embedding_forward(agent_embedding, 333 | detach=False, 334 | ema=False).reshape(-1,self.N, self.role_embedding_dim) # shape=(batch_size, N, role_embed_dim) 335 | role_embedding_key = self.RECL.role_embedding_forward(agent_embedding, 336 | detach=True, 337 | ema=True).reshape(-1,self.N, self.role_embedding_dim) 338 | logits = torch.bmm(role_embedding_qury, self.RECL.W.squeeze(0).expand((role_embedding_qury.shape[0],self.role_embedding_dim,self.role_embedding_dim))) 339 | logits = torch.bmm(logits, role_embedding_key.transpose(1,2)) # (batch_size, N, N) 340 | logits = logits - torch.max(logits, dim=-1)[0][:,:,None] 341 | exp_logits = torch.exp(logits) # (batch_size, N, 1) 342 | agent_embedding = agent_embedding.reshape(batch_o.shape[0],self.N, -1).to('cpu') # shape=(batch_size,N, agent_embed_dim) 343 | 344 | for idx in range(agent_embedding.shape[0]): # idx = 0,1,2...(batch_size-1) 345 | if batch_active[idx, t] > 0.5: 346 | if t % self.multi_steps == 0: 347 | clusters_labels = KMeans(n_clusters=self.cluster_num).fit(agent_embedding[idx]).labels_ # (1,N) 348 | labels[idx] = copy.deepcopy(clusters_labels) 349 | else: 350 | clusters_labels = copy.deepcopy(labels[idx]) 351 | # clusters_labels, _ = kmeans(X=agent_embedding[idx],num_clusters=self.cluster_num) 352 | for j in range(self.cluster_num): # j = 0,1,...(cluster_num -1) 353 | label_pos = [idx for idx, value in enumerate(clusters_labels) if value==j] 354 | # label_neg = [idx for idx, value in enumerate(clusters_labels) if value!=j] 355 | for anchor in label_pos: 356 | loss += -torch.log(exp_logits[idx, anchor, label_pos].sum()/exp_logits[idx, anchor].sum()) 357 | loss /= (self.batch_size * max_episode_len * self.N) 358 | if batch_active[idx, t] > 0.5: 359 | self.RECL_optimizer.zero_grad() 360 | loss.backward() 361 | self.RECL_optimizer.step() 362 | return loss 363 | 364 | def update_qmix(self, inputs, batch_o, batch_s, batch_r, batch_a, batch_last_a, batch_avail_a_n, batch_active, batch_dw, max_episode_len): 365 | self.eval_Q_net.rnn_hidden = None 366 | self.target_Q_net.rnn_hidden = None 367 | _, role_embeddings = self.RECL.batch_role_embed_forward(batch_o, batch_last_a, max_episode_len, detach=False) # shape=(batch_size, (max_episode_len+1),N, role_embed_dim) 368 | inputs = torch.cat([inputs, role_embeddings], dim=-1) 369 | q_evals, q_targets = [], [] 370 | 371 | self.eval_mix_net.state_gru_hidden = None 372 | # self.target_mix_net.state_gru_hidden = None 373 | fc_batch_s = F.relu(self.eval_mix_net.state_fc(batch_s.reshape(-1, self.state_dim))).reshape(-1, max_episode_len+1, self.state_dim) # shape(batch*max_len+1, state_dim) 374 | state_gru_outs = [] 375 | for t in range(max_episode_len): # t=0,1,2,...(episode_len-1) 376 | q_eval = self.eval_Q_net(inputs[:, t].reshape(-1, self.QMIX_input_dim)) # q_eval.shape=(batch_size*N,action_dim) 377 | q_target = self.target_Q_net(inputs[:, t + 1].reshape(-1, self.QMIX_input_dim)) 378 | q_evals.append(q_eval.reshape(self.batch_size, self.N, -1)) # q_eval.shape=(batch_size,N,action_dim) 379 | q_targets.append(q_target.reshape(self.batch_size, self.N, -1)) 380 | 381 | self.eval_mix_net.state_gru_hidden = self.eval_mix_net.state_gru(fc_batch_s[:, t].reshape(-1,self.state_dim), self.eval_mix_net.state_gru_hidden) # shape=(batch, N*state_embed_dim) 382 | state_gru_outs.append(self.eval_mix_net.state_gru_hidden) 383 | 384 | # role_eval = self.eval_mix_net.role_gru_forward(role_embeddings[:,t].reshape(-1, self.N*self.role_embedding_dim)) # shape=(batch_size, N*role_embed_dim) 385 | # role_target = self.target_mix_net.role_gru_forward(role_embeddings[:,t].reshape(-1, self.N*self.role_embedding_dim)) # shape=(batch_size, N*role_embed_dim) 386 | # role_evals.append(role_eval) 387 | # role_targets.append(role_target) 388 | self.eval_mix_net.state_gru_hidden = self.eval_mix_net.state_gru(fc_batch_s[:, max_episode_len].reshape(-1,self.state_dim), self.eval_mix_net.state_gru_hidden) 389 | state_gru_outs.append(self.eval_mix_net.state_gru_hidden) 390 | # role_targets.append(self.target_mix_net.role_gru_forward(role_embeddings[:,max_episode_len].reshape(-1, self.N*self.role_embedding_dim))) 391 | 392 | # Stack them according to the time (dim=1) 393 | # role_evals = torch.stack(role_evals, dim=1) # shape=(batch_size, max_len+1, N*role_dim) 394 | # role_targets = torch.stack(role_targets, dim=1) 395 | state_gru_outs = torch.stack(state_gru_outs, dim=1).reshape(-1, self.N, self.args.state_embed_dim) # shape=(batch*max_len+1, N,state_embed_dim) 396 | q_evals = torch.stack(q_evals, dim=1) # q_evals.shape=(batch_size,max_episode_len,N,action_dim) 397 | q_targets = torch.stack(q_targets, dim=1) 398 | 399 | with torch.no_grad(): 400 | q_eval_last = self.eval_Q_net(inputs[:, -1].reshape(-1, self.QMIX_input_dim)).reshape(self.batch_size, 1, self.N, -1) 401 | q_evals_next = torch.cat([q_evals[:, 1:], q_eval_last], dim=1) # q_evals_next.shape=(batch_size,max_episode_len,N,action_dim) 402 | q_evals_next[batch_avail_a_n[:, 1:] == 0] = -999999 403 | a_argmax = torch.argmax(q_evals_next, dim=-1, keepdim=True) # a_max.shape=(batch_size,max_episode_len, N, 1) 404 | q_targets = torch.gather(q_targets, dim=-1, index=a_argmax).squeeze(-1) # q_targets.shape=(batch_size, max_episode_len, N) 405 | q_evals = torch.gather(q_evals, dim=-1, index=batch_a.unsqueeze(-1)).squeeze(-1) # q_evals.shape(batch_size, max_episode_len, N) 406 | 407 | role_embeddings = role_embeddings.reshape(-1, self.N, self.role_embedding_dim) # shape=((batch_size * max_episode_len+1), N, role_embed_dim) 408 | # eval_state_qv = self.eval_mix_net.state_fc2(batch_s.reshape(-1, self.state_dim)).reshape(-1, self.N, self.eval_mix_net.dim_k) # shape=(batch*max_len, N, state_dim//N) 409 | # target_state_qv = self.target_mix_net.state_fc2(batch_s.reshape(-1, self.state_dim)).reshape(-1, self.N, self.eval_mix_net.dim_k) 410 | # agent_embeddings = agent_embeddings.reshape(-1, self.N, self.agent_embedding_dim) 411 | att_eval = self.eval_mix_net.attention_net(state_gru_outs, role_embeddings, role_embeddings).reshape(-1, max_episode_len+1, self.N*self.att_out_dim) # ((batch*max_episode_len+1), N, att_dim)->(batch, len, N*att_dim) 412 | with torch.no_grad(): 413 | att_target = self.target_mix_net.attention_net(state_gru_outs, role_embeddings, role_embeddings).reshape(-1, max_episode_len+1, self.N*self.att_out_dim) # ((batch*max_episode_len+1), N, att_dim)->(batch, len, N*att_dim) 414 | 415 | # eval_batch_s = self.eval_mix_net.state_fc1(batch_s.reshape(-1, self.state_dim)).reshape(-1, max_episode_len+1, self.state_dim) 416 | # traget_batch_s = self.target_mix_net.state_fc1(batch_s.reshape(-1, self.state_dim)).reshape(-1, max_episode_len+1, self.state_dim) 417 | 418 | q_total_eval = self.eval_mix_net(q_evals, fc_batch_s[:, :-1], att_eval[:, :-1]) 419 | q_total_target = self.target_mix_net(q_targets, fc_batch_s[:, 1:], att_target[:, 1:]) 420 | targets = batch_r + self.gamma * (1 - batch_dw) * q_total_target 421 | td_error = (q_total_eval - targets.detach()) # targets.detach() to cut the backward 422 | mask_td_error = td_error * batch_active 423 | loss = (mask_td_error ** 2).sum() / batch_active.sum() 424 | self.optimizer.zero_grad() 425 | self.role_embedding_optimizer.zero_grad() 426 | 427 | loss.backward() 428 | torch.nn.utils.clip_grad_norm_(self.role_parameters, 10) 429 | torch.nn.utils.clip_grad_norm_(self.eval_parameters, 10) 430 | self.optimizer.step() 431 | self.role_embedding_optimizer.step() 432 | 433 | def soft_update_params(self, net, target_net, tau): 434 | for param, target_param in zip(net.parameters(), target_net.parameters()): 435 | target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) 436 | -------------------------------------------------------------------------------- /ACORM_QMIX/algorithm/vdn_qmix.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from util.net import Q_network_MLP, Q_network_RNN 6 | import copy 7 | 8 | class QMIX_Net(nn.Module): 9 | def __init__(self, args): 10 | super(QMIX_Net, self).__init__() 11 | self.N = args.N 12 | self.state_dim = args.state_dim 13 | self.batch_size = args.batch_size 14 | self.qmix_hidden_dim = args.qmix_hidden_dim 15 | self.hyper_hidden_dim = args.hyper_hidden_dim 16 | self.hyper_layers_num = args.hyper_layers_num 17 | """ 18 | w1:(N, qmix_hidden_dim) 19 | b1:(1, qmix_hidden_dim) 20 | w2:(qmix_hidden_dim, 1) 21 | b2:(1, 1) 22 | 23 | """ 24 | if self.hyper_layers_num == 2: 25 | print("hyper_layers_num=2") 26 | self.hyper_w1 = nn.Sequential(nn.Linear(self.state_dim, self.hyper_hidden_dim), 27 | nn.ReLU(), 28 | nn.Linear(self.hyper_hidden_dim, self.N * self.qmix_hidden_dim)) 29 | self.hyper_w2 = nn.Sequential(nn.Linear(self.state_dim, self.hyper_hidden_dim), 30 | nn.ReLU(), 31 | nn.Linear(self.hyper_hidden_dim, self.qmix_hidden_dim * 1)) 32 | elif self.hyper_layers_num == 1: 33 | print("hyper_layers_num=1") 34 | self.hyper_w1 = nn.Linear(self.state_dim, self.N * self.qmix_hidden_dim) 35 | self.hyper_w2 = nn.Linear(self.state_dim, self.qmix_hidden_dim * 1) 36 | else: 37 | print("wrong!!!") 38 | 39 | self.hyper_b1 = nn.Linear(self.state_dim, self.qmix_hidden_dim) 40 | self.hyper_b2 = nn.Sequential(nn.Linear(self.state_dim, self.qmix_hidden_dim), 41 | nn.ReLU(), 42 | nn.Linear(self.qmix_hidden_dim, 1)) 43 | 44 | def forward(self, q, s): 45 | # q.shape(batch_size, max_episode_len, N) 46 | # s.shape(batch_size, max_episode_len,state_dim) 47 | q = q.view(-1, 1, self.N) # (batch_size * max_episode_len, 1, N) 48 | s = s.reshape(-1, self.state_dim) # (batch_size * max_episode_len, state_dim) 49 | 50 | w1 = torch.abs(self.hyper_w1(s)) # (batch_size * max_episode_len, N * qmix_hidden_dim) 51 | b1 = self.hyper_b1(s) # (batch_size * max_episode_len, qmix_hidden_dim) 52 | w1 = w1.view(-1, self.N, self.qmix_hidden_dim) # (batch_size * max_episode_len, N, qmix_hidden_dim) 53 | b1 = b1.view(-1, 1, self.qmix_hidden_dim) # (batch_size * max_episode_len, 1, qmix_hidden_dim) 54 | 55 | # torch.bmm: 3 dimensional tensor multiplication 56 | q_hidden = F.elu(torch.bmm(q, w1) + b1) # (batch_size * max_episode_len, 1, qmix_hidden_dim) 57 | 58 | w2 = torch.abs(self.hyper_w2(s)) # (batch_size * max_episode_len, qmix_hidden_dim * 1) 59 | b2 = self.hyper_b2(s) # (batch_size * max_episode_len,1) 60 | w2 = w2.view(-1, self.qmix_hidden_dim, 1) # (batch_size * max_episode_len, qmix_hidden_dim, 1) 61 | b2 = b2.view(-1, 1, 1) # (batch_size * max_episode_len, 1, 1) 62 | 63 | q_total = torch.bmm(q_hidden, w2) + b2 # (batch_size * max_episode_len, 1, 1) 64 | q_total = q_total.view(self.batch_size, -1, 1) # (batch_size, max_episode_len, 1) 65 | return q_total 66 | 67 | 68 | class VDN_Net(nn.Module): 69 | def __init__(self, ): 70 | super(VDN_Net, self).__init__() 71 | 72 | def forward(self, q): 73 | return torch.sum(q, dim=-1, keepdim=True) # (batch_size, max_episode_len, 1) 74 | 75 | 76 | class VDN_QMIX(object): 77 | def __init__(self, args): 78 | self.N = args.N 79 | self.action_dim = args.action_dim 80 | self.obs_dim = args.obs_dim 81 | self.state_dim = args.state_dim 82 | self.add_last_action = args.add_last_action 83 | self.add_agent_id = args.add_agent_id 84 | self.max_train_steps=args.max_train_steps 85 | self.lr = args.lr 86 | self.gamma = args.gamma 87 | self.use_grad_clip = args.use_grad_clip 88 | self.batch_size = args.batch_size 89 | self.target_update_freq = args.target_update_freq 90 | self.tau = args.tau 91 | self.use_hard_update = args.use_hard_update 92 | self.use_rnn = args.use_rnn 93 | self.algorithm = args.algorithm 94 | self.use_double_q = args.use_double_q 95 | self.use_RMS = args.use_RMS 96 | self.use_lr_decay = args.use_lr_decay 97 | self.device = torch.device(args.device if torch.cuda.is_available() else 'cpu') 98 | self.use_gpu = args.use_gpu 99 | # Compute the input dimension 100 | self.input_dim = self.obs_dim 101 | if self.add_last_action: 102 | print("------add last action------") 103 | self.input_dim += self.action_dim 104 | if self.add_agent_id: 105 | print("------add agent id------") 106 | self.input_dim += self.N 107 | 108 | if self.use_rnn: 109 | print("------use RNN------") 110 | self.eval_Q_net = Q_network_RNN(args, self.input_dim) 111 | self.target_Q_net = Q_network_RNN(args, self.input_dim) 112 | else: 113 | print("------use MLP------") 114 | self.eval_Q_net = Q_network_MLP(args, self.input_dim) 115 | self.target_Q_net = Q_network_MLP(args, self.input_dim) 116 | self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict()) 117 | 118 | if self.algorithm == "QMIX": 119 | print("------algorithm: QMIX------") 120 | self.eval_mix_net = QMIX_Net(args) 121 | self.target_mix_net = QMIX_Net(args) 122 | elif self.algorithm == "VDN": 123 | print("------algorithm: VDN------") 124 | self.eval_mix_net = VDN_Net() 125 | self.target_mix_net = VDN_Net() 126 | else: 127 | print("wrong!!!") 128 | 129 | self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict()) 130 | self.eval_parameters = list(self.eval_mix_net.parameters()) + list(self.eval_Q_net.parameters()) 131 | 132 | if self.use_RMS: 133 | print("------optimizer: RMSprop------") 134 | self.optimizer = torch.optim.RMSprop(self.eval_parameters, lr=self.lr) 135 | else: 136 | print("------optimizer: Adam------") 137 | self.optimizer = torch.optim.Adam(self.eval_parameters, lr=self.lr) 138 | 139 | if self.use_gpu: 140 | self.target_Q_net.to(self.device) 141 | self.eval_Q_net.to(self.device) 142 | self.target_mix_net.to(self.device) 143 | self.eval_mix_net.to(self.device) 144 | self.train_step = 0 145 | 146 | def choose_action(self, obs_n, last_onehot_a_n, avail_a_n, epsilon): 147 | with torch.no_grad(): 148 | if np.random.uniform() < epsilon: # epsilon-greedy 149 | # Only available actions can be chosen 150 | a_n = [np.random.choice(np.nonzero(avail_a)[0]) for avail_a in avail_a_n] 151 | else: 152 | # inputs = [] 153 | # obs_n = torch.tensor(obs_n, dtype=torch.float32) # obs_n.shape=(N,obs_dim) 154 | # inputs.append(obs_n) 155 | inputs = copy.deepcopy(obs_n) 156 | if self.add_last_action: 157 | # last_a_n = torch.tensor(last_onehot_a_n, dtype=torch.float32) 158 | # inputs.append(last_a_n) 159 | inputs = np.hstack((inputs, last_onehot_a_n)) 160 | if self.add_agent_id: 161 | inputs = np.hstack((inputs, np.eye(self.N))) 162 | # inputs.append(torch.eye(self.N)) 163 | 164 | # inputs = torch.cat([x for x in inputs], dim=-1) # inputs.shape=(N,inputs_dim) 165 | inputs = torch.tensor(inputs, dtype=torch.float32) # nputs.shape = (N, obs_dim+action_dim+N) 166 | if self.use_gpu: 167 | inputs = inputs.to(self.device) 168 | q_value = self.eval_Q_net(inputs) 169 | 170 | avail_a_n = torch.tensor(avail_a_n, dtype=torch.float32) # avail_a_n.shape=(N, action_dim) 171 | if self.use_gpu: 172 | q_value = q_value.to('cpu') 173 | q_value[avail_a_n == 0] = -float('inf') # Mask the unavailable actions 174 | a_n = q_value.argmax(dim=-1).numpy() 175 | return a_n 176 | 177 | def train(self, replay_buffer): 178 | batch, max_episode_len = replay_buffer.sample() # Get training data 179 | self.train_step += 1 180 | 181 | inputs = self.get_inputs(batch, max_episode_len) # inputs.shape=(bach_size,max_episode_len+1,N,input_dim) 182 | if self.use_gpu: 183 | inputs = inputs.to(self.device) 184 | batch_s = batch['s'].to(self.device) 185 | batch_r = batch['r'].to(self.device) 186 | batch_n = batch['a_n'].to(self.device) 187 | batch_active = batch['active'].to(self.device) 188 | batch_dw = batch['dw'].to(self.device) 189 | if self.use_rnn: 190 | self.eval_Q_net.rnn_hidden = None 191 | self.target_Q_net.rnn_hidden = None 192 | q_evals, q_targets = [], [] 193 | for t in range(max_episode_len): # t=0,1,2,...(episode_len-1) 194 | q_eval = self.eval_Q_net(inputs[:, t].reshape(-1, self.input_dim)) # q_eval.shape=(batch_size*N,action_dim) 195 | q_target = self.target_Q_net(inputs[:, t + 1].reshape(-1, self.input_dim)) 196 | q_evals.append(q_eval.reshape(self.batch_size, self.N, -1)) # q_eval.shape=(batch_size,N,action_dim) 197 | q_targets.append(q_target.reshape(self.batch_size, self.N, -1)) 198 | 199 | # Stack them according to the time (dim=1) 200 | q_evals = torch.stack(q_evals, dim=1) # q_evals.shape=(batch_size,max_episode_len,N,action_dim) 201 | q_targets = torch.stack(q_targets, dim=1) 202 | else: 203 | q_evals = self.eval_Q_net(inputs[:, :-1]) # q_evals.shape=(batch_size,max_episode_len,N,action_dim) 204 | q_targets = self.target_Q_net(inputs[:, 1:]) # inputs[:, 1:] -> obs_next 205 | 206 | with torch.no_grad(): 207 | if self.use_double_q: # If use double q-learning, we use eval_net to choose actions,and use target_net to compute q_target 208 | q_eval_last = self.eval_Q_net(inputs[:, -1].reshape(-1, self.input_dim)).reshape(self.batch_size, 1, self.N, -1) 209 | q_evals_next = torch.cat([q_evals[:, 1:], q_eval_last], dim=1) # q_evals_next.shape=(batch_size,max_episode_len,N,action_dim) 210 | q_evals_next[batch['avail_a_n'][:, 1:] == 0] = -999999 211 | a_argmax = torch.argmax(q_evals_next, dim=-1, keepdim=True) # a_max.shape=(batch_size,max_episode_len, N, 1) 212 | q_targets = torch.gather(q_targets, dim=-1, index=a_argmax).squeeze(-1) # q_targets.shape=(batch_size, max_episode_len, N) 213 | else: 214 | q_targets[batch['avail_a_n'][:, 1:] == 0] = -999999 # batch['avail_a_n'].shape = (batch_size, max_episode_len, N, action_dim) 215 | q_targets = q_targets.max(dim=-1)[0] # q_targets.shape=(batch_size, max_episode_len, N) 216 | 217 | # batch['a_n'].shape(batch_size,max_episode_len, N) 218 | q_evals = torch.gather(q_evals, dim=-1, index=batch_n.unsqueeze(-1)).squeeze(-1) # q_evals.shape(batch_size, max_episode_len, N) 219 | 220 | # Compute q_total using QMIX or VDN, q_total.shape=(batch_size, max_episode_len, 1) 221 | if self.algorithm == "QMIX": 222 | q_total_eval = self.eval_mix_net(q_evals, batch_s[:, :-1]) 223 | q_total_target = self.target_mix_net(q_targets, batch_s[:, 1:]) 224 | else: 225 | q_total_eval = self.eval_mix_net(q_evals) 226 | q_total_target = self.target_mix_net(q_targets) 227 | # targets.shape=(batch_size,max_episode_len,1) 228 | 229 | targets = batch_r + self.gamma * (1 - batch_dw) * q_total_target 230 | 231 | td_error = (q_total_eval - targets.detach()) # targets.detach() to cut the backward 232 | mask_td_error = td_error * batch_active 233 | loss = (mask_td_error ** 2).sum() / batch_active.sum() 234 | self.optimizer.zero_grad() 235 | loss.backward() 236 | if self.use_grad_clip: 237 | torch.nn.utils.clip_grad_norm_(self.eval_parameters, 10) 238 | self.optimizer.step() 239 | 240 | if self.use_hard_update: 241 | # hard update 242 | if self.train_step % self.target_update_freq == 0: 243 | self.target_Q_net.load_state_dict(self.eval_Q_net.state_dict()) 244 | self.target_mix_net.load_state_dict(self.eval_mix_net.state_dict()) 245 | else: 246 | # Softly update the target networks 247 | for param, target_param in zip(self.eval_Q_net.parameters(), self.target_Q_net.parameters()): 248 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 249 | 250 | for param, target_param in zip(self.eval_mix_net.parameters(), self.target_mix_net.parameters()): 251 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 252 | 253 | 254 | def get_inputs(self, batch, max_episode_len): 255 | # inputs = [] 256 | # inputs.append(batch['obs_n']) 257 | inputs = copy.deepcopy(batch['obs_n']) # batch['obs_n'].shape = (batch_size,max_episode_len+1, N, obs_dim) 258 | if self.add_last_action: 259 | inputs = np.concatenate((inputs, batch['last_onehot_a_n']),axis=-1) 260 | # inputs.append(batch['last_onehot_a_n']) 261 | if self.add_agent_id: 262 | agent_id_one_hot = torch.eye(self.N).unsqueeze(0).unsqueeze(0).repeat(self.batch_size, max_episode_len + 1, 1, 1) 263 | # inputs.append(agent_id_one_hot) 264 | inputs = np.concatenate((inputs, agent_id_one_hot),axis=-1) 265 | inputs = torch.tensor(inputs, dtype=torch.float32) 266 | # inputs.shape=(bach_size,max_episode_len+1,N,input_dim) 267 | # inputs = torch.cat([x for x in inputs], dim=-1) 268 | return inputs 269 | 270 | def save_model(self, env_name, algorithm, seed, total_steps): 271 | torch.save(self.eval_Q_net.state_dict(), "./model/{}/{}_seed_{}_step_{}k.pth".format(env_name, algorithm, seed, int(total_steps / 1000))) 272 | -------------------------------------------------------------------------------- /ACORM_QMIX/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from run import Runner 3 | import torch 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser("Hyperparameter Setting for QMIX,VDN and ACORM in SMAC environment") 7 | parser.add_argument("--max_train_steps", type=int, default=5000000, help="Maximum number of training steps") 8 | parser.add_argument("--evaluate_freq", type=int, default=10000, help="Evaluate the policy every 'evaluate_freq' steps") 9 | parser.add_argument("--evaluate_times", type=float, default=32, help="Evaluate times") 10 | # parser.add_argument("--save_freq", type=int, default=int(1e5), help="Save frequency") 11 | parser.add_argument("--algorithm", type=str, default="ACORM", help="QMIX or VDN") 12 | parser.add_argument("--epsilon", type=float, default=1.0, help="Initial epsilon") 13 | parser.add_argument("--epsilon_decay_steps", type=float, default=80000, help="How many steps before the epsilon decays to the minimum") 14 | parser.add_argument("--epsilon_min", type=float, default=0.02, help="Minimum epsilon") 15 | parser.add_argument("--buffer_size", type=int, default=5000, help="The capacity of the replay buffer") 16 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size (the number of episodes)") 17 | parser.add_argument("--lr", type=float, default=6e-4, help="Learning rate") 18 | parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor") 19 | parser.add_argument("--qmix_hidden_dim", type=int, default=32, help="The dimension of the hidden layer of the QMIX network") 20 | parser.add_argument("--hyper_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of the hyper-network") 21 | parser.add_argument("--hyper_layers_num", type=int, default=2, help="The number of layers of hyper-network") 22 | parser.add_argument("--rnn_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of RNN") 23 | parser.add_argument("--mlp_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of MLP") 24 | parser.add_argument("--add_last_action", type=bool, default=True, help="Whether to add last actions into the observation") 25 | parser.add_argument("--use_hard_update", type=bool, default=False, help="Whether to use hard update") 26 | parser.add_argument("--use_lr_decay", type=bool, default=True, help="Whether to use learning rate decay") 27 | parser.add_argument("--lr_decay_steps", type=int, default=500, help="every steps decay steps") 28 | parser.add_argument("--lr_decay_rate", type=float, default=0.98, help="learn decay rate") 29 | parser.add_argument("--target_update_freq", type=int, default=100, help="Update frequency of the target network") 30 | parser.add_argument("--tau", type=float, default=0.005, help="If use soft update") 31 | parser.add_argument("--seed", type=int, default=123, help="random seed") 32 | parser.add_argument('--device', type=str, default='cuda:0') 33 | parser.add_argument('--env_name', type=str, default='MMM2') #['3m', '8m', '2s3z'] 34 | 35 | # plot 36 | parser.add_argument("--sns_plot", type=bool, default=False, help="Whether to use seaborn plot") 37 | parser.add_argument("--tb_plot", type=bool, default=True, help="Whether to use tensorboard plot") 38 | 39 | # RECL 40 | parser.add_argument("--agent_embedding_dim", type=int, default=128, help="The dimension of the agent embedding") 41 | parser.add_argument("--role_embedding_dim", type=int, default=64, help="The dimension of the role embedding") 42 | parser.add_argument("--use_ln", type=bool, default=False, help="Whether to use layer normalization") 43 | parser.add_argument("--cluster_num", type=int, default=int(3), help="the cluster number of knn") 44 | parser.add_argument("--recl_lr", type=float, default=8e-4, help="Learning rate") 45 | parser.add_argument("--agent_embedding_lr", type=float, default=1e-3, help="agent_embedding Learning rate") 46 | parser.add_argument("--train_recl_freq", type=int, default=200, help="Train frequency of the RECL network") 47 | parser.add_argument("--role_tau", type=float, default=0.005, help="If use soft update") 48 | parser.add_argument("--multi_steps", type=int, default=1, help="Train frequency of the RECL network") 49 | parser.add_argument("--role_mix_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of the QMIX network") 50 | 51 | # attention 52 | parser.add_argument("--att_dim", type=int, default=128, help="The dimension of the attention net") 53 | parser.add_argument("--att_out_dim", type=int, default=64, help="The dimension of the attention net") 54 | parser.add_argument("--n_heads", type=int, default=4, help="multi-head attention") 55 | parser.add_argument("--soft_temperature", type=float, default=1.0, help="multi-head attention") 56 | parser.add_argument("--state_embed_dim", type=int, default=64, help="The dimension of the gru state net") 57 | 58 | # save path 59 | parser.add_argument('--save_path', type=str, default='./result/acorm') 60 | parser.add_argument('--model_path', type=str, default='./model/acorm') 61 | 62 | args = parser.parse_args() 63 | args.epsilon_decay = (args.epsilon - args.epsilon_min) / args.epsilon_decay_steps 64 | torch.multiprocessing.set_start_method('spawn') 65 | 66 | runner = Runner(args) 67 | runner.run() 68 | 69 | -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/10m_vs_11m_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/10m_vs_11m_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/10m_vs_11m_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/10m_vs_11m_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/10m_vs_11m_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/10m_vs_11m_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/10m_vs_11m_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/10m_vs_11m_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/1c3s5z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/1c3s5z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/1c3s5z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/1c3s5z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/1c3s5z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/1c3s5z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/1c3s5z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/1c3s5z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/27m_vs_30m_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/27m_vs_30m_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/27m_vs_30m_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/27m_vs_30m_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/27m_vs_30m_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/27m_vs_30m_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/27m_vs_30m_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/27m_vs_30m_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2c_vs_64zg_seed4.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2s3z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2s3z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2s3z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2s3z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2s3z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2s3z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/2s3z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/2s3z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/3s5z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/3s5z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/3s5z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/3s5z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/3s5z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/3s5z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/3s5z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/3s5z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/3s5z_vs_3s6z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/3s5z_vs_3s6z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/3s5z_vs_3s6z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/3s5z_vs_3s6z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/3s5z_vs_3s6z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/3s5z_vs_3s6z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/3s5z_vs_3s6z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/3s5z_vs_3s6z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/5m_vs_6m_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/5m_vs_6m_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/5m_vs_6m_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/5m_vs_6m_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/5m_vs_6m_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/5m_vs_6m_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/5m_vs_6m_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/5m_vs_6m_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/6h_vs_8z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/6h_vs_8z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/6h_vs_8z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/6h_vs_8z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/6h_vs_8z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/6h_vs_8z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/6h_vs_8z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/6h_vs_8z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/MMM2_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/MMM2_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/MMM2_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/MMM2_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/MMM2_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/MMM2_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/MMM2_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/MMM2_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/MMM2_seed4.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/MMM2_seed4.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/bane_vs_bane_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/bane_vs_bane_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/bane_vs_bane_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/bane_vs_bane_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/bane_vs_bane_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/bane_vs_bane_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/bane_vs_bane_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/bane_vs_bane_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/corridor_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/corridor_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/corridor_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/corridor_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/corridor_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/corridor_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/acorm/corridor_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/acorm/corridor_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/10m_vs_11m_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/10m_vs_11m_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/10m_vs_11m_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/10m_vs_11m_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/10m_vs_11m_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/10m_vs_11m_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/10m_vs_11m_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/10m_vs_11m_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/1c3s5z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/1c3s5z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/1c3s5z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/1c3s5z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/1c3s5z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/1c3s5z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/1c3s5z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/1c3s5z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/27m_vs_30m_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/27m_vs_30m_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/27m_vs_30m_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/27m_vs_30m_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/27m_vs_30m_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/27m_vs_30m_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/27m_vs_30m_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/27m_vs_30m_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/2c_vs_64zg_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/2c_vs_64zg_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/2c_vs_64zg_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/2c_vs_64zg_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/2c_vs_64zg_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/2c_vs_64zg_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/2c_vs_64zg_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/2c_vs_64zg_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/2s3z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/2s3z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/2s3z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/2s3z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/2s3z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/2s3z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/2s3z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/2s3z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/3s5z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/3s5z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/3s5z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/3s5z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/3s5z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/3s5z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/3s5z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/3s5z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/3s5z_vs_3s6z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/3s5z_vs_3s6z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/3s5z_vs_3s6z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/3s5z_vs_3s6z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/3s5z_vs_3s6z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/3s5z_vs_3s6z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/3s5z_vs_3s6z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/3s5z_vs_3s6z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/5m_vs_6m_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/5m_vs_6m_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/5m_vs_6m_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/5m_vs_6m_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/5m_vs_6m_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/5m_vs_6m_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/5m_vs_6m_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/5m_vs_6m_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/6h_vs_8z_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/6h_vs_8z_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/6h_vs_8z_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/6h_vs_8z_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/6h_vs_8z_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/6h_vs_8z_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/6h_vs_8z_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/6h_vs_8z_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/MMM2_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/MMM2_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/MMM2_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/MMM2_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/MMM2_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/MMM2_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/MMM2_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/MMM2_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/bane_vs_bane_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/bane_vs_bane_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/bane_vs_bane_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/bane_vs_bane_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/bane_vs_bane_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/bane_vs_bane_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/bane_vs_bane_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/bane_vs_bane_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/corridor_seed0.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/corridor_seed0.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/corridor_seed1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/corridor_seed1.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/corridor_seed2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/corridor_seed2.npy -------------------------------------------------------------------------------- /ACORM_QMIX/result/sacred/qmix/corridor_seed3.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/result/sacred/qmix/corridor_seed3.npy -------------------------------------------------------------------------------- /ACORM_QMIX/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from algorithm.vdn_qmix import VDN_QMIX 4 | from algorithm.acorm import ACORM_Agent 5 | from util.replay_buffer import ReplayBuffer 6 | from smac.env import StarCraft2Env 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | import datetime 10 | 11 | class Runner: 12 | def __init__(self, args): 13 | self.args = args 14 | self.env_name = self.args.env_name 15 | self.seed = self.args.seed 16 | # Set random seed 17 | np.random.seed(self.seed) 18 | torch.manual_seed(self.seed) 19 | # Create env 20 | self.env = StarCraft2Env(map_name=self.env_name, seed=self.seed) 21 | self.env_info = self.env.get_env_info() 22 | self.args.N = self.env_info["n_agents"] # The number of agents 23 | self.args.obs_dim = self.env_info["obs_shape"] # The dimensions of an agent's observation space 24 | self.args.state_dim = self.env_info["state_shape"] # The dimensions of global state space 25 | self.args.action_dim = self.env_info["n_actions"] # The dimensions of an agent's action space 26 | self.args.episode_limit = self.env_info["episode_limit"] # Maximum number of steps per episode 27 | print("number of agents={}".format(self.args.N)) 28 | print("obs_dim={}".format(self.args.obs_dim)) 29 | print("state_dim={}".format(self.args.state_dim)) 30 | print("action_dim={}".format(self.args.action_dim)) 31 | print("episode_limit={}".format(self.args.episode_limit)) 32 | self.save_path = args.save_path 33 | self.model_path = args.model_path 34 | 35 | from tensorboardX import SummaryWriter 36 | time_path = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 37 | self.writer = SummaryWriter(log_dir='./result/tb_logs/{}/{}/{}_seed_{}_{}'.format(self.args.algorithm, self.env_name, self.env_name, self.seed,time_path)) 38 | 39 | # Create N agents 40 | if args.algorithm in ['QMIX', 'VDN']: 41 | self.agent_n = VDN_QMIX(self.args) 42 | elif args.algorithm == 'ACORM': 43 | self.agent_n = ACORM_Agent(self.args) 44 | self.replay_buffer = ReplayBuffer(self.args, self.args.buffer_size) 45 | 46 | self.epsilon = self.args.epsilon # Initialize the epsilon 47 | self.win_rates = [] # Record the win rates 48 | self.evaluate_reward = [] 49 | self.total_steps = 0 50 | self.agent_embed_pretrain_epoch, self.recl_pretrain_epoch = 0, 0 51 | self.pretrain_agent_embed_loss, self.pretrain_recl_loss = [], [] 52 | self.args.agent_embed_pretrain_epochs =120 53 | self.args.recl_pretrain_epochs = 100 54 | 55 | def run(self, ): 56 | evaluate_num = -1 # Record the number of evaluations 57 | while self.total_steps < self.args.max_train_steps: 58 | if self.total_steps // self.args.evaluate_freq > evaluate_num: 59 | self.evaluate_policy() # Evaluate the policy every 'evaluate_freq' steps 60 | evaluate_num += 1 61 | 62 | _, _, episode_steps = self.run_episode_smac(evaluate=False) # Run an episode 63 | 64 | if self.agent_embed_pretrain_epoch < self.args.agent_embed_pretrain_epochs: 65 | if self.replay_buffer.current_size >= self.args.batch_size: 66 | self.agent_embed_pretrain_epoch += 1 67 | agent_embedding_loss = self.agent_n.pretrain_agent_embedding(self.replay_buffer) 68 | self.pretrain_agent_embed_loss.append(agent_embedding_loss.item()) 69 | else: 70 | if self.recl_pretrain_epoch < self.args.recl_pretrain_epochs: 71 | self.recl_pretrain_epoch += 1 72 | recl_loss = self.agent_n.pretrain_recl(self.replay_buffer) 73 | self.pretrain_recl_loss.append(recl_loss.item()) 74 | 75 | else: 76 | self.total_steps += episode_steps 77 | if self.replay_buffer.current_size >= self.args.batch_size: 78 | self.agent_n.train(self.replay_buffer) # Training 79 | 80 | self.evaluate_policy() 81 | # save model 82 | model_path = f'{self.model_path}/{self.env_name}_seed{self.seed}_' 83 | torch.save(self.agent_n.eval_Q_net, model_path + 'q_net.pth') 84 | torch.save(self.agent_n.RECL.role_embedding_net, model_path + 'role_net.pth') 85 | torch.save(self.agent_n.RECL.agent_embedding_net, model_path+'agent_embed_net.pth') 86 | torch.save(self.agent_n.eval_mix_net.attention_net, model_path+'attention_net.pth') 87 | torch.save(self.agent_n.eval_mix_net, model_path+'mix_net.pth') 88 | self.env.close() 89 | 90 | def evaluate_policy(self, ): 91 | win_times = 0 92 | evaluate_reward = 0 93 | for _ in range(self.args.evaluate_times): 94 | win_tag, episode_reward, _ = self.run_episode_smac(evaluate=True) 95 | if win_tag: 96 | win_times += 1 97 | evaluate_reward += episode_reward 98 | 99 | win_rate = win_times / self.args.evaluate_times 100 | evaluate_reward = evaluate_reward / self.args.evaluate_times 101 | self.win_rates.append(win_rate) 102 | self.evaluate_reward.append(evaluate_reward) 103 | print("total_steps:{} \t win_rate:{} \t evaluate_reward:{}".format(self.total_steps, win_rate, evaluate_reward)) 104 | 105 | if self.args.tb_plot: 106 | self.writer.add_scalar('win_rate', win_rate, global_step=self.total_steps) 107 | if self.args.sns_plot: 108 | # # plot curve 109 | sns.set_style('whitegrid') 110 | plt.figure() 111 | x_step = np.array(range(len(self.win_rates))) 112 | ax = sns.lineplot(x=x_step, y=np.array(self.win_rates).flatten(), label=self.args.algorithm) 113 | plt.ylabel('win_rates', fontsize=14) 114 | plt.xlabel(f'step*{self.args.evaluate_freq}', fontsize=14) 115 | plt.title(f'{self.args.algorithm} on {self.env_name}') 116 | plt.savefig(f'{self.save_path}/{self.env_name}_seed{self.seed}.jpg') 117 | 118 | # Save the win rates 119 | np.save(f'{self.save_path}/{self.env_name}_seed{self.seed}.npy', np.array(self.win_rates)) 120 | np.save(f'{self.save_path}/{self.env_name}_seed{self.seed}_return.npy', np.array(self.evaluate_reward)) 121 | 122 | def run_episode_smac(self, evaluate=False): 123 | win_tag = False 124 | episode_reward = 0 125 | self.env.reset() 126 | 127 | self.agent_n.eval_Q_net.rnn_hidden = None 128 | if self.args.algorithm == 'ACORM': 129 | self.agent_n.RECL.agent_embedding_net.rnn_hidden = None 130 | 131 | last_onehot_a_n = np.zeros((self.args.N, self.args.action_dim)) # Last actions of N agents(one-hot) 132 | for episode_step in range(self.args.episode_limit): 133 | obs_n = self.env.get_obs() # obs_n.shape=(N,obs_dim) 134 | s = self.env.get_state() # s.shape=(state_dim,) 135 | avail_a_n = self.env.get_avail_actions() # Get available actions of N agents, avail_a_n.shape=(N,action_dim) 136 | epsilon = 0 if evaluate else self.epsilon 137 | 138 | if self.args.algorithm == 'ACORM': 139 | role_embedding = self.agent_n.get_role_embedding(obs_n, last_onehot_a_n) 140 | a_n = self.agent_n.choose_action(obs_n, last_onehot_a_n, role_embedding, avail_a_n, epsilon) 141 | else: 142 | a_n = self.agent_n.choose_action(obs_n, last_onehot_a_n, avail_a_n, epsilon) 143 | 144 | r, done, info = self.env.step(a_n) # Take a step 145 | win_tag = True if done and 'battle_won' in info and info['battle_won'] else False 146 | episode_reward += r 147 | 148 | if not evaluate: 149 | """" 150 | When dead or win or reaching the episode_limit, done will be Ture, we need to distinguish them; 151 | dw means dead or win,there is no next state s'; 152 | but when reaching the max_episode_steps,there is a next state s' actually. 153 | """ 154 | if done and episode_step + 1 != self.args.episode_limit: 155 | dw = True 156 | else: 157 | dw = False 158 | 159 | # Store the transition 160 | self.replay_buffer.store_transition(episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw) 161 | last_onehot_a_n = np.eye(self.args.action_dim)[a_n] # Convert actions to one-hot vectors 162 | # obs_a_n_buffer[episode_step] = obs_n 163 | # Decay the epsilon 164 | self.epsilon = self.epsilon - self.args.epsilon_decay if self.epsilon - self.args.epsilon_decay > self.args.epsilon_min else self.args.epsilon_min 165 | 166 | if done: 167 | break 168 | 169 | if not evaluate: 170 | # An episode is over, store obs_n, s and avail_a_n in the last step 171 | obs_n = self.env.get_obs() 172 | s = self.env.get_state() 173 | avail_a_n = self.env.get_avail_actions() 174 | self.replay_buffer.store_last_step(episode_step + 1, obs_n, s, avail_a_n, last_onehot_a_n) 175 | return win_tag, episode_reward, episode_step+1 -------------------------------------------------------------------------------- /ACORM_QMIX/run.sh: -------------------------------------------------------------------------------- 1 | nohup python main.py --algorithm ACORM \ 2 | --env_name 'MMM2' \ 3 | --device 'cuda:1' \ 4 | --max_train_steps 3000000 \ 5 | --seed 4 \ 6 | --epsilon 1.0 \ 7 | --epsilon_decay_steps 80000 \ 8 | --epsilon_min 0.02 \ 9 | --use_hard_update False \ 10 | --use_lr_decay True \ 11 | --lr_decay_steps 500 \ 12 | --lr_decay_rate 0.98 \ 13 | --train_recl_freq 100 \ 14 | --use_ln False \ 15 | --role_tau 0.005 \ 16 | --cluster_num 3 \ 17 | --agent_embedding_dim 128 \ 18 | --hyper_layers_num 2 \ 19 | --lr +6e-4 \ 20 | --recl_lr +8e-4 \ 21 | --role_embedding_dim 64 \ 22 | --save_path './result/acorm'\ 23 | --att_dim 128 \ 24 | --att_out_dim 64 \ 25 | --n_heads 4 \ 26 | --soft_temperature 1.0 \ 27 | --state_embed_dim 64 \ 28 | & 29 | 30 | 31 | -------------------------------------------------------------------------------- /ACORM_QMIX/util/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/util/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/util/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/util/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/util/__pycache__/net.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/util/__pycache__/net.cpython-310.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/util/__pycache__/net.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/util/__pycache__/net.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/util/__pycache__/replay_buffer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/util/__pycache__/replay_buffer.cpython-310.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/util/__pycache__/replay_buffer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/util/__pycache__/replay_buffer.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/util/__pycache__/replay_buffer_v1.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ACORM_QMIX/util/__pycache__/replay_buffer_v1.cpython-37.pyc -------------------------------------------------------------------------------- /ACORM_QMIX/util/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MultiHeadAttention(nn.Module): 6 | def __init__(self, n_heads, att_dim, att_out_dim, soft_temperature, dim_q, dim_k, dim_v): 7 | super(MultiHeadAttention, self).__init__() 8 | assert (att_dim % n_heads) == 0, "n_heads must divide att_dim" 9 | self.att_dim = att_dim 10 | self.att_out_dim = att_out_dim 11 | self.head_att_dim = att_dim // n_heads 12 | self.n_heads = n_heads 13 | self.temperature = self.head_att_dim ** 0.5 / soft_temperature 14 | 15 | self.fc_q = nn.Linear(dim_q, self.att_dim, bias=False) 16 | self.fc_k = nn.Linear(dim_k, self.att_dim, bias=False) 17 | self.fc_v = nn.Linear(dim_v, self.att_dim) 18 | self.fc_final = nn.Linear(self.att_dim, self.att_out_dim) 19 | 20 | def forward(self, q, k, v): 21 | # q.shape = (batch, N, dim) 22 | batch_size = q.shape[0] 23 | # shape = (batch*N, att_dim)->(batch, N, heads, head_att_dim)->(batch, heads, N, head_att_dim) 24 | q = self.fc_q(q.view(-1, q.shape[2])).view(batch_size, -1, self.n_heads, self.head_att_dim).transpose(1, 2) 25 | # shape = (batch*N, att_dim)->(batch, N, heads, head_att_dim)->(batch, heads, head_att_dim, N) 26 | k_T = self.fc_k(k.view(-1, k.shape[2])).view(batch_size, -1, self.n_heads, self.head_att_dim).permute(0,2,3,1) 27 | v = self.fc_v(v.view(-1, v.shape[2])).view(batch_size, -1, self.n_heads, self.head_att_dim).transpose(1, 2) 28 | alpha = F.softmax(torch.matmul(q/self.temperature, k_T), dim=-1) # shape = (batch, heads, N, N) 29 | # shape = (batch, heads, N, head_att_dim)->(batch, N, heads, head_att_dim)->(batch, N, att_dim) 30 | result = torch.matmul(alpha, v).transpose(1, 2).reshape(batch_size, -1, self.att_dim) 31 | result = self.fc_final(result) # shape = (batch, N, att_out_dim) 32 | return result 33 | -------------------------------------------------------------------------------- /ACORM_QMIX/util/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | class Q_network_RNN(nn.Module): 6 | def __init__(self, args, input_dim): 7 | super(Q_network_RNN, self).__init__() 8 | self.rnn_hidden = None 9 | 10 | self.fc1 = nn.Linear(input_dim, args.rnn_hidden_dim) 11 | self.rnn = nn.GRUCell(args.rnn_hidden_dim, args.rnn_hidden_dim) 12 | self.fc2 = nn.Linear(args.rnn_hidden_dim, args.action_dim) 13 | 14 | def forward(self, inputs): 15 | # When 'choose_action', inputs.shape(N,input_dim) 16 | # When 'train', inputs.shape(bach_size*N,input_dim) 17 | x = F.relu(self.fc1(inputs)) 18 | self.rnn_hidden = self.rnn(x, self.rnn_hidden) 19 | Q = self.fc2(self.rnn_hidden) 20 | return Q 21 | 22 | 23 | class Q_network_MLP(nn.Module): 24 | def __init__(self, args, input_dim): 25 | super(Q_network_MLP, self).__init__() 26 | self.rnn_hidden = None 27 | 28 | self.fc1 = nn.Linear(input_dim, args.mlp_hidden_dim) 29 | self.fc2 = nn.Linear(args.mlp_hidden_dim, args.mlp_hidden_dim) 30 | self.fc3 = nn.Linear(args.mlp_hidden_dim, args.action_dim) 31 | 32 | def forward(self, inputs): 33 | # When 'choose_action', inputs.shape(N,input_dim) 34 | # When 'train', inputs.shape(bach_size,max_episode_len,N,input_dim) 35 | x = F.relu(self.fc1(inputs)) 36 | x = F.relu(self.fc2(x)) 37 | Q = self.fc3(x) 38 | return Q 39 | 40 | class Agent_Embedding(nn.Module): 41 | def __init__(self, args): 42 | super(Agent_Embedding, self).__init__() 43 | self.input_dim = args.obs_dim + args.action_dim 44 | self.agent_embedding_dim = args.agent_embedding_dim 45 | 46 | self.fc1 = nn.Linear(self.input_dim, self.agent_embedding_dim) 47 | self.rnn_hidden = None 48 | self.agent_embedding_fc = nn.GRUCell(self.agent_embedding_dim, self.agent_embedding_dim) 49 | self.fc2 = nn.Linear(self.agent_embedding_dim, self.agent_embedding_dim) 50 | 51 | def forward(self, obs, last_a, detach=False): 52 | inputs = torch.cat([obs, last_a], dim=-1) 53 | fc1_out = torch.relu(self.fc1(inputs)) 54 | self.rnn_hidden = self.agent_embedding_fc(fc1_out, self.rnn_hidden) 55 | fc2_out = self.fc2(self.rnn_hidden) 56 | if detach: 57 | fc2_out.detach() 58 | return fc2_out 59 | 60 | class Agent_Embedding_Decoder(nn.Module): 61 | def __init__(self, args): 62 | super(Agent_Embedding_Decoder, self).__init__() 63 | self.agent_embedding_dim = args.agent_embedding_dim 64 | self.decoder_out_dim = args.obs_dim + args.N # out_put: o(t+1)+agent_idx 65 | 66 | self.fc1 = nn.Linear(self.agent_embedding_dim, self.agent_embedding_dim) 67 | self.fc2 = nn.Linear(self.agent_embedding_dim, self.decoder_out_dim) 68 | 69 | def forward(self, agent_embedding): 70 | fc1_out = torch.relu(self.fc1(agent_embedding)) 71 | decoder_out = self.fc2(fc1_out) 72 | return decoder_out 73 | 74 | 75 | class Role_Embedding(nn.Module): 76 | def __init__(self, args): 77 | super(Role_Embedding, self).__init__() 78 | self.agent_embedding_dim = args.agent_embedding_dim 79 | self.role_embedding_dim = args.role_embedding_dim 80 | self.use_ln = args.use_ln 81 | 82 | if self.use_ln: # 使用layer_norm 83 | self.role_embeding = nn.ModuleList([nn.Linear(self.agent_embedding_dim, self.role_embedding_dim), 84 | nn.LayerNorm(self.role_embedding_dim)]) 85 | else: 86 | self.role_embeding = nn.Linear(self.agent_embedding_dim, self.role_embedding_dim) 87 | 88 | def forward(self, agent_embedding, detach=False): 89 | if self.use_ln: 90 | output = self.role_embeding[1](self.role_embeding[0](agent_embedding)) 91 | else: 92 | output = self.role_embeding(agent_embedding) 93 | 94 | if detach: 95 | output.detach() 96 | output = torch.sigmoid(output) 97 | return output -------------------------------------------------------------------------------- /ACORM_QMIX/util/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class ReplayBuffer: 6 | def __init__(self, args, buffer_size): 7 | self.N = args.N 8 | self.obs_dim = args.obs_dim 9 | self.state_dim = args.state_dim 10 | self.action_dim = args.action_dim 11 | self.episode_limit = args.episode_limit 12 | self.buffer_size = buffer_size 13 | self.episode_num = 0 14 | self.current_size = 0 15 | self.buffer = {'obs_n': np.zeros([self.buffer_size, self.episode_limit + 1, self.N, self.obs_dim]), 16 | 's': np.zeros([self.buffer_size, self.episode_limit + 1, self.state_dim]), 17 | 'avail_a_n': np.ones([self.buffer_size, self.episode_limit + 1, self.N, self.action_dim]), # Note: We use 'np.ones' to initialize 'avail_a_n' 18 | 'last_onehot_a_n': np.zeros([self.buffer_size, self.episode_limit + 1, self.N, self.action_dim]), 19 | 'a_n': np.zeros([self.buffer_size, self.episode_limit, self.N]), 20 | 'r': np.zeros([self.buffer_size, self.episode_limit, 1]), 21 | 'dw': np.ones([self.buffer_size, self.episode_limit, 1]), # Note: We use 'np.ones' to initialize 'dw' 22 | 'active': np.zeros([self.buffer_size, self.episode_limit, 1]) 23 | } 24 | self.episode_len = np.zeros(self.buffer_size) 25 | 26 | def store_transition(self, episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw): 27 | self.buffer['obs_n'][self.episode_num][episode_step] = obs_n 28 | self.buffer['s'][self.episode_num][episode_step] = s 29 | self.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_n 30 | self.buffer['last_onehot_a_n'][self.episode_num][episode_step + 1] = last_onehot_a_n 31 | self.buffer['a_n'][self.episode_num][episode_step] = a_n 32 | self.buffer['r'][self.episode_num][episode_step] = r 33 | self.buffer['dw'][self.episode_num][episode_step] = dw 34 | self.buffer['active'][self.episode_num][episode_step] = 1.0 35 | 36 | def store_last_step(self, episode_step, obs_n, s, avail_a_n, last_onehot_a_n): 37 | self.buffer['obs_n'][self.episode_num][episode_step] = obs_n 38 | self.buffer['s'][self.episode_num][episode_step] = s 39 | self.buffer['avail_a_n'][self.episode_num][episode_step] = avail_a_n 40 | self.buffer['last_onehot_a_n'][self.episode_num][episode_step] = last_onehot_a_n 41 | self.buffer['active'][self.episode_num][episode_step:] = 0 42 | self.episode_len[self.episode_num] = episode_step # Record the length of this episode 43 | self.episode_num = (self.episode_num + 1) % self.buffer_size 44 | self.current_size = min(self.current_size + 1, self.buffer_size) 45 | 46 | def sample(self, batch_size): 47 | # Randomly sampling 48 | index = np.random.choice(self.current_size, size=batch_size, replace=False) 49 | max_episode_len = int(np.max(self.episode_len[index])) 50 | batch = {} 51 | for key in self.buffer.keys(): 52 | if key == 'obs_n' or key == 's' or key == 'avail_a_n' or key == 'last_onehot_a_n': 53 | batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len + 1], dtype=torch.float32) 54 | elif key == 'a_n': 55 | batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len], dtype=torch.long) 56 | else: 57 | batch[key] = torch.tensor(self.buffer[key][index, :max_episode_len], dtype=torch.float32) 58 | 59 | return batch, max_episode_len 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # **Attention-Guided Contrastive Role Representations for Multi-agent Reinforcement Learning** 2 | 3 | Zican Hu, Zongzhang Zhang, Huaxiong Li, Chunlin Chen, Hongyu Ding, Zhi Wang* 4 | 5 | A link to our paper can be found on [Paper Link](https://openreview.net/forum?id=LWmuPfEYhH) 6 | 7 | ## **Overview** 8 | 9 | ![ACORM_QMIX](./ACORM_QMIX.jpg) 10 | 11 | ## **Instructions** 12 | 13 | ACORM tested on two benchmark tasks [SMAC ](https://github.com/oxwhirl/smac) and [GRF](https://github.com/google-research/football) based on two algorithm framework [QMIX](https://arxiv.org/abs/1803.11485) and [MAPPO](https://arxiv.org/abs/2103.01955). 14 | 15 | ## **Citation** 16 | 17 | Please cite our paper as: 18 | ```tex 19 | @inproceedings{ 20 | hu2024attentionguided, 21 | title={Attention-Guided Contrastive Role Representations for Multi-agent Reinforcement Learning}, 22 | author={Zican Hu and Zongzhang Zhang and Huaxiong Li and Chunlin Chen and Hongyu Ding and Zhi Wang}, 23 | booktitle={The Twelfth International Conference on Learning Representations}, 24 | year={2024}, 25 | url={https://openreview.net/forum?id=LWmuPfEYhH} 26 | } 27 | ``` 28 | 29 | ## **experiment instructions** 30 | 31 | ### **Installation instructions** 32 | Download the Linux version 4.10 of StarCraft II from the Blizzard's [repository](https://github.com/Blizzard/s2client-proto#downloads). By default, the game is expected to be in `~/StarCraftII/` directory. 33 | See `requirments.txt` file for more information about how to install the dependencies. 34 | ```python 35 | conda create -n acorm python=3.9.16 -y 36 | conda activate acorm 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ### Run an experiment 41 | 42 | You can execute the following command to run ACORM based on QMIX with a map config, such as `MMM2`: 43 | 44 | ```python 45 | python ./ACORM_QMIX/main.py --algorithm ACORM --env_name MMM2 --cluster_num 3 --max_train_steps 3050000 46 | ``` 47 | or you can execute the following command to run ACORM base on MAPPO with a map config, such as `corridor` 48 | 49 | ```python 50 | python ./ACORM_MAPPO/main.py --algorithm ACORM --env_name corridor --cluster_num 3 --max_train_steps 5050000 51 | ``` 52 | 53 | All results will be stored in the `ACORM_QMIX or ACORM_MAPPO/results` folder. You can see the console output, config, and tensorboard logging in the `ACORM_QMIX or ACORM_MAPPO/results/tb_logs` folder. 54 | 55 | You can plot the curve with `seaborn`: 56 | 57 | ```python 58 | python plot.py --algorithm 'ACORM_QMIX' or 'ACORM_MAPPO' 59 | ``` 60 | 61 | ## License 62 | 63 | Code licensed under the Apache License v2.0. 64 | 65 | -------------------------------------------------------------------------------- /ablation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ablation.jpg -------------------------------------------------------------------------------- /ablation_k_means.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/ablation_k_means.jpg -------------------------------------------------------------------------------- /plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | import argparse 5 | 6 | def smooth(y, radius, mode='two_sided', valid_only=False): 7 | assert mode in ('two_sided', 'causal') 8 | if len(y) < 2*radius+1: 9 | return np.ones_like(y) * y.mean() 10 | elif mode == 'two_sided': 11 | convkernel = np.ones(2 * radius+1) 12 | out = np.convolve(y, convkernel,mode='same') / np.convolve(np.ones_like(y), convkernel, mode='same') 13 | if valid_only: 14 | out[:radius] = out[-radius:] = np.nan 15 | elif mode == 'causal': 16 | convkernel = np.ones(radius) 17 | out = np.convolve(y, convkernel,mode='full') / np.convolve(np.ones_like(y), convkernel, mode='full') 18 | out = out[:-radius+1] 19 | if valid_only: 20 | out[:radius] = np.nan 21 | return out 22 | 23 | 24 | parser = argparse.ArgumentParser("Hyperparameter Setting for QMIX,VDN and RECL_QMIX in SMAC environment") 25 | parser.add_argument("--algorithm", default='ACORM_QMIX', help="env_names") 26 | args = parser.parse_args() 27 | 28 | if args.algorithm == 'ACORM_QMIX': 29 | algs = ['acorm','qmix'] 30 | labels = ['ACORM(Ours)','QMIX'] 31 | elif args.algorithm == 'ACORM_MAPPO': 32 | algs = ['acorm','mappo'] 33 | labels = ['ACORM(Ours)','MAPPO'] 34 | color = ['#D62728', '#2A9F2A', '#FF7F0E','#1B75B4','#9467BD','#7F7F7F','#E377C2', '#BBBC1F'] 35 | 36 | env_names = ['2s3z', '1c3s5z', '3s5z', # easy 37 | 'bane_vs_bane','2c_vs_64zg','5m_vs_6m','10m_vs_11m', # hard 38 | '27m_vs_30m','MMM2','3s5z_vs_3s6z','6h_vs_8z','corridor'] # super hard 39 | env_difficulty = ['easy']*3 + ['hard']*4 + ['super hard']*5 40 | 41 | sns.set_style('ticks') 42 | sns.set_context('talk') 43 | fig = plt.figure(figsize=(20, 18), dpi=800) 44 | # Grid = plt.GridSpec(2, 3, wspace=0.2, hspace=0.4) 45 | Grid = plt.GridSpec(4, 3, wspace=0.2, hspace=0.4) 46 | plt.rcParams.update({'font.size': 15}) 47 | 48 | 49 | for sub_i, env_name in enumerate(env_names): 50 | print(env_name) 51 | # fig=plt.figure(figsize=(6, 4),dpi=120) 52 | sub_ax = plt.subplot(Grid[sub_i//3,sub_i%3]) 53 | plt.title(f'{env_name} ({env_difficulty[sub_i]})') 54 | for index, alg in enumerate(algs): 55 | print(alg) 56 | vdn_qmix_data = [] 57 | for seed in range(1): 58 | dir = (f'./{args.algorithm}/result/sacred/{alg}/{env_name}_seed{seed}.npy') 59 | # if env_name in ['corridor']: 60 | # data = np.load(dir, allow_pickle=True)[:980] 61 | if env_name in ['3s5z_vs_3s6z','6h_vs_8z','corridor']: 62 | data = np.load(dir, allow_pickle=True)[:600] 63 | elif env_name in ['bane_vs_bane','2s3z','1c3s5z']: 64 | data = np.load(dir, allow_pickle=True)[:200] 65 | elif env_name in ['27m_vs_30m','10m_vs_11m','5m_vs_6m','2c_vs_64zg']: 66 | data = np.load(dir, allow_pickle=True)[:300] 67 | else: 68 | data = np.load(dir, allow_pickle=True)[:400] 69 | data[0:5] = 0.0 70 | data = smooth(data, radius=3) 71 | vdn_qmix_data.append(data) 72 | vdn_qmix_data = np.array(vdn_qmix_data) 73 | x_step = np.tile(np.array(range(vdn_qmix_data.shape[1]))*5000, vdn_qmix_data.shape[0]) 74 | ax = sns.lineplot(x=x_step, y=vdn_qmix_data.flatten(),label=labels[index], color=color[index], linewidth=2) 75 | 76 | plt.grid(True,linestyle='-.',alpha=0.4) 77 | plt.legend(fontsize = 11) 78 | plt.ylabel('Test Win Rate', labelpad=-0.5) 79 | plt.xlabel(f'Timesteps') 80 | if env_name in ['3s5z_vs_3s6z','6h_vs_8z', 'corridor']: 81 | plt.xticks(np.array(range(data.shape[0]//100 +1))*500000) 82 | plt.yticks(np.array(range(0, 10+2, 2))/10) 83 | handles, labels = sub_ax.get_legend_handles_labels() 84 | sub_ax.legend_.remove() 85 | 86 | # plt.legend(handles, labels, ncol=6, bbox_to_anchor=(0.44, 5.55)) 87 | plt.legend(handles, labels, ncol=8, bbox_to_anchor=(0.82, 5.55)) 88 | 89 | plt.tight_layout() 90 | plt.savefig(f'./smac_{args.algorithm}.jpg',bbox_inches='tight') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | attrs==22.2.0 3 | cachetools==5.3.0 4 | certifi @ file:///croot/certifi_1671487769961/work/certifi 5 | charset-normalizer==3.1.0 6 | cloudpickle==2.2.1 7 | colorama==0.4.6 8 | cycler==0.11.0 9 | deepdiff==6.3.0 10 | dm-env==1.6 11 | dm-env-rpc==1.1.5 12 | dm-tree==0.1.8 13 | docopt==0.6.2 14 | enum34==1.1.10 15 | exceptiongroup==1.1.1 16 | fonttools==4.38.0 17 | gitdb==4.0.10 18 | GitPython==3.1.31 19 | google-auth==2.16.3 20 | google-auth-oauthlib==0.4.6 21 | googleapis-common-protos==1.59.0 22 | grpcio==1.51.3 23 | gym==0.26.2 24 | gym-notices==0.0.8 25 | idna==3.4 26 | imageio==2.26.1 27 | immutabledict==2.2.3 28 | importlib-metadata==6.1.0 29 | iniconfig==2.0.0 30 | joblib==1.2.0 31 | jsonpickle==1.5.2 32 | kiwisolver==1.4.4 33 | kmeans-pytorch==0.3 34 | lbforaging==1.1.1 35 | Markdown==3.4.3 36 | MarkupSafe==2.1.2 37 | matplotlib==3.5.3 38 | mock==5.0.1 39 | mpyq==0.2.5 40 | munch==2.5.0 41 | numpy==1.21.6 42 | nvidia-cublas-cu11==11.10.3.66 43 | nvidia-cuda-nvrtc-cu11==11.7.99 44 | nvidia-cuda-runtime-cu11==11.7.99 45 | nvidia-cudnn-cu11==8.5.0.96 46 | oauthlib==3.2.2 47 | ordered-set==4.1.0 48 | packaging==23.0 49 | pandas==1.3.5 50 | Pillow==9.4.0 51 | pluggy==1.0.0 52 | portpicker==1.5.2 53 | probscale==0.2.5 54 | protobuf==3.19.5 55 | psutil==5.9.4 56 | py-cpuinfo==9.0.0 57 | pyasn1==0.4.8 58 | pyasn1-modules==0.2.8 59 | pygame==2.3.0 60 | pyglet==2.0.7 61 | pyparsing==3.0.9 62 | PySC2==4.0.0 63 | pytest==7.2.2 64 | python-dateutil==2.8.2 65 | pytz==2022.7.1 66 | PyYAML==3.13 67 | requests==2.28.2 68 | requests-oauthlib==1.3.1 69 | rsa==4.9 70 | s2clientprotocol==5.0.11.89720.0 71 | s2protocol==5.0.11.89720.0 72 | sacred==0.8.2 73 | scikit-learn==1.0.2 74 | scipy==1.7.3 75 | seaborn==0.12.2 76 | six==1.16.0 77 | sk-video==1.1.10 78 | sklearn==0.0.post1 79 | SMAC @ git+https://github.com/oxwhirl/smac.git@8a092e4bd0c6f5d3cf929523c67e1bca861463aa 80 | smmap==5.0.0 81 | snakeviz==2.1.1 82 | tensorboard==2.11.2 83 | tensorboard-data-server==0.6.1 84 | tensorboard-logger==0.1.0 85 | tensorboard-plugin-wit==1.8.1 86 | threadpoolctl==3.1.0 87 | tomli==2.0.1 88 | torch==1.13.1 89 | torchaudio==0.13.1 90 | torchvision==0.14.1 91 | tornado==6.2 92 | tqdm==4.65.0 93 | typing_extensions==4.5.0 94 | urllib3==1.26.15 95 | websocket-client==1.5.1 96 | Werkzeug==2.2.3 97 | wrapt==1.15.0 98 | zipp==3.15.0 99 | -------------------------------------------------------------------------------- /results_grf.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/results_grf.jpg -------------------------------------------------------------------------------- /results_mappo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/results_mappo.jpg -------------------------------------------------------------------------------- /results_smac.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/results_smac.jpg -------------------------------------------------------------------------------- /visual_mha.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/visual_mha.jpg -------------------------------------------------------------------------------- /visual_t.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NJU-RL/ACORM/67aa762e3ccc1d7333f77765d9780b2f30aca296/visual_t.jpg --------------------------------------------------------------------------------