├── saved_3m └── agents_113500 ├── saved_8m └── agents_144500 ├── saved_3z_vs_3s └── agents_45000 ├── launch tensorboard.bat ├── launch.bat ├── launch eval.bat ├── README.md ├── rnn_agent.py ├── qmixer.py ├── train.py ├── runner.py └── qmix.py /saved_3m/agents_113500: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gouet/QMIX-Starcraft/HEAD/saved_3m/agents_113500 -------------------------------------------------------------------------------- /saved_8m/agents_144500: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gouet/QMIX-Starcraft/HEAD/saved_8m/agents_144500 -------------------------------------------------------------------------------- /saved_3z_vs_3s/agents_45000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Gouet/QMIX-Starcraft/HEAD/saved_3z_vs_3s/agents_45000 -------------------------------------------------------------------------------- /launch tensorboard.bat: -------------------------------------------------------------------------------- 1 | call C:\Users\Victor\Anaconda3\Scripts\activate.bat 2 | call conda activate GYM_ENV_RL 3 | 4 | tensorboard --logdir=./logs --host localhost --port 8000 5 | pause -------------------------------------------------------------------------------- /launch.bat: -------------------------------------------------------------------------------- 1 | call C:\Users\Victor\Anaconda3\Scripts\activate.bat 2 | call conda activate GYM_ENV_RL 3 | set SC2PATH=C:\Program Files (x86)\StarCraft II 4 | 5 | python train.py --train --scenario 2c_vs_64zg 6 | pause -------------------------------------------------------------------------------- /launch eval.bat: -------------------------------------------------------------------------------- 1 | call C:\Users\Victor\Anaconda3\Scripts\activate.bat 2 | call conda activate GYM_ENV_RL 3 | set SC2PATH=C:\Program Files (x86)\StarCraft II 4 | 5 | python train.py --load-episode-saved 75000 --scenario 2c_vs_64zg 6 | pause -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # QMIX-Starcraft 2 | 3 | ## Research Paper and environment 4 | 5 | [*QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning*](https://arxiv.org/pdf/1803.11485.pdf) 6 | 7 | [*The StarCraft Multi-Agent Challenge : Environment Code*](https://github.com/oxwhirl/smac) 8 | 9 | [*The StarCraft Multi-Agent Challenge : Research Paper*](https://arxiv.org/pdf/1902.04043.pdf) 10 | 11 | 12 | ## Setup 13 | 14 | Using Pytorch 1.3. 15 | 16 | Anaconda. 17 | 18 | Windows 10. 19 | 20 | Be sure to set up the environment variable : SC2PATH (see lauch.bat) 21 | 22 | ### Train an AI 23 | 24 | ``` 25 | python train.py --scenario [scenario_name] --train 26 | ``` 27 | 28 | *or* 29 | 30 | ``` 31 | launch.bat 32 | ``` 33 | 34 | 35 | ### Launch the AI 36 | 37 | ``` 38 | python train.py --scenario [scenario_name] --load-episode-saved [episode number] 39 | ``` 40 | 41 | *or* 42 | 43 | 44 | ``` 45 | launch eval.bat 46 | ``` 47 | 48 | This will generate a SC2Replay file in {SC2_PATH}/Replays/replay 49 | -------------------------------------------------------------------------------- /rnn_agent.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class RNNAgent(nn.Module): 6 | def __init__(self, input_shape, rnn_hidden_dim=64, n_actions=1): 7 | super(RNNAgent, self).__init__() 8 | self.rnn_hidden_dim = rnn_hidden_dim 9 | 10 | print('input_shape: ', input_shape) 11 | 12 | self.fc1 = nn.Linear(input_shape, rnn_hidden_dim) 13 | self.rnn = nn.GRUCell(rnn_hidden_dim, rnn_hidden_dim) 14 | self.fc2 = nn.Linear(rnn_hidden_dim, n_actions) 15 | 16 | def init_hidden(self): 17 | # make hidden states on same device as model 18 | return self.fc1.weight.new(1, self.rnn_hidden_dim).zero_() 19 | 20 | def forward(self, inputs, hidden_state): 21 | x = F.relu(self.fc1(inputs)) 22 | h_in = hidden_state.reshape(-1, self.rnn_hidden_dim) 23 | 24 | h = self.rnn(x, h_in) 25 | q = self.fc2(h) 26 | return q, h 27 | 28 | def update(self, agent): 29 | for param, target_param in zip(agent.parameters(), self.parameters()): 30 | target_param.data.copy_(param.data) 31 | -------------------------------------------------------------------------------- /qmixer.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class QMixer(nn.Module): 8 | def __init__(self, n_agents, state_shape, mixing_embed_dim=64): 9 | super(QMixer, self).__init__() 10 | 11 | #self.args = args 12 | self.n_agents = n_agents 13 | self.state_dim = int(np.prod(state_shape)) 14 | 15 | self.embed_dim = mixing_embed_dim 16 | 17 | self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) 18 | self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) 19 | 20 | # State dependent bias for hidden layer 21 | self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 22 | 23 | # V(s) instead of a bias for the last layers 24 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 25 | nn.ReLU(), 26 | nn.Linear(self.embed_dim, 1)) 27 | 28 | def forward(self, agent_qs, states): 29 | bs = agent_qs.size(0) 30 | states = states.reshape(-1, self.state_dim) 31 | agent_qs = agent_qs.view(-1, 1, self.n_agents) 32 | # First layer 33 | w1 = th.abs(self.hyper_w_1(states)) 34 | b1 = self.hyper_b_1(states) 35 | w1 = w1.view(-1, self.n_agents, self.embed_dim) 36 | b1 = b1.view(-1, 1, self.embed_dim) 37 | hidden = F.elu(th.bmm(agent_qs, w1) + b1) 38 | # Second layer 39 | w_final = th.abs(self.hyper_w_final(states)) 40 | w_final = w_final.view(-1, self.embed_dim, 1) 41 | # State-dependent bias 42 | v = self.V(states).view(-1, 1, 1) 43 | # Compute final output 44 | y = th.bmm(hidden, w_final) + v 45 | # Reshape and return 46 | q_tot = y.view(bs, -1, 1) 47 | return q_tot 48 | 49 | def update(self, agent): 50 | for param, target_param in zip(agent.parameters(), self.parameters()): 51 | target_param.data.copy_(param.data) -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from smac.env import StarCraft2Env 2 | import numpy as np 3 | import qmix 4 | import torch 5 | import os 6 | import argparse 7 | from time import gmtime, strftime 8 | from torch.utils.tensorboard import SummaryWriter 9 | import runner 10 | 11 | use_cuda = torch.cuda.is_available() 12 | device = torch.device("cuda" if use_cuda else "cpu") 13 | 14 | def main(arglist): 15 | current_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) 16 | writer = SummaryWriter(log_dir='./logs/' + current_time + '-snake') 17 | actors = 6 18 | if arglist.train == False: 19 | actors = 1 20 | env_runner = runner.Runner(arglist, arglist.scenario, actors) 21 | 22 | 23 | while arglist.train or env_runner.episode < 1: 24 | env_runner.reset() 25 | replay_buffers = env_runner.run() 26 | for replay_buffer in replay_buffers: 27 | env_runner.qmix_algo.episode_batch.add(replay_buffer) 28 | env_runner.qmix_algo.train() 29 | for episode in env_runner.episodes: 30 | env_runner.qmix_algo.update_targets(episode) 31 | 32 | for episode in env_runner.episodes: 33 | if episode % 500 == 0 and arglist.train: 34 | env_runner.qmix_algo.save_model('./saved/agents_' + str(episode)) 35 | 36 | print(env_runner.win_counted_array) 37 | for idx, episode in enumerate(env_runner.episodes): 38 | print("Total reward in episode {} = {} and global step: {}".format(episode, env_runner.episode_reward[idx], env_runner.episode_global_step)) 39 | 40 | if arglist.train: 41 | writer.add_scalar('Reward', env_runner.episode_reward[idx], episode) 42 | writer.add_scalar('Victory', env_runner.win_counted_array[idx], episode) 43 | 44 | 45 | if arglist.train == False: 46 | env_runner.save() 47 | 48 | env_runner.close() 49 | 50 | def parse_args(): 51 | parser = argparse.ArgumentParser('Reinforcement Learning parser for DQN') 52 | 53 | parser.add_argument('--train', action='store_true') #"3m" 54 | parser.add_argument('--load-episode-saved', type=int, default=8000) 55 | parser.add_argument('--scenario', type=str, default="3m") 56 | 57 | return parser.parse_args() 58 | 59 | if __name__ == "__main__": 60 | try: 61 | os.mkdir('./saved') 62 | except OSError: 63 | print ("Creation of the directory failed") 64 | else: 65 | print ("Successfully created the directory") 66 | arglist = parse_args() 67 | main(arglist) -------------------------------------------------------------------------------- /runner.py: -------------------------------------------------------------------------------- 1 | import qmix 2 | from smac.env import StarCraft2Env 3 | import numpy as np 4 | import torch 5 | from multiprocessing import Process, Lock, Pipe, Value 6 | from threading import Thread 7 | import time 8 | 9 | use_cuda = torch.cuda.is_available() 10 | device = torch.device("cuda" if use_cuda else "cpu") 11 | 12 | class Transform: 13 | def transform(self, tensor): 14 | raise NotImplementedError 15 | 16 | def infer_output_info(self, vshape_in, dtype_in): 17 | raise NotImplementedError 18 | 19 | class OneHot(Transform): 20 | def __init__(self, out_dim): 21 | self.out_dim = out_dim 22 | 23 | def transform(self, tensor): 24 | y_onehot = tensor.new(*tensor.shape[:-1], self.out_dim).zero_() 25 | y_onehot.scatter_(-1, tensor.long(), 1) 26 | return y_onehot.float() 27 | 28 | def infer_output_info(self, vshape_in, dtype_in): 29 | return (self.out_dim,), th.float32 30 | 31 | def env_run(scenario, id, child_conn, locker, replay_buffer_size): 32 | #qmix_algo = wrapper.value 33 | env = StarCraft2Env(map_name=scenario, replay_dir="./replay/") 34 | 35 | env_info = env.get_env_info() 36 | 37 | process_id = id 38 | 39 | 40 | action_n = env_info["n_actions"] 41 | agent_nb = env_info["n_agents"] 42 | state_shape = env_info["state_shape"] 43 | obs_shape = env_info["obs_shape"] + agent_nb + action_n 44 | #self.episode_limit = env_info['episode_limit'] 45 | 46 | agent_id_one_hot = OneHot(agent_nb) 47 | actions_one_hot = OneHot(action_n) 48 | 49 | agent_id_one_hot_array = [] 50 | for agent_id in range(agent_nb): 51 | agent_id_one_hot_array.append(agent_id_one_hot.transform(torch.FloatTensor([agent_id])).cpu().detach().numpy()) 52 | agent_id_one_hot_array = np.array(agent_id_one_hot_array) 53 | actions_one_hot_reset = torch.zeros_like(torch.empty(agent_nb, action_n)) 54 | 55 | state_zeros = np.zeros(state_shape) 56 | obs_zeros = np.zeros((agent_nb, obs_shape)) 57 | actions_zeros = np.zeros([agent_nb, 1]) 58 | reward_zeros = 0 59 | agents_available_actions_zeros = np.zeros((agent_nb, action_n)) 60 | agents_available_actions_zeros[:,0] = 1 61 | 62 | child_conn.send(id) 63 | 64 | while True: 65 | 66 | while True: 67 | data = child_conn.recv() 68 | if data == 'save': 69 | env.save_replay() 70 | child_conn.send('save ok.') 71 | elif data == 'close': 72 | env.close() 73 | exit() 74 | else: 75 | break 76 | 77 | locker.acquire() 78 | env.reset() 79 | locker.release() 80 | 81 | episode_reward = 0 82 | episode_step = 0 83 | 84 | obs = np.array(env.get_obs()) 85 | obs = np.concatenate([obs, actions_one_hot_reset, agent_id_one_hot_array], axis=-1) 86 | state = np.array(env.get_state()) 87 | terminated = False 88 | #replay_buffer = qmix.ReplayBuffer(replay_buffer_size) 89 | 90 | while not terminated: 91 | 92 | agents_available_actions = [] 93 | for agent_id in range(agent_nb): 94 | agents_available_actions.append(env.get_avail_agent_actions(agent_id)) 95 | 96 | #locker.acquire() 97 | child_conn.send(["actions", obs, agents_available_actions]) 98 | actions = child_conn.recv() 99 | #actions = qmix_algo.act(torch.FloatTensor(obs).to(device), torch.FloatTensor(agents_available_actions).to(device)) 100 | #locker.release() 101 | 102 | reward, terminated, _ = env.step(actions) 103 | #self.terminated[i] = terminated 104 | 105 | agents_available_actions2 = [] 106 | for agent_id in range(agent_nb): 107 | agents_available_actions2.append(env.get_avail_agent_actions(agent_id)) 108 | 109 | obs2 = np.array(env.get_obs()) 110 | actions_one_hot_agents = [] 111 | for action in actions: 112 | actions_one_hot_agents.append(actions_one_hot.transform(torch.FloatTensor(action)).cpu().detach().numpy()) 113 | actions_one_hot_agents = np.array(actions_one_hot_agents) 114 | 115 | obs2 = np.concatenate([obs2, actions_one_hot_agents, agent_id_one_hot_array], axis=-1) 116 | state2 = np.array(env.get_state()) 117 | 118 | child_conn.send(["replay_buffer", state, actions, [reward], [terminated], obs, agents_available_actions, 0]) 119 | #replay_buffer.add(state, state2, actions, [reward], [terminated], obs, obs2, agents_available_actions, agents_available_actions2, 0) 120 | 121 | #self.qmix_algo.decay_epsilon_greddy(self.episode_global_step) 122 | 123 | episode_reward += reward 124 | episode_step += 1 125 | 126 | obs = obs2 127 | state = state2 128 | 129 | #episode_global_step += 1 130 | 131 | for _ in range(episode_step, replay_buffer_size): 132 | child_conn.send(["actions", obs_zeros, agents_available_actions_zeros]) 133 | child_conn.send(["replay_buffer", state_zeros, actions_zeros, [reward_zeros], [True], obs_zeros, agents_available_actions_zeros, 1]) 134 | child_conn.recv() 135 | #replay_buffer.add(state_zeros, state_zeros, actions_zeros, [reward_zeros], [True], obs_zeros, obs_zeros, agents_available_actions_zeros, agents_available_actions_zeros, 1) 136 | 137 | child_conn.send(["episode_end", episode_reward, episode_step, env.win_counted]) 138 | pass 139 | 140 | class Runner: 141 | def __init__(self, arglist, scenario, actors): 142 | env = StarCraft2Env(map_name=scenario, replay_dir="./replay/") 143 | 144 | env_info = env.get_env_info() 145 | 146 | self.actors = actors 147 | self.scenario = scenario 148 | 149 | self.n_actions = env_info["n_actions"] 150 | self.n_agents = env_info["n_agents"] 151 | self.state_shape = env_info["state_shape"] 152 | self.obs_shape = env_info["obs_shape"] + self.n_agents + self.n_actions 153 | self.episode_limit = env_info['episode_limit'] 154 | 155 | self.qmix_algo = qmix.QMix(arglist.train, self.n_agents, self.obs_shape, self.state_shape, self.n_actions, 0.0005, replay_buffer_size=1000) 156 | if arglist.train == False: 157 | self.qmix_algo.load_model('./saved/agents_' + str(arglist.load_episode_saved)) 158 | print('Load model agent ', str(arglist.load_episode_saved)) 159 | 160 | self.episode_global_step = 0 161 | self.episode = 0 162 | 163 | self.process_com = [] 164 | self.locker = Lock() 165 | for idx in range(self.actors): 166 | parent_conn, child_conn = Pipe() 167 | Process(target=env_run, args=[self.scenario, idx, child_conn, self.locker, self.episode_limit]).start() 168 | self.process_com.append(parent_conn) 169 | 170 | for process_conn in self.process_com: 171 | process_id = process_conn.recv() 172 | print(process_id, " is ready !") 173 | 174 | pass 175 | 176 | def reset(self): 177 | self.qmix_algo.on_reset(self.actors) 178 | self.episodes = [] 179 | self.episode_reward = [] 180 | self.episode_step = [] 181 | self.replay_buffers = [] 182 | self.win_counted_array = [] 183 | episode_managed = self.episode 184 | for _ in range(self.actors): 185 | self.episodes.append(episode_managed) 186 | self.episode_reward.append(0) 187 | self.episode_step.append(0) 188 | self.win_counted_array.append(False) 189 | self.replay_buffers.append(qmix.ReplayBuffer(self.episode_limit)) 190 | episode_managed += 1 191 | for process_conn in self.process_com: 192 | process_conn.send("Go !") 193 | 194 | def run(self): 195 | episode_done = 0 196 | process_size = len(self.process_com) 197 | available_to_send = np.array([True for _ in range(self.actors)]) 198 | 199 | while True: 200 | obs_batch = [] 201 | available_batch = [] 202 | actions = None 203 | for idx, process_conn in enumerate(self.process_com): 204 | #if process_conn.poll(): 205 | data = process_conn.recv() 206 | if data[0] == "actions": 207 | obs_batch.append(data[1]) 208 | available_batch.append(data[2]) 209 | 210 | if idx == process_size - 1: 211 | obs_batch = np.concatenate(obs_batch, axis=0) 212 | available_batch = np.concatenate(available_batch, axis=0) 213 | actions = self.qmix_algo.act(self.actors, torch.FloatTensor(obs_batch).to(device), torch.FloatTensor(available_batch).to(device)) 214 | 215 | elif data[0] == "replay_buffer": 216 | self.replay_buffers[idx].add(data[1], data[2], data[3], data[4], data[5], data[6], data[7]) 217 | 218 | elif data[0] == "episode_end": 219 | self.episode_reward[idx] = data[1] 220 | self.episode_step[idx] = data[2] 221 | self.win_counted_array[idx] = data[3] 222 | available_to_send[idx] = False 223 | episode_done += 1 224 | 225 | if actions is not None: 226 | for idx_proc, process in enumerate(self.process_com): 227 | if available_to_send[idx_proc]: 228 | process.send(actions[idx_proc]) 229 | 230 | if episode_done >= self.actors: 231 | break 232 | 233 | self.episode += self.actors 234 | 235 | self.episode_global_step += max(self.episode_step) 236 | 237 | self.qmix_algo.decay_epsilon_greddy(self.episode_global_step) 238 | 239 | return self.replay_buffers 240 | 241 | def save(self): 242 | for process in self.process_com: 243 | process.send('save') 244 | data = process.recv() 245 | print(data) 246 | pass 247 | """ 248 | for env in range(self.actors): 249 | env.save_replay() 250 | """ 251 | 252 | def close(self): 253 | for process in self.process_com: 254 | process.send('close') 255 | pass 256 | """ 257 | for env in range(self.actors): 258 | env.close() 259 | """ -------------------------------------------------------------------------------- /qmix.py: -------------------------------------------------------------------------------- 1 | import rnn_agent 2 | import qmixer 3 | import torch 4 | import numpy as np 5 | import random 6 | from collections import deque 7 | 8 | use_cuda = torch.cuda.is_available() 9 | device = torch.device("cuda" if use_cuda else "cpu") 10 | 11 | class EpsilonGreedy: 12 | def __init__(self, action_nb, agent_nb, final_step, epsilon_start=float(1), epsilon_end=0.05): 13 | self.epsilon = epsilon_start 14 | self.initial_epsilon = epsilon_start 15 | self.epsilon_end = epsilon_end 16 | self.action_nb = action_nb 17 | self.final_step = final_step 18 | self.agent_nb = agent_nb 19 | 20 | def act(self, value_action, avail_actions): 21 | if np.random.random() > self.epsilon: 22 | action = value_action.max(dim=1)[1].cpu().detach().numpy() 23 | else: 24 | action = torch.distributions.Categorical(avail_actions).sample().long().cpu().detach().numpy() 25 | return action 26 | 27 | def epislon_decay(self, step): 28 | progress = step / self.final_step 29 | 30 | decay = self.initial_epsilon - progress 31 | if decay <= self.epsilon_end: 32 | decay = self.epsilon_end 33 | self.epsilon = decay 34 | 35 | 36 | class ReplayBuffer(object): 37 | 38 | def __init__(self, buffer_size, random_seed=123): 39 | """ 40 | The right side of the deque contains the most recent experiences 41 | """ 42 | self.buffer_size = buffer_size 43 | self.count = 0 44 | self.buffer = deque() 45 | 46 | def add(self, s, a, r, t, obs, available_actions, filled): 47 | experience = [s, a, r, t, obs, available_actions, np.array([filled])] 48 | if self.count < self.buffer_size: 49 | self.buffer.append(experience) 50 | self.count += 1 51 | else: 52 | self.buffer.popleft() 53 | self.buffer.append(experience) 54 | 55 | def size(self): 56 | return self.count 57 | 58 | def sample_batch(self, batch_size): 59 | batch = [] 60 | 61 | for idx in range(batch_size): 62 | batch.append(self.buffer[idx]) 63 | batch = np.array(batch) 64 | 65 | s_batch = np.array([_[0] for _ in batch], dtype='float32') 66 | a_batch = np.array([_[1] for _ in batch], dtype='float32') 67 | r_batch = np.array([_[2] for _ in batch]) 68 | t_batch = np.array([_[3] for _ in batch]) 69 | obs_batch = np.array([_[4] for _ in batch], dtype='float32') 70 | available_actions_batch = np.array([_[5] for _ in batch], dtype='float32') 71 | filled_batch = np.array([_[6] for _ in batch], dtype='float32') 72 | 73 | return s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch, filled_batch 74 | 75 | def clear(self): 76 | self.buffer.clear() 77 | self.count = 0 78 | 79 | class EpisodeBatch: 80 | def __init__(self, buffer_size, random_seed=123): 81 | self.buffer_size = buffer_size 82 | self.count = 0 83 | self.buffer = deque() 84 | 85 | def reset(self): 86 | pass 87 | 88 | def add(self, replay_buffer): 89 | if self.count < self.buffer_size: 90 | self.buffer.append(replay_buffer) 91 | self.count += 1 92 | else: 93 | self.buffer.popleft() 94 | self.buffer.append(replay_buffer) 95 | 96 | def _get_max_episode_len(self, batch): 97 | max_episode_len = 0 98 | 99 | for replay_buffer in batch: 100 | _, _, _, t, _, _, _ = replay_buffer.sample_batch(replay_buffer.size()) 101 | for idx, t_idx in enumerate(t): 102 | if t_idx == True: 103 | if idx > max_episode_len: 104 | max_episode_len = idx + 1 105 | break 106 | 107 | return max_episode_len 108 | 109 | 110 | def sample_batch(self, batch_size): 111 | batch = [] 112 | 113 | if self.count < batch_size: 114 | batch = random.sample(self.buffer, self.count) 115 | else: 116 | batch = random.sample(self.buffer, batch_size) 117 | episode_len = self._get_max_episode_len(batch) 118 | s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch, filled_batch = [], [], [], [], [], [], [] 119 | for replay_buffer in batch: 120 | s, a, r, t, obs, available_actions, filled = replay_buffer.sample_batch(episode_len) 121 | s_batch.append(s) 122 | a_batch.append(a) 123 | r_batch.append(r) 124 | t_batch.append(t) 125 | obs_batch.append(obs) 126 | available_actions_batch.append(available_actions) 127 | filled_batch.append(filled) 128 | 129 | filled_batch = np.array(filled_batch) 130 | r_batch = np.array(r_batch) 131 | t_batch = np.array(t_batch) 132 | a_batch = np.array(a_batch) 133 | obs_batch = np.array(obs_batch) 134 | available_actions_batch = np.array(available_actions_batch) 135 | 136 | 137 | 138 | return s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch, filled_batch, episode_len 139 | 140 | #return batch 141 | 142 | def size(self): 143 | return self.count 144 | 145 | class QMix: 146 | def __init__(self, training, agent_nb, obs_shape, states_shape, action_n, lr, gamma=0.99, batch_size=16, replay_buffer_size=5000, update_target_network=200, final_step=50000): #32 147 | self.training = training 148 | self.gamma = gamma 149 | self.batch_size = batch_size 150 | self.update_target_network = update_target_network 151 | self.hidden_states = None 152 | self.target_hidden_states = None 153 | self.agent_nb = agent_nb 154 | self.action_n = action_n 155 | self.state_shape = states_shape 156 | self.obs_shape = obs_shape 157 | 158 | self.epsilon_greedy = EpsilonGreedy(action_n, agent_nb, final_step) 159 | self.episode_batch = EpisodeBatch(replay_buffer_size) 160 | 161 | self.agents = rnn_agent.RNNAgent(obs_shape, n_actions=action_n).to(device) 162 | self.target_agents = rnn_agent.RNNAgent(obs_shape, n_actions=action_n).to(device) 163 | self.qmixer = qmixer.QMixer(agent_nb, states_shape, mixing_embed_dim=32).to(device) 164 | self.target_qmixer = qmixer.QMixer(agent_nb, states_shape, mixing_embed_dim=32).to(device) 165 | 166 | self.target_agents.update(self.agents) 167 | self.target_qmixer.update(self.qmixer) 168 | 169 | self.params = list(self.agents.parameters()) 170 | self.params += self.qmixer.parameters() 171 | 172 | self.optimizer = torch.optim.RMSprop(params=self.params, lr=lr, alpha=0.99, eps=0.00001) 173 | 174 | def save_model(self, filename): 175 | torch.save(self.agents.state_dict(), filename) 176 | 177 | def load_model(self, filename): 178 | self.agents.load_state_dict(torch.load(filename)) 179 | self.agents.eval() 180 | 181 | def _init_hidden_states(self, batch_size): 182 | self.hidden_states = self.agents.init_hidden().unsqueeze(0).expand(batch_size, self.agent_nb, -1) 183 | self.target_hidden_states = self.target_agents.init_hidden().unsqueeze(0).expand(batch_size, self.agent_nb, -1) 184 | 185 | def decay_epsilon_greddy(self, global_steps): 186 | self.epsilon_greedy.epislon_decay(global_steps) 187 | 188 | def on_reset(self, batch_size): 189 | self._init_hidden_states(batch_size) 190 | 191 | def update_targets(self, episode): 192 | if episode % self.update_target_network == 0 and self.training: 193 | self.target_agents.update(self.agents) 194 | self.target_qmixer.update(self.qmixer) 195 | pass 196 | 197 | def train(self): 198 | if self.training and self.episode_batch.size() > self.batch_size: 199 | for _ in range(2): 200 | self._init_hidden_states(self.batch_size) 201 | s_batch, a_batch, r_batch, t_batch, obs_batch, available_actions_batch, filled_batch, episode_len = self.episode_batch.sample_batch(self.batch_size) 202 | 203 | r_batch = r_batch[:, :-1] 204 | a_batch = a_batch[:, :-1] 205 | t_batch = t_batch[:, :-1] 206 | filled_batch = filled_batch[:, :-1] 207 | 208 | mask = (1 - filled_batch) * (1 - t_batch) 209 | 210 | r_batch = torch.FloatTensor(r_batch).to(device) 211 | t_batch = torch.FloatTensor(t_batch).to(device) 212 | mask = torch.FloatTensor(mask).to(device) 213 | 214 | a_batch = torch.LongTensor(a_batch).to(device) 215 | 216 | mac_out = [] 217 | 218 | for t in range(episode_len): 219 | obs = obs_batch[:, t] 220 | obs = np.concatenate(obs, axis=0) 221 | obs = torch.FloatTensor(obs).to(device) 222 | agent_actions, self.hidden_states = self.agents(obs, self.hidden_states) 223 | agent_actions = agent_actions.view(self.batch_size, self.agent_nb, -1) 224 | mac_out.append(agent_actions) 225 | mac_out = torch.stack(mac_out, dim=1) 226 | 227 | chosen_action_qvals = torch.gather(mac_out[:, :-1], dim=3, index=a_batch).squeeze(3) 228 | 229 | target_mac_out = [] 230 | 231 | for t in range(episode_len): 232 | obs = obs_batch[:, t] 233 | obs = np.concatenate(obs, axis=0) 234 | obs = torch.FloatTensor(obs).to(device) 235 | agent_actions, self.target_hidden_states = self.target_agents(obs, self.target_hidden_states) 236 | agent_actions = agent_actions.view(self.batch_size, self.agent_nb, -1) 237 | target_mac_out.append(agent_actions) 238 | target_mac_out = torch.stack(target_mac_out[1:], dim=1) 239 | available_actions_batch = torch.Tensor(available_actions_batch).to(device) 240 | 241 | target_mac_out[available_actions_batch[:, 1:] == 0] = -9999999 242 | 243 | target_max_qvals = target_mac_out.max(dim=3)[0] 244 | 245 | states = torch.FloatTensor(s_batch).to(device) 246 | #states2 = torch.FloatTensor(s2_batch).to(device) 247 | 248 | chosen_action_qvals = self.qmixer(chosen_action_qvals, states[:, :-1]) 249 | target_max_qvals = self.target_qmixer(target_max_qvals, states[:, 1:]) 250 | 251 | yi = r_batch + self.gamma * (1 - t_batch) * target_max_qvals 252 | 253 | td_error = (chosen_action_qvals - yi.detach()) 254 | 255 | mask = mask.expand_as(td_error) 256 | 257 | masked_td_error = td_error * mask 258 | 259 | loss = (masked_td_error ** 2).sum() / mask.sum() 260 | 261 | 262 | print('loss:', loss) 263 | self.optimizer.zero_grad() 264 | loss.backward() 265 | grad_norm = torch.nn.utils.clip_grad_norm_(self.params, 10) 266 | self.optimizer.step() 267 | 268 | pass 269 | pass 270 | 271 | 272 | def act(self, batch, obs, agents_available_actions): 273 | value_action, self.hidden_states = self.agents(obs, self.hidden_states) 274 | value_action[agents_available_actions == 0] = -1e10 275 | if self.training: 276 | value_action = self.epsilon_greedy.act(value_action, agents_available_actions) 277 | else: 278 | value_action = np.argmax(value_action.cpu().data.numpy(), -1) 279 | value_action = value_action.reshape(batch, self.agent_nb, -1) 280 | return value_action 281 | --------------------------------------------------------------------------------