├── .gitignore ├── README.md ├── m3ddpg.py ├── memory.py ├── model.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # M3DDPG-pytorch 2 | This is a MiniMax Multi-Agent Deep Deterministic Policy Gradient (M3DDPG) pytorch implementation 3 | 4 | This repository is the implementation code for the following paper. 5 | 6 | [Robust Multi-Agent Reinforcement Learning 7 | via Minimax Deep Deterministic Policy Gradient](https://people.eecs.berkeley.edu/~russell/papers/aaai19-marl.pdf) 8 | 9 | # For Multi-Agent Particle Environments (MPE) installation 10 | 11 | `https://github.com/openai/multiagent-particle-envs.git` 12 | 13 | `cd multiagent-particle-envs` 14 | 15 | `pip install -e .` 16 | 17 | # How to run 18 | 19 | `python train.py` -------------------------------------------------------------------------------- /m3ddpg.py: -------------------------------------------------------------------------------- 1 | from model import Policy, Critic 2 | from memory import ReplayMemory 3 | from copy import deepcopy 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from torch.optim import Adam 7 | import torch 8 | from collections import namedtuple 9 | 10 | Transition = namedtuple( 11 | 'Transition', ('state', 'action', 'next_state', 'reward', 'done')) 12 | 13 | def soft_update(target, source, t): 14 | for target_param, source_param in zip(target.parameters(), 15 | source.parameters()): 16 | target_param.data.copy_( 17 | (1 - t) * target_param.data + t * source_param.data) 18 | 19 | class Agent(): 20 | def __init__(self, args, state_dim, action_space, name): 21 | self.args = args 22 | self.name = name 23 | self.actor = Policy(state_dim, action_space.n).to(args.device) 24 | self.critic = Critic(state_dim, action_space.n).to(args.device) 25 | self.actor_target = deepcopy(self.actor) 26 | self.critic_target = deepcopy(self.critic) 27 | self.optimizer_actor = Adam(self.actor.parameters(), lr=args.actor_lr) 28 | self.optimizer_critic = Adam(self.critic.parameters(), lr=args.critic_lr) 29 | 30 | def get_action(self, observation, greedy): 31 | action = self.actor(observation) 32 | if not greedy: 33 | action += torch.tensor(np.random.normal(0, 0.1), 34 | dtype=torch.float, device=self.args.device) 35 | return action 36 | 37 | class M3DDPG(): 38 | 39 | def __init__(self, args, env): 40 | self.args = args 41 | self.device = args.device 42 | self.obs_shape_n = [env.observation_space[i].shape for i in range(env.n)] 43 | self.action_space = env.action_space[0] 44 | num_adversaries = min(env.n, args.num_adversaries) 45 | self.agents = [] 46 | for i in range(num_adversaries): 47 | self.agents.append(Agent(args, self.obs_shape_n[i][0], env.action_space[0], f'good_{i}')) 48 | for i in range(num_adversaries, env.n): 49 | self.agents.append(Agent(args, self.obs_shape_n[i][0], env.action_space[0], f'bad_{i}')) 50 | self.memory = ReplayMemory(args.capacity) 51 | 52 | def add_memory_list(self, *args): 53 | transitions = Transition(*args) 54 | self.memory.append(transitions) 55 | 56 | def sample_action(self, state, greedy=False): 57 | actions = [] 58 | for i, agent in enumerate(self.agents): 59 | observation_tensor = torch.tensor( 60 | state[i], dtype=torch.float, device=self.args.device).view(-1, self.obs_shape_n[i][0]) 61 | action = agent.get_action(observation_tensor, greedy).squeeze(0).detach().cpu().numpy().tolist() 62 | actions.append(np.argmax(action)) 63 | actions = np.array(actions) 64 | return actions 65 | 66 | def transition2batch(self, transitions): 67 | batch = Transition(*zip(*transitions)) 68 | state_batch = torch.transpose(torch.tensor( 69 | batch.state, device=self.args.device, dtype=torch.float), 0, 1) 70 | actions = [] 71 | for action in batch.action: 72 | action_vec = np.zeros(self.action_space.n) 73 | action_vec[np.argmax(action)] = 1 74 | actions.append(action_vec) 75 | action_batch = torch.tensor( 76 | actions, device=self.args.device, dtype=torch.float) 77 | next_state_batch = torch.transpose(torch.tensor( 78 | batch.next_state, device=self.args.device, dtype=torch.float), 0, 1) 79 | reward_batch = torch.tensor( 80 | batch.reward, device=self.args.device, dtype=torch.float) 81 | not_done = np.array([(not don) for don in batch.done]) 82 | not_done_batch = torch.tensor( 83 | not_done, device=self.args.device, dtype=torch.float).unsqueeze(1) 84 | 85 | return state_batch, action_batch, next_state_batch, not_done_batch, reward_batch 86 | 87 | def update(self): 88 | actor_losses, critic_losses = [], [] 89 | if self.memory.size() <= self.args.batch_size: 90 | return None, None 91 | transitions = self.memory.sample(self.args.batch_size) 92 | state_n_batch, action_n_batch, next_state_n_batch, not_done_n_batch, reward_n_batch = self.transition2batch(transitions) 93 | for i, agent in enumerate(self.agents): 94 | if 'good' in agent.name: 95 | eps = self.args.eps 96 | else: 97 | eps = self.args.adv_eps 98 | 99 | reward_batch = reward_n_batch[i] 100 | not_done_batch = not_done_n_batch[i] 101 | 102 | _next_actions = [self.agents[j].actor(next_state_n_batch) for j in range(len(self.agents))] 103 | _next_action_n_batch_critic = torch.cat([_next_action if j != i else _next_action.detach() for j, _next_action in enumerate(_next_actions)],axis=1).squeeze(0) 104 | _critic_target_loss = self.agents[i].critic_target(next_state_n_batch, _next_action_n_batch_critic).mean() 105 | _critic_target_loss.backward() 106 | with torch.no_grad(): 107 | next_action_n_batch_critic = torch.cat( 108 | [_next_action + eps * _next_action.grad if j != i else _next_action for j, _next_action in enumerate(_next_actions)] 109 | , axis=1).squeeze(0) 110 | 111 | _actions = [self.agents[j].actor( 112 | state_n_batch[j]) for j in range(len(self.agents))] 113 | _action_n_batch_actor = torch.cat([_action if j != i else _action.detach() for j, _action in enumerate(_actions)], axis=1) 114 | _actor_target_loss = self.agents[i].critic( 115 | state_n_batch, _action_n_batch_actor).mean() 116 | _actor_target_loss.backward() 117 | action_n_batch_actor = torch.cat( 118 | [_action + eps * _action.grad if j != i else _action for j, _action in enumerate(_actions)], axis=1) 119 | 120 | ##critic 121 | agent.optimizer_critic.zero_grad() 122 | currentQ = agent.critic(state_n_batch, action_n_batch) 123 | nextQ = agent.critic_target(next_state_n_batch, next_action_n_batch_critic) 124 | targetQ = reward_batch + self.args.gamma * not_done_batch * nextQ 125 | critic_loss = F.mse_loss(currentQ, targetQ) 126 | critic_loss.backward() 127 | agent.optimizer_critic.step() 128 | 129 | ##policy 130 | agent.optimizer_actor.zero_grad() 131 | actor_loss = - agent.critic(state_n_batch, action_n_batch_actor).mean() 132 | actor_loss.backward() 133 | agent.optimizer_actor.step() 134 | 135 | soft_update(agent.critic_target, agent.critic, self.args.tau) 136 | soft_update(agent.actor_target, agent.actor, self.args.tau) 137 | 138 | actor_losses.append(actor_loss.item()) 139 | critic_losses.append(critic_loss.item()) 140 | 141 | return actor_losses, critic_losses 142 | -------------------------------------------------------------------------------- /memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import deque 3 | 4 | 5 | class ReplayMemory(object): 6 | def __init__(self, capacity=1e6): 7 | self.capacity = capacity 8 | self.memory = deque([], maxlen=int(capacity)) 9 | 10 | def append(self, transition): 11 | self.memory.append(transition) 12 | 13 | def sample(self, batch_size): 14 | if len(self.memory) < batch_size: 15 | return None 16 | return random.sample(self.memory, batch_size) 17 | 18 | def reset(self): 19 | self.memory.clear() 20 | 21 | def size(self): 22 | return len(self.memory) 23 | 24 | def __len__(self): 25 | return len(self.memory) 26 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Policy(nn.Module): 7 | def __init__(self, state_dim, action_dim): 8 | super(Policy, self).__init__() 9 | self.fc1 = nn.Linear(state_dim, 256) 10 | self.fc2 = nn.Linear(256, 256) 11 | self.fc3 = nn.Linear(256, action_dim) 12 | 13 | def forward(self, state): 14 | h = F.relu(self.fc1(state)) 15 | h = F.relu(self.fc2(h)) 16 | h = self.fc3(h) 17 | y = torch.tanh(h) 18 | return y 19 | 20 | 21 | class Critic(nn.Module): 22 | def __init__(self, state_dim, action_dim): 23 | super(Critic, self).__init__() 24 | self.fc1 = nn.Linear(state_dim+action_dim, 256) 25 | self.fc2 = nn.Linear(256, 256) 26 | self.fc3 = nn.Linear(256, 1) 27 | 28 | def forward(self, state, action): 29 | state = state.squeeze(0) 30 | h = F.relu(self.fc1(torch.cat([state, action], axis=1))) 31 | h = F.relu(self.fc2(h)) 32 | y = self.fc3(h) 33 | return y 34 | 35 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gym 3 | from m3ddpg import M3DDPG 4 | import numpy as np 5 | import pickle 6 | import os 7 | import json 8 | import torch 9 | import datetime 10 | from multiagent.environment import MultiAgentEnv 11 | import multiagent.scenarios as scenarios 12 | import time 13 | 14 | def make_env(args): 15 | scenario = scenarios.load(f'{args.env_name}.py').Scenario() 16 | world = scenario.make_world() 17 | env = MultiAgentEnv(world, scenario.reset_world, 18 | scenario.reward, scenario.observation) 19 | return env 20 | 21 | def make_action(actions, action_space): 22 | agent_actions = [] 23 | for i, action in enumerate(actions): 24 | action_vec = np.zeros(action_space[i].n) 25 | action_vec[action] = 1 26 | agent_actions.append(action_vec) 27 | return agent_actions 28 | 29 | def set_seed(seed, env): 30 | torch.manual_seed(seed) 31 | env.seed = seed 32 | np.random.seed(seed) 33 | return env 34 | 35 | def evaluate(m3ddpg, args): 36 | env = make_env(args) 37 | env = set_seed(args.seed+100, env) 38 | total_reward = [0] * env.n 39 | for _ in range(args.evaluate_num): 40 | state_n = env.reset() 41 | for _ in range(args.max_episode_len): 42 | action_n = m3ddpg.sample_action(state_n, greedy=True) 43 | agent_actions = make_action(action_n, env.action_space) 44 | next_state_n, reward_n, done_n, _ = env.step(agent_actions) 45 | if args.render: 46 | env.render() 47 | time.sleep(0.1) 48 | for i in range(env.n): 49 | total_reward[i] += reward_n[i] 50 | if all(done_n): 51 | state_n = env.reset() 52 | break 53 | if args.render: 54 | env.close() 55 | return np.mean(np.array(total_reward)).tolist() 56 | 57 | def main(): 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument('--env_name', type=str, default='simple') 60 | parser.add_argument('--device', default='cpu', type=str) 61 | parser.add_argument('--capacity', default='1e6', type=float) 62 | parser.add_argument('--steps', type=float, default=1e6) 63 | parser.add_argument('--max_episode_len', type=int, default=25) 64 | parser.add_argument('--start_steps', type=float, default=1e3) 65 | parser.add_argument('--evaluate-interval', type=float, default=1e3) 66 | parser.add_argument('--evaluate_num', type=int, default=1) 67 | parser.add_argument('--batch_size', type=int, default=100) 68 | parser.add_argument('--eps', default=1e-5, type=float) 69 | parser.add_argument('--adv-eps', default=1e-3, type=float) 70 | parser.add_argument('--num-adversaries', type=int, default=0) 71 | parser.add_argument('--gamma', default=0.99, type=float) 72 | parser.add_argument('--tau', default=5e-3, type=float) 73 | parser.add_argument('--actor_lr', default=3e-4, type=float) 74 | parser.add_argument('--critic_lr', default=3e-4, type=float) 75 | parser.add_argument('--memo', default='', type=str) 76 | parser.add_argument('--seed', type=int, default=42) 77 | parser.add_argument('---render', action='store_true') 78 | 79 | args = parser.parse_args() 80 | dt_now = datetime.datetime.now() 81 | args.logger_dir = f'experiments/{args.env_name}_{args.seed}_{dt_now}' 82 | os.makedirs(args.logger_dir, exist_ok=False) 83 | with open('{}/hyperparameters.json'.format(args.logger_dir), 'w') as f: 84 | f.write(json.dumps(args.__dict__)) 85 | 86 | env = make_env(args) 87 | env = set_seed(args.seed, env) 88 | m3ddpg = M3DDPG(args, env) 89 | 90 | total_reward_list = [] 91 | reward_mean_list = [] 92 | 93 | step = 0 94 | episode_step = 0 95 | 96 | state_n = env.reset() 97 | total_reward = [0] * env.n 98 | actor_loss_list = [] 99 | critic_loss_list = [] 100 | while True: 101 | if step <= args.start_steps: 102 | action_n = np.array([env.action_space[i].sample() for i in range(env.n)]) 103 | else: 104 | action_n = m3ddpg.sample_action(state_n) 105 | agent_action = make_action(action_n, env.action_space) 106 | next_state_n, reward_n, done_n, _ = env.step(agent_action) 107 | for i in range(env.n): 108 | total_reward[i] += reward_n[i] 109 | m3ddpg.add_memory_list(state_n, action_n, next_state_n, reward_n, done_n) 110 | episode_step += 1 111 | if step >= args.start_steps: 112 | actor_loss, critic_loss = m3ddpg.update() 113 | if actor_loss is not None: 114 | actor_loss_list.append(actor_loss) 115 | critic_loss_list.append(critic_loss) 116 | 117 | state_n = next_state_n 118 | 119 | if step % args.evaluate_interval == 0: 120 | reward_mean= evaluate(m3ddpg, args) 121 | print('====================') 122 | print(f'step: {step} reward: {reward_mean}') 123 | print('====================') 124 | reward_mean_list.append(reward_mean) 125 | results = { 126 | 'train_reward': total_reward_list, 127 | 'reward_mean_list': reward_mean_list, 128 | } 129 | pickle.dump(results, open( 130 | '{}/results.pkl'.format(args.logger_dir), 'wb')) 131 | 132 | if step > args.steps: 133 | break 134 | step += 1 135 | if all(done_n) or episode_step > args.max_episode_len: 136 | total_reward_list.append(total_reward) 137 | actor_loss = np.mean(actor_loss_list, axis=0) 138 | critic_loss = np.mean(critic_loss_list, axis=0) 139 | print(f'step: {step} reward: {total_reward} actor loss: {actor_loss} critic loss: {critic_loss}') 140 | total_reward = [0] * env.n 141 | episode_step = 0 142 | state_n = env.reset() 143 | actor_loss_list = [] 144 | critic_loss_list = [] 145 | 146 | results = { 147 | 'train_reward': total_reward_list, 148 | 'reward_mean_list': reward_mean_list, 149 | } 150 | 151 | pickle.dump(results, open( 152 | '{}/results{}.pkl'.format('results', args.seed), 'wb')) 153 | 154 | pickle.dump(results, open( 155 | '{}/results.pkl'.format(args.logger_dir), 'wb')) 156 | 157 | if __name__ == '__main__': 158 | main() 159 | --------------------------------------------------------------------------------