├── .gitignore ├── LICENSE ├── README.md ├── algorithms └── attention_sac.py ├── envs └── mpe_scenarios │ ├── __init__.py │ ├── fullobs_collect_treasure.py │ └── multi_speaker_listener.py ├── main.py └── utils ├── agents.py ├── buffer.py ├── critics.py ├── env_wrappers.py ├── make_env.py ├── misc.py └── policies.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Repo Specific 2 | models 3 | fig_data 4 | notebooks 5 | multi_run* 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # pyenv 80 | .python-version 81 | 82 | # celery beat schedule file 83 | celerybeat-schedule 84 | 85 | # SageMath parsed files 86 | *.sage.py 87 | 88 | # dotenv 89 | .env 90 | 91 | # virtualenv 92 | .venv 93 | venv/ 94 | ENV/ 95 | 96 | # Spyder project settings 97 | .spyderproject 98 | .spyproject 99 | 100 | # Rope project settings 101 | .ropeproject 102 | 103 | # mkdocs documentation 104 | /site 105 | 106 | # mypy 107 | .mypy_cache/ 108 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Shariq Iqbal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-Actor-Attention-Critic 2 | Code for [*Actor-Attention-Critic for Multi-Agent Reinforcement Learning*](https://arxiv.org/abs/1810.02912) (Iqbal and Sha, ICML 2019) 3 | 4 | ## Requirements 5 | * Python 3.6.1 (Minimum) 6 | * [OpenAI baselines](https://github.com/openai/baselines), commit hash: 98257ef8c9bd23a24a330731ae54ed086d9ce4a7 7 | * My [fork](https://github.com/shariqiqbal2810/multiagent-particle-envs) of Multi-agent Particle Environments 8 | * [PyTorch](http://pytorch.org/), version: 0.3.0.post4 9 | * [OpenAI Gym](https://github.com/openai/gym), version: 0.9.4 10 | * [Tensorboard](https://github.com/tensorflow/tensorboard), version: 0.4.0rc3 and [Tensorboard-Pytorch](https://github.com/lanpa/tensorboard-pytorch), version: 1.0 (for logging) 11 | 12 | The versions are just what I used and not necessarily strict requirements. 13 | 14 | ## How to Run 15 | 16 | All training code is contained within `main.py`. To view options simply run: 17 | 18 | ```shell 19 | python main.py --help 20 | ``` 21 | The "Cooperative Treasure Collection" environment from our paper is referred to as `fullobs_collect_treasure` in this repo, and "Rover-Tower" is referred to as `multi_speaker_listener`. 22 | 23 | In order to match our experiments, the maximum episode length should be set to 100 for Cooperative Treasure Collection and 25 for Rover-Tower. 24 | 25 | ## Citing our work 26 | 27 | If you use this repo in your work, please consider citing the corresponding paper: 28 | 29 | ```bibtex 30 | @InProceedings{pmlr-v97-iqbal19a, 31 | title = {Actor-Attention-Critic for Multi-Agent Reinforcement Learning}, 32 | author = {Iqbal, Shariq and Sha, Fei}, 33 | booktitle = {Proceedings of the 36th International Conference on Machine Learning}, 34 | pages = {2961--2970}, 35 | year = {2019}, 36 | editor = {Chaudhuri, Kamalika and Salakhutdinov, Ruslan}, 37 | volume = {97}, 38 | series = {Proceedings of Machine Learning Research}, 39 | address = {Long Beach, California, USA}, 40 | month = {09--15 Jun}, 41 | publisher = {PMLR}, 42 | pdf = {http://proceedings.mlr.press/v97/iqbal19a/iqbal19a.pdf}, 43 | url = {http://proceedings.mlr.press/v97/iqbal19a.html}, 44 | } 45 | ``` 46 | -------------------------------------------------------------------------------- /algorithms/attention_sac.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.optim import Adam 4 | from utils.misc import soft_update, hard_update, enable_gradients, disable_gradients 5 | from utils.agents import AttentionAgent 6 | from utils.critics import AttentionCritic 7 | 8 | MSELoss = torch.nn.MSELoss() 9 | 10 | class AttentionSAC(object): 11 | """ 12 | Wrapper class for SAC agents with central attention critic in multi-agent 13 | task 14 | """ 15 | def __init__(self, agent_init_params, sa_size, 16 | gamma=0.95, tau=0.01, pi_lr=0.01, q_lr=0.01, 17 | reward_scale=10., 18 | pol_hidden_dim=128, 19 | critic_hidden_dim=128, attend_heads=4, 20 | **kwargs): 21 | """ 22 | Inputs: 23 | agent_init_params (list of dict): List of dicts with parameters to 24 | initialize each agent 25 | num_in_pol (int): Input dimensions to policy 26 | num_out_pol (int): Output dimensions to policy 27 | sa_size (list of (int, int)): Size of state and action space for 28 | each agent 29 | gamma (float): Discount factor 30 | tau (float): Target update rate 31 | pi_lr (float): Learning rate for policy 32 | q_lr (float): Learning rate for critic 33 | reward_scale (float): Scaling for reward (has effect of optimal 34 | policy entropy) 35 | hidden_dim (int): Number of hidden dimensions for networks 36 | """ 37 | self.nagents = len(sa_size) 38 | 39 | self.agents = [AttentionAgent(lr=pi_lr, 40 | hidden_dim=pol_hidden_dim, 41 | **params) 42 | for params in agent_init_params] 43 | self.critic = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim, 44 | attend_heads=attend_heads) 45 | self.target_critic = AttentionCritic(sa_size, hidden_dim=critic_hidden_dim, 46 | attend_heads=attend_heads) 47 | hard_update(self.target_critic, self.critic) 48 | self.critic_optimizer = Adam(self.critic.parameters(), lr=q_lr, 49 | weight_decay=1e-3) 50 | self.agent_init_params = agent_init_params 51 | self.gamma = gamma 52 | self.tau = tau 53 | self.pi_lr = pi_lr 54 | self.q_lr = q_lr 55 | self.reward_scale = reward_scale 56 | self.pol_dev = 'cpu' # device for policies 57 | self.critic_dev = 'cpu' # device for critics 58 | self.trgt_pol_dev = 'cpu' # device for target policies 59 | self.trgt_critic_dev = 'cpu' # device for target critics 60 | self.niter = 0 61 | 62 | @property 63 | def policies(self): 64 | return [a.policy for a in self.agents] 65 | 66 | @property 67 | def target_policies(self): 68 | return [a.target_policy for a in self.agents] 69 | 70 | def step(self, observations, explore=False): 71 | """ 72 | Take a step forward in environment with all agents 73 | Inputs: 74 | observations: List of observations for each agent 75 | Outputs: 76 | actions: List of actions for each agent 77 | """ 78 | return [a.step(obs, explore=explore) for a, obs in zip(self.agents, 79 | observations)] 80 | 81 | def update_critic(self, sample, soft=True, logger=None, **kwargs): 82 | """ 83 | Update central critic for all agents 84 | """ 85 | obs, acs, rews, next_obs, dones = sample 86 | # Q loss 87 | next_acs = [] 88 | next_log_pis = [] 89 | for pi, ob in zip(self.target_policies, next_obs): 90 | curr_next_ac, curr_next_log_pi = pi(ob, return_log_pi=True) 91 | next_acs.append(curr_next_ac) 92 | next_log_pis.append(curr_next_log_pi) 93 | trgt_critic_in = list(zip(next_obs, next_acs)) 94 | critic_in = list(zip(obs, acs)) 95 | next_qs = self.target_critic(trgt_critic_in) 96 | critic_rets = self.critic(critic_in, regularize=True, 97 | logger=logger, niter=self.niter) 98 | q_loss = 0 99 | for a_i, nq, log_pi, (pq, regs) in zip(range(self.nagents), next_qs, 100 | next_log_pis, critic_rets): 101 | target_q = (rews[a_i].view(-1, 1) + 102 | self.gamma * nq * 103 | (1 - dones[a_i].view(-1, 1))) 104 | if soft: 105 | target_q -= log_pi / self.reward_scale 106 | q_loss += MSELoss(pq, target_q.detach()) 107 | for reg in regs: 108 | q_loss += reg # regularizing attention 109 | q_loss.backward() 110 | self.critic.scale_shared_grads() 111 | grad_norm = torch.nn.utils.clip_grad_norm( 112 | self.critic.parameters(), 10 * self.nagents) 113 | self.critic_optimizer.step() 114 | self.critic_optimizer.zero_grad() 115 | 116 | if logger is not None: 117 | logger.add_scalar('losses/q_loss', q_loss, self.niter) 118 | logger.add_scalar('grad_norms/q', grad_norm, self.niter) 119 | self.niter += 1 120 | 121 | def update_policies(self, sample, soft=True, logger=None, **kwargs): 122 | obs, acs, rews, next_obs, dones = sample 123 | samp_acs = [] 124 | all_probs = [] 125 | all_log_pis = [] 126 | all_pol_regs = [] 127 | 128 | for a_i, pi, ob in zip(range(self.nagents), self.policies, obs): 129 | curr_ac, probs, log_pi, pol_regs, ent = pi( 130 | ob, return_all_probs=True, return_log_pi=True, 131 | regularize=True, return_entropy=True) 132 | logger.add_scalar('agent%i/policy_entropy' % a_i, ent, 133 | self.niter) 134 | samp_acs.append(curr_ac) 135 | all_probs.append(probs) 136 | all_log_pis.append(log_pi) 137 | all_pol_regs.append(pol_regs) 138 | 139 | critic_in = list(zip(obs, samp_acs)) 140 | critic_rets = self.critic(critic_in, return_all_q=True) 141 | for a_i, probs, log_pi, pol_regs, (q, all_q) in zip(range(self.nagents), all_probs, 142 | all_log_pis, all_pol_regs, 143 | critic_rets): 144 | curr_agent = self.agents[a_i] 145 | v = (all_q * probs).sum(dim=1, keepdim=True) 146 | pol_target = q - v 147 | if soft: 148 | pol_loss = (log_pi * (log_pi / self.reward_scale - pol_target).detach()).mean() 149 | else: 150 | pol_loss = (log_pi * (-pol_target).detach()).mean() 151 | for reg in pol_regs: 152 | pol_loss += 1e-3 * reg # policy regularization 153 | # don't want critic to accumulate gradients from policy loss 154 | disable_gradients(self.critic) 155 | pol_loss.backward() 156 | enable_gradients(self.critic) 157 | 158 | grad_norm = torch.nn.utils.clip_grad_norm( 159 | curr_agent.policy.parameters(), 0.5) 160 | curr_agent.policy_optimizer.step() 161 | curr_agent.policy_optimizer.zero_grad() 162 | 163 | if logger is not None: 164 | logger.add_scalar('agent%i/losses/pol_loss' % a_i, 165 | pol_loss, self.niter) 166 | logger.add_scalar('agent%i/grad_norms/pi' % a_i, 167 | grad_norm, self.niter) 168 | 169 | 170 | def update_all_targets(self): 171 | """ 172 | Update all target networks (called after normal updates have been 173 | performed for each agent) 174 | """ 175 | soft_update(self.target_critic, self.critic, self.tau) 176 | for a in self.agents: 177 | soft_update(a.target_policy, a.policy, self.tau) 178 | 179 | def prep_training(self, device='gpu'): 180 | self.critic.train() 181 | self.target_critic.train() 182 | for a in self.agents: 183 | a.policy.train() 184 | a.target_policy.train() 185 | if device == 'gpu': 186 | fn = lambda x: x.cuda() 187 | else: 188 | fn = lambda x: x.cpu() 189 | if not self.pol_dev == device: 190 | for a in self.agents: 191 | a.policy = fn(a.policy) 192 | self.pol_dev = device 193 | if not self.critic_dev == device: 194 | self.critic = fn(self.critic) 195 | self.critic_dev = device 196 | if not self.trgt_pol_dev == device: 197 | for a in self.agents: 198 | a.target_policy = fn(a.target_policy) 199 | self.trgt_pol_dev = device 200 | if not self.trgt_critic_dev == device: 201 | self.target_critic = fn(self.target_critic) 202 | self.trgt_critic_dev = device 203 | 204 | def prep_rollouts(self, device='cpu'): 205 | for a in self.agents: 206 | a.policy.eval() 207 | if device == 'gpu': 208 | fn = lambda x: x.cuda() 209 | else: 210 | fn = lambda x: x.cpu() 211 | # only need main policy for rollouts 212 | if not self.pol_dev == device: 213 | for a in self.agents: 214 | a.policy = fn(a.policy) 215 | self.pol_dev = device 216 | 217 | def save(self, filename): 218 | """ 219 | Save trained parameters of all agents into one file 220 | """ 221 | self.prep_training(device='cpu') # move parameters to CPU before saving 222 | save_dict = {'init_dict': self.init_dict, 223 | 'agent_params': [a.get_params() for a in self.agents], 224 | 'critic_params': {'critic': self.critic.state_dict(), 225 | 'target_critic': self.target_critic.state_dict(), 226 | 'critic_optimizer': self.critic_optimizer.state_dict()}} 227 | torch.save(save_dict, filename) 228 | 229 | @classmethod 230 | def init_from_env(cls, env, gamma=0.95, tau=0.01, 231 | pi_lr=0.01, q_lr=0.01, 232 | reward_scale=10., 233 | pol_hidden_dim=128, critic_hidden_dim=128, attend_heads=4, 234 | **kwargs): 235 | """ 236 | Instantiate instance of this class from multi-agent environment 237 | 238 | env: Multi-agent Gym environment 239 | gamma: discount factor 240 | tau: rate of update for target networks 241 | lr: learning rate for networks 242 | hidden_dim: number of hidden dimensions for networks 243 | """ 244 | agent_init_params = [] 245 | sa_size = [] 246 | for acsp, obsp in zip(env.action_space, 247 | env.observation_space): 248 | agent_init_params.append({'num_in_pol': obsp.shape[0], 249 | 'num_out_pol': acsp.n}) 250 | sa_size.append((obsp.shape[0], acsp.n)) 251 | 252 | init_dict = {'gamma': gamma, 'tau': tau, 253 | 'pi_lr': pi_lr, 'q_lr': q_lr, 254 | 'reward_scale': reward_scale, 255 | 'pol_hidden_dim': pol_hidden_dim, 256 | 'critic_hidden_dim': critic_hidden_dim, 257 | 'attend_heads': attend_heads, 258 | 'agent_init_params': agent_init_params, 259 | 'sa_size': sa_size} 260 | instance = cls(**init_dict) 261 | instance.init_dict = init_dict 262 | return instance 263 | 264 | @classmethod 265 | def init_from_save(cls, filename, load_critic=False): 266 | """ 267 | Instantiate instance of this class from file created by 'save' method 268 | """ 269 | save_dict = torch.load(filename) 270 | instance = cls(**save_dict['init_dict']) 271 | instance.init_dict = save_dict['init_dict'] 272 | for a, params in zip(instance.agents, save_dict['agent_params']): 273 | a.load_params(params) 274 | 275 | if load_critic: 276 | critic_params = save_dict['critic_params'] 277 | instance.critic.load_state_dict(critic_params['critic']) 278 | instance.target_critic.load_state_dict(critic_params['target_critic']) 279 | instance.critic_optimizer.load_state_dict(critic_params['critic_optimizer']) 280 | return instance -------------------------------------------------------------------------------- /envs/mpe_scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os.path as osp 3 | 4 | 5 | def load(name): 6 | pathname = osp.join(osp.dirname(__file__), name) 7 | return imp.load_source('', pathname) 8 | -------------------------------------------------------------------------------- /envs/mpe_scenarios/fullobs_collect_treasure.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | from multiagent.core import World, Agent, Landmark, Wall 4 | from multiagent.scenario import BaseScenario 5 | 6 | class Scenario(BaseScenario): 7 | def make_world(self): 8 | world = World() 9 | # set any world properties first 10 | world.cache_dists = True 11 | world.dim_c = 2 12 | num_agents = 8 13 | num_collectors = 6 14 | num_deposits = num_agents - num_collectors 15 | world.treasure_types = list(range(num_deposits)) 16 | world.treasure_colors = np.array( 17 | sns.color_palette(n_colors=num_deposits)) 18 | num_treasures = num_collectors 19 | # add agents 20 | world.agents = [Agent() for i in range(num_agents)] 21 | for i, agent in enumerate(world.agents): 22 | agent.i = i 23 | agent.name = 'agent %d' % i 24 | agent.collector = True if i < num_collectors else False 25 | if agent.collector: 26 | agent.color = np.array([0.85, 0.85, 0.85]) 27 | else: 28 | agent.d_i = i - num_collectors 29 | agent.color = world.treasure_colors[agent.d_i] * 0.35 30 | agent.collide = True 31 | agent.silent = True 32 | agent.ghost = True 33 | agent.holding = None 34 | agent.size = 0.05 if agent.collector else 0.075 35 | agent.accel = 1.5 36 | agent.initial_mass = 1.0 if agent.collector else 2.25 37 | agent.max_speed = 1.0 38 | # add treasures 39 | world.landmarks = [Landmark() for i in range(num_treasures)] 40 | for i, landmark in enumerate(world.landmarks): 41 | landmark.i = i + num_agents 42 | landmark.name = 'treasure %d' % i 43 | landmark.respawn_prob = 1.0 44 | landmark.type = np.random.choice(world.treasure_types) 45 | landmark.color = world.treasure_colors[landmark.type] 46 | landmark.alive = True 47 | landmark.collide = False 48 | landmark.movable = False 49 | landmark.size = 0.025 50 | landmark.boundary = False 51 | world.walls = [] 52 | # make initial conditions 53 | self.reset_world(world) 54 | self.reset_cached_rewards() 55 | return world 56 | 57 | def collectors(self, world): 58 | return [a for a in world.agents if a.collector] 59 | 60 | def deposits(self, world): 61 | return [a for a in world.agents if not a.collector] 62 | 63 | def reset_cached_rewards(self): 64 | self.global_collecting_reward = None 65 | self.global_holding_reward = None 66 | self.global_deposit_reward = None 67 | 68 | def post_step(self, world): 69 | self.reset_cached_rewards() 70 | for l in world.landmarks: 71 | if l.alive: 72 | for a in self.collectors(world): 73 | if a.holding is None and self.is_collision(l, a, world): 74 | l.alive = False 75 | a.holding = l.type 76 | a.color = 0.85 * l.color 77 | l.state.p_pos = np.array([-999., -999.]) 78 | break 79 | else: 80 | if np.random.uniform() <= l.respawn_prob: 81 | bound = 0.95 82 | l.state.p_pos = np.random.uniform(low=-bound, high=bound, 83 | size=world.dim_p) 84 | l.type = np.random.choice(world.treasure_types) 85 | l.color = world.treasure_colors[l.type] 86 | l.alive = True 87 | for a in self.collectors(world): 88 | if a.holding is not None: 89 | for d in self.deposits(world): 90 | if d.d_i == a.holding and self.is_collision(a, d, world): 91 | a.holding = None 92 | a.color = np.array([0.85, 0.85, 0.85]) 93 | 94 | def reset_world(self, world): 95 | # set random initial states 96 | for i, agent in enumerate(world.agents): 97 | agent.state.p_pos = np.random.uniform(low=-1, high=1, 98 | size=world.dim_p) 99 | agent.state.p_vel = np.zeros(world.dim_p) 100 | agent.state.c = np.zeros(world.dim_c) 101 | agent.holding = None 102 | if agent.collector: 103 | agent.color = np.array([0.85, 0.85, 0.85]) 104 | for i, landmark in enumerate(world.landmarks): 105 | bound = 0.95 106 | landmark.type = np.random.choice(world.treasure_types) 107 | landmark.color = world.treasure_colors[landmark.type] 108 | landmark.state.p_pos = np.random.uniform(low=-bound, high=bound, 109 | size=world.dim_p) 110 | landmark.state.p_vel = np.zeros(world.dim_p) 111 | landmark.alive = True 112 | world.calculate_distances() 113 | 114 | def benchmark_data(self, agent, world): 115 | # returns data for benchmarking purposes 116 | if agent.collector: 117 | if agent.holding is not None: 118 | for d in self.deposits(world): 119 | if d.d_i == agent.holding and self.is_collision(d, agent, world): 120 | return 1 121 | else: 122 | for t in self.treasures(world): 123 | if self.is_collision(t, agent, world): 124 | return 1 125 | else: # deposit 126 | for a in self.collectors(world): 127 | if a.holding == agent.d_i and self.is_collision(a, agent, world): 128 | return 1 129 | return 0 130 | 131 | def is_collision(self, agent1, agent2, world): 132 | dist = world.cached_dist_mag[agent1.i, agent2.i] 133 | dist_min = agent1.size + agent2.size 134 | return True if dist < dist_min else False 135 | 136 | def treasures(self, world): 137 | return world.landmarks 138 | 139 | def reward(self, agent, world): 140 | main_reward = (self.collector_reward(agent, world) if agent.collector 141 | else self.deposit_reward(agent, world)) 142 | return main_reward 143 | 144 | def deposit_reward(self, agent, world): 145 | rew = 0 146 | shape = True 147 | if shape: # reward can optionally be shaped 148 | # penalize by distance to closest relevant holding agent 149 | dists_to_holding = [world.cached_dist_mag[agent.i, a.i] for a in 150 | self.collectors(world) if a.holding == agent.d_i] 151 | if len(dists_to_holding) > 0: 152 | rew -= 0.1 * min(dists_to_holding) 153 | else: 154 | n_visible = 7 155 | # get positions of all entities in this agent's reference frame 156 | other_agent_inds = [a.i for a in world.agents if (a is not agent and a.collector)] 157 | closest_agents = sorted( 158 | zip(world.cached_dist_mag[other_agent_inds, agent.i], 159 | other_agent_inds))[:n_visible] 160 | closest_inds = list(i for _, i in closest_agents) 161 | closest_avg_dist_vect = world.cached_dist_vect[closest_inds, agent.i].mean(axis=0) 162 | rew -= 0.1 * np.linalg.norm(closest_avg_dist_vect) 163 | rew += self.global_reward(world) 164 | return rew 165 | 166 | def collector_reward(self, agent, world): 167 | rew = 0 168 | # penalize collisions between collectors 169 | rew -= 5 * sum(self.is_collision(agent, a, world) 170 | for a in self.collectors(world) if a is not agent) 171 | shape = True 172 | if agent.holding is None and shape: 173 | rew -= 0.1 * min(world.cached_dist_mag[t.i, agent.i] for t in 174 | self.treasures(world)) 175 | elif shape: 176 | rew -= 0.1 * min(world.cached_dist_mag[d.i, agent.i] for d in 177 | self.deposits(world) if d.d_i == agent.holding) 178 | # collectors get global reward 179 | rew += self.global_reward(world) 180 | return rew 181 | 182 | def global_reward(self, world): 183 | if self.global_deposit_reward is None: 184 | self.calc_global_deposit_reward(world) 185 | if self.global_collecting_reward is None: 186 | self.calc_global_collecting_reward(world) 187 | return self.global_deposit_reward + self.global_collecting_reward 188 | 189 | def calc_global_collecting_reward(self, world): 190 | rew = 0 191 | for t in self.treasures(world): 192 | rew += 5 * sum(self.is_collision(a, t, world) 193 | for a in self.collectors(world) 194 | if a.holding is None) 195 | self.global_collecting_reward = rew 196 | 197 | def calc_global_deposit_reward(self, world): 198 | # reward deposits for getting treasure from collectors 199 | rew = 0 200 | for d in self.deposits(world): 201 | rew += 5 * sum(self.is_collision(d, a, world) for a in 202 | self.collectors(world) if a.holding == d.d_i) 203 | self.global_deposit_reward = rew 204 | 205 | def get_agent_encoding(self, agent, world): 206 | encoding = [] 207 | n_treasure_types = len(world.treasure_types) 208 | if agent.collector: 209 | encoding.append(np.zeros(n_treasure_types)) 210 | encoding.append((np.arange(n_treasure_types) == agent.holding)) 211 | else: 212 | encoding.append((np.arange(n_treasure_types) == agent.d_i)) 213 | encoding.append(np.zeros(n_treasure_types)) 214 | return np.concatenate(encoding) 215 | 216 | def observation(self, agent, world): 217 | n_visible = 7 # number of other agents and treasures visible to each agent 218 | # get positions of all entities in this agent's reference frame 219 | other_agents = [a.i for a in world.agents if a is not agent] 220 | closest_agents = sorted( 221 | zip(world.cached_dist_mag[other_agents, agent.i], 222 | other_agents))[:n_visible] 223 | treasures = [t.i for t in self.treasures(world)] 224 | closest_treasures = sorted( 225 | zip(world.cached_dist_mag[treasures, agent.i], 226 | treasures))[:n_visible] 227 | 228 | n_treasure_types = len(world.treasure_types) 229 | obs = [agent.state.p_pos, agent.state.p_vel] 230 | if agent.collector: 231 | # collectors need to know their own state bc it changes 232 | obs.append((np.arange(n_treasure_types) == agent.holding)) 233 | for _, i in closest_agents: 234 | a = world.entities[i] 235 | obs.append(world.cached_dist_vect[i, agent.i]) 236 | obs.append(a.state.p_vel) 237 | obs.append(self.get_agent_encoding(a, world)) 238 | for _, i in closest_treasures: 239 | t = world.entities[i] 240 | obs.append(world.cached_dist_vect[i, agent.i]) 241 | obs.append((np.arange(n_treasure_types) == t.type)) 242 | 243 | return np.concatenate(obs) 244 | -------------------------------------------------------------------------------- /envs/mpe_scenarios/multi_speaker_listener.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import seaborn as sns 3 | from multiagent.core import World, Agent, Landmark 4 | from multiagent.scenario import BaseScenario 5 | 6 | class Scenario(BaseScenario): 7 | def make_world(self): 8 | world = World() 9 | # set any world properties first 10 | world.dim_c = 5 11 | num_listeners = 4 12 | num_speakers = 4 13 | num_landmarks = 6 14 | world.landmark_colors = np.array( 15 | sns.color_palette(n_colors=num_landmarks)) 16 | world.listeners = [] 17 | for li in range(num_listeners): 18 | agent = Agent() 19 | agent.i = li 20 | agent.name = 'agent %i' % agent.i 21 | agent.listener = True 22 | agent.collide = False 23 | agent.size = 0.075 24 | agent.silent = True 25 | agent.accel = 1.5 26 | agent.initial_mass = 1.0 27 | agent.max_speed = 1.0 28 | world.listeners.append(agent) 29 | world.speakers = [] 30 | for si in range(num_speakers): 31 | agent = Agent() 32 | agent.i = si + num_listeners 33 | agent.name = 'agent %i' % agent.i 34 | agent.listener = False 35 | agent.collide = False 36 | agent.size = 0.075 37 | agent.movable = False 38 | agent.accel = 1.5 39 | agent.initial_mass = 1.0 40 | agent.max_speed = 1.0 41 | world.speakers.append(agent) 42 | world.agents = world.listeners + world.speakers 43 | # add landmarks 44 | world.landmarks = [Landmark() for i in range(num_landmarks)] 45 | for i, landmark in enumerate(world.landmarks): 46 | landmark.i = i + num_listeners + num_speakers 47 | landmark.name = 'landmark %d' % i 48 | landmark.collide = False 49 | landmark.movable = False 50 | landmark.size = 0.04 51 | landmark.color = world.landmark_colors[i] 52 | # make initial conditions 53 | self.reset_world(world) 54 | self.reset_cached_rewards() 55 | return world 56 | 57 | def reset_cached_rewards(self): 58 | self.pair_rewards = None 59 | 60 | def post_step(self, world): 61 | self.reset_cached_rewards() 62 | 63 | def reset_world(self, world): 64 | listen_inds = list(range(len(world.listeners))) 65 | np.random.shuffle(listen_inds) # randomize which listener each episode 66 | for i, speaker in enumerate(world.speakers): 67 | li = listen_inds[i] 68 | speaker.listen_ind = li 69 | speaker.goal_a = world.listeners[li] 70 | speaker.goal_b = np.random.choice(world.landmarks) 71 | speaker.color = np.array([0.25,0.25,0.25]) 72 | world.listeners[li].color = speaker.goal_b.color + np.array([0.25, 0.25, 0.25]) 73 | world.listeners[li].speak_ind = i 74 | 75 | # set random initial states 76 | for agent in world.agents: 77 | agent.state.p_pos = np.random.uniform(-1,+1, world.dim_p) 78 | agent.state.p_vel = np.zeros(world.dim_p) 79 | agent.state.c = np.zeros(world.dim_c) 80 | for i, landmark in enumerate(world.landmarks): 81 | landmark.state.p_pos = np.random.uniform(-1,+1, world.dim_p) 82 | landmark.state.p_vel = np.zeros(world.dim_p) 83 | self.reset_cached_rewards() 84 | 85 | def benchmark_data(self, agent, world): 86 | # returns data for benchmarking purposes 87 | return reward(agent, world) 88 | 89 | def calc_rewards(self, world): 90 | rews = [] 91 | for speaker in world.speakers: 92 | dist = np.sum(np.square(speaker.goal_a.state.p_pos - 93 | speaker.goal_b.state.p_pos)) 94 | rew = -dist 95 | if dist < (speaker.goal_a.size + speaker.goal_b.size) * 1.5: 96 | rew += 10. 97 | rews.append(rew) 98 | return rews 99 | 100 | def reward(self, agent, world): 101 | if self.pair_rewards is None: 102 | self.pair_rewards = self.calc_rewards(world) 103 | share_rews = False 104 | if share_rews: 105 | return sum(self.pair_rewards) 106 | if agent.listener: 107 | return self.pair_rewards[agent.speak_ind] 108 | else: 109 | return self.pair_rewards[agent.goal_a.speak_ind] 110 | 111 | def observation(self, agent, world): 112 | if agent.listener: 113 | obs = [] 114 | # give listener index of their speaker 115 | obs += [agent.speak_ind == np.arange(len(world.speakers))] 116 | # give listener communication from its speaker 117 | obs += [world.speakers[agent.speak_ind].state.c] 118 | # give listener its own position/velocity, 119 | obs += [agent.state.p_pos, agent.state.p_vel] 120 | 121 | # obs += [world.speakers[agent.speak_ind].state.c] 122 | # # # give listener index of their speaker 123 | # # obs += [agent.speak_ind == np.arange(len(world.speakers))] 124 | # # # give listener all communications 125 | # # obs += [speaker.state.c for speaker in world.speakers] 126 | # # give listener its own velocity 127 | # obs += [agent.state.p_vel] 128 | # # give listener locations of all agents 129 | # # obs += [a.state.p_pos for a in world.agents] 130 | # # give listener locations of all landmarks 131 | # obs += [l.state.p_pos for l in world.landmarks] 132 | return np.concatenate(obs) 133 | else: # speaker 134 | obs = [] 135 | # give speaker index of their listener 136 | obs += [agent.listen_ind == np.arange(len(world.listeners))] 137 | # speaker gets position of listener and goal 138 | obs += [agent.goal_a.state.p_pos, agent.goal_b.state.p_pos] 139 | 140 | # # give speaker index of their listener 141 | # # obs += [agent.listen_ind == np.arange(len(world.listeners))] 142 | # # # give speaker all communications 143 | # # obs += [speaker.state.c for speaker in world.speakers] 144 | # # give speaker their goal color 145 | # obs += [agent.goal_b.color] 146 | # # give speaker their listener's position 147 | # obs += [agent.goal_a.state.p_pos] 148 | # 149 | # obs += [speaker.state.c for speaker in world.speakers] 150 | return np.concatenate(obs) 151 | 152 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | from gym.spaces import Box, Discrete 6 | from pathlib import Path 7 | from torch.autograd import Variable 8 | from tensorboardX import SummaryWriter 9 | from utils.make_env import make_env 10 | from utils.buffer import ReplayBuffer 11 | from utils.env_wrappers import SubprocVecEnv, DummyVecEnv 12 | from algorithms.attention_sac import AttentionSAC 13 | 14 | 15 | def make_parallel_env(env_id, n_rollout_threads, seed): 16 | def get_env_fn(rank): 17 | def init_env(): 18 | env = make_env(env_id, discrete_action=True) 19 | env.seed(seed + rank * 1000) 20 | np.random.seed(seed + rank * 1000) 21 | return env 22 | return init_env 23 | if n_rollout_threads == 1: 24 | return DummyVecEnv([get_env_fn(0)]) 25 | else: 26 | return SubprocVecEnv([get_env_fn(i) for i in range(n_rollout_threads)]) 27 | 28 | def run(config): 29 | model_dir = Path('./models') / config.env_id / config.model_name 30 | if not model_dir.exists(): 31 | run_num = 1 32 | else: 33 | exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in 34 | model_dir.iterdir() if 35 | str(folder.name).startswith('run')] 36 | if len(exst_run_nums) == 0: 37 | run_num = 1 38 | else: 39 | run_num = max(exst_run_nums) + 1 40 | curr_run = 'run%i' % run_num 41 | run_dir = model_dir / curr_run 42 | log_dir = run_dir / 'logs' 43 | os.makedirs(log_dir) 44 | logger = SummaryWriter(str(log_dir)) 45 | 46 | torch.manual_seed(run_num) 47 | np.random.seed(run_num) 48 | env = make_parallel_env(config.env_id, config.n_rollout_threads, run_num) 49 | model = AttentionSAC.init_from_env(env, 50 | tau=config.tau, 51 | pi_lr=config.pi_lr, 52 | q_lr=config.q_lr, 53 | gamma=config.gamma, 54 | pol_hidden_dim=config.pol_hidden_dim, 55 | critic_hidden_dim=config.critic_hidden_dim, 56 | attend_heads=config.attend_heads, 57 | reward_scale=config.reward_scale) 58 | replay_buffer = ReplayBuffer(config.buffer_length, model.nagents, 59 | [obsp.shape[0] for obsp in env.observation_space], 60 | [acsp.shape[0] if isinstance(acsp, Box) else acsp.n 61 | for acsp in env.action_space]) 62 | t = 0 63 | for ep_i in range(0, config.n_episodes, config.n_rollout_threads): 64 | print("Episodes %i-%i of %i" % (ep_i + 1, 65 | ep_i + 1 + config.n_rollout_threads, 66 | config.n_episodes)) 67 | obs = env.reset() 68 | model.prep_rollouts(device='cpu') 69 | 70 | for et_i in range(config.episode_length): 71 | # rearrange observations to be per agent, and convert to torch Variable 72 | torch_obs = [Variable(torch.Tensor(np.vstack(obs[:, i])), 73 | requires_grad=False) 74 | for i in range(model.nagents)] 75 | # get actions as torch Variables 76 | torch_agent_actions = model.step(torch_obs, explore=True) 77 | # convert actions to numpy arrays 78 | agent_actions = [ac.data.numpy() for ac in torch_agent_actions] 79 | # rearrange actions to be per environment 80 | actions = [[ac[i] for ac in agent_actions] for i in range(config.n_rollout_threads)] 81 | next_obs, rewards, dones, infos = env.step(actions) 82 | replay_buffer.push(obs, agent_actions, rewards, next_obs, dones) 83 | obs = next_obs 84 | t += config.n_rollout_threads 85 | if (len(replay_buffer) >= config.batch_size and 86 | (t % config.steps_per_update) < config.n_rollout_threads): 87 | if config.use_gpu: 88 | model.prep_training(device='gpu') 89 | else: 90 | model.prep_training(device='cpu') 91 | for u_i in range(config.num_updates): 92 | sample = replay_buffer.sample(config.batch_size, 93 | to_gpu=config.use_gpu) 94 | model.update_critic(sample, logger=logger) 95 | model.update_policies(sample, logger=logger) 96 | model.update_all_targets() 97 | model.prep_rollouts(device='cpu') 98 | ep_rews = replay_buffer.get_average_rewards( 99 | config.episode_length * config.n_rollout_threads) 100 | for a_i, a_ep_rew in enumerate(ep_rews): 101 | logger.add_scalar('agent%i/mean_episode_rewards' % a_i, 102 | a_ep_rew * config.episode_length, ep_i) 103 | 104 | if ep_i % config.save_interval < config.n_rollout_threads: 105 | model.prep_rollouts(device='cpu') 106 | os.makedirs(run_dir / 'incremental', exist_ok=True) 107 | model.save(run_dir / 'incremental' / ('model_ep%i.pt' % (ep_i + 1))) 108 | model.save(run_dir / 'model.pt') 109 | 110 | model.save(run_dir / 'model.pt') 111 | env.close() 112 | logger.export_scalars_to_json(str(log_dir / 'summary.json')) 113 | logger.close() 114 | 115 | 116 | if __name__ == '__main__': 117 | parser = argparse.ArgumentParser() 118 | parser.add_argument("env_id", help="Name of environment") 119 | parser.add_argument("model_name", 120 | help="Name of directory to store " + 121 | "model/training contents") 122 | parser.add_argument("--n_rollout_threads", default=12, type=int) 123 | parser.add_argument("--buffer_length", default=int(1e6), type=int) 124 | parser.add_argument("--n_episodes", default=50000, type=int) 125 | parser.add_argument("--episode_length", default=25, type=int) 126 | parser.add_argument("--steps_per_update", default=100, type=int) 127 | parser.add_argument("--num_updates", default=4, type=int, 128 | help="Number of updates per update cycle") 129 | parser.add_argument("--batch_size", 130 | default=1024, type=int, 131 | help="Batch size for training") 132 | parser.add_argument("--save_interval", default=1000, type=int) 133 | parser.add_argument("--pol_hidden_dim", default=128, type=int) 134 | parser.add_argument("--critic_hidden_dim", default=128, type=int) 135 | parser.add_argument("--attend_heads", default=4, type=int) 136 | parser.add_argument("--pi_lr", default=0.001, type=float) 137 | parser.add_argument("--q_lr", default=0.001, type=float) 138 | parser.add_argument("--tau", default=0.001, type=float) 139 | parser.add_argument("--gamma", default=0.99, type=float) 140 | parser.add_argument("--reward_scale", default=100., type=float) 141 | parser.add_argument("--use_gpu", action='store_true') 142 | 143 | config = parser.parse_args() 144 | 145 | run(config) 146 | -------------------------------------------------------------------------------- /utils/agents.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | from torch.autograd import Variable 3 | from torch.optim import Adam 4 | from utils.misc import hard_update, gumbel_softmax, onehot_from_logits 5 | from utils.policies import DiscretePolicy 6 | 7 | class AttentionAgent(object): 8 | """ 9 | General class for Attention agents (policy, target policy) 10 | """ 11 | def __init__(self, num_in_pol, num_out_pol, hidden_dim=64, 12 | lr=0.01, onehot_dim=0): 13 | """ 14 | Inputs: 15 | num_in_pol (int): number of dimensions for policy input 16 | num_out_pol (int): number of dimensions for policy output 17 | """ 18 | self.policy = DiscretePolicy(num_in_pol, num_out_pol, 19 | hidden_dim=hidden_dim, 20 | onehot_dim=onehot_dim) 21 | self.target_policy = DiscretePolicy(num_in_pol, 22 | num_out_pol, 23 | hidden_dim=hidden_dim, 24 | onehot_dim=onehot_dim) 25 | 26 | hard_update(self.target_policy, self.policy) 27 | self.policy_optimizer = Adam(self.policy.parameters(), lr=lr) 28 | 29 | def step(self, obs, explore=False): 30 | """ 31 | Take a step forward in environment for a minibatch of observations 32 | Inputs: 33 | obs (PyTorch Variable): Observations for this agent 34 | explore (boolean): Whether or not to sample 35 | Outputs: 36 | action (PyTorch Variable): Actions for this agent 37 | """ 38 | return self.policy(obs, sample=explore) 39 | 40 | def get_params(self): 41 | return {'policy': self.policy.state_dict(), 42 | 'target_policy': self.target_policy.state_dict(), 43 | 'policy_optimizer': self.policy_optimizer.state_dict()} 44 | 45 | def load_params(self, params): 46 | self.policy.load_state_dict(params['policy']) 47 | self.target_policy.load_state_dict(params['target_policy']) 48 | self.policy_optimizer.load_state_dict(params['policy_optimizer']) 49 | -------------------------------------------------------------------------------- /utils/buffer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch import Tensor 3 | from torch.autograd import Variable 4 | 5 | class ReplayBuffer(object): 6 | """ 7 | Replay Buffer for multi-agent RL with parallel rollouts 8 | """ 9 | def __init__(self, max_steps, num_agents, obs_dims, ac_dims): 10 | """ 11 | Inputs: 12 | max_steps (int): Maximum number of timepoints to store in buffer 13 | num_agents (int): Number of agents in environment 14 | obs_dims (list of ints): number of obervation dimensions for each 15 | agent 16 | ac_dims (list of ints): number of action dimensions for each agent 17 | """ 18 | self.max_steps = max_steps 19 | self.num_agents = num_agents 20 | self.obs_buffs = [] 21 | self.ac_buffs = [] 22 | self.rew_buffs = [] 23 | self.next_obs_buffs = [] 24 | self.done_buffs = [] 25 | for odim, adim in zip(obs_dims, ac_dims): 26 | self.obs_buffs.append(np.zeros((max_steps, odim), dtype=np.float32)) 27 | self.ac_buffs.append(np.zeros((max_steps, adim), dtype=np.float32)) 28 | self.rew_buffs.append(np.zeros(max_steps, dtype=np.float32)) 29 | self.next_obs_buffs.append(np.zeros((max_steps, odim), dtype=np.float32)) 30 | self.done_buffs.append(np.zeros(max_steps, dtype=np.uint8)) 31 | 32 | 33 | self.filled_i = 0 # index of first empty location in buffer (last index when full) 34 | self.curr_i = 0 # current index to write to (ovewrite oldest data) 35 | 36 | def __len__(self): 37 | return self.filled_i 38 | 39 | def push(self, observations, actions, rewards, next_observations, dones): 40 | nentries = observations.shape[0] # handle multiple parallel environments 41 | if self.curr_i + nentries > self.max_steps: 42 | rollover = self.max_steps - self.curr_i # num of indices to roll over 43 | for agent_i in range(self.num_agents): 44 | self.obs_buffs[agent_i] = np.roll(self.obs_buffs[agent_i], 45 | rollover, axis=0) 46 | self.ac_buffs[agent_i] = np.roll(self.ac_buffs[agent_i], 47 | rollover, axis=0) 48 | self.rew_buffs[agent_i] = np.roll(self.rew_buffs[agent_i], 49 | rollover) 50 | self.next_obs_buffs[agent_i] = np.roll( 51 | self.next_obs_buffs[agent_i], rollover, axis=0) 52 | self.done_buffs[agent_i] = np.roll(self.done_buffs[agent_i], 53 | rollover) 54 | self.curr_i = 0 55 | self.filled_i = self.max_steps 56 | for agent_i in range(self.num_agents): 57 | self.obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack( 58 | observations[:, agent_i]) 59 | # actions are already batched by agent, so they are indexed differently 60 | self.ac_buffs[agent_i][self.curr_i:self.curr_i + nentries] = actions[agent_i] 61 | self.rew_buffs[agent_i][self.curr_i:self.curr_i + nentries] = rewards[:, agent_i] 62 | self.next_obs_buffs[agent_i][self.curr_i:self.curr_i + nentries] = np.vstack( 63 | next_observations[:, agent_i]) 64 | self.done_buffs[agent_i][self.curr_i:self.curr_i + nentries] = dones[:, agent_i] 65 | self.curr_i += nentries 66 | if self.filled_i < self.max_steps: 67 | self.filled_i += nentries 68 | if self.curr_i == self.max_steps: 69 | self.curr_i = 0 70 | 71 | def sample(self, N, to_gpu=False, norm_rews=True): 72 | inds = np.random.choice(np.arange(self.filled_i), size=N, 73 | replace=True) 74 | if to_gpu: 75 | cast = lambda x: Variable(Tensor(x), requires_grad=False).cuda() 76 | else: 77 | cast = lambda x: Variable(Tensor(x), requires_grad=False) 78 | if norm_rews: 79 | ret_rews = [cast((self.rew_buffs[i][inds] - 80 | self.rew_buffs[i][:self.filled_i].mean()) / 81 | self.rew_buffs[i][:self.filled_i].std()) 82 | for i in range(self.num_agents)] 83 | else: 84 | ret_rews = [cast(self.rew_buffs[i][inds]) for i in range(self.num_agents)] 85 | return ([cast(self.obs_buffs[i][inds]) for i in range(self.num_agents)], 86 | [cast(self.ac_buffs[i][inds]) for i in range(self.num_agents)], 87 | ret_rews, 88 | [cast(self.next_obs_buffs[i][inds]) for i in range(self.num_agents)], 89 | [cast(self.done_buffs[i][inds]) for i in range(self.num_agents)]) 90 | 91 | def get_average_rewards(self, N): 92 | if self.filled_i == self.max_steps: 93 | inds = np.arange(self.curr_i - N, self.curr_i) # allow for negative indexing 94 | else: 95 | inds = np.arange(max(0, self.curr_i - N), self.curr_i) 96 | return [self.rew_buffs[i][inds].mean() for i in range(self.num_agents)] 97 | -------------------------------------------------------------------------------- /utils/critics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from itertools import chain 6 | 7 | 8 | class AttentionCritic(nn.Module): 9 | """ 10 | Attention network, used as critic for all agents. Each agent gets its own 11 | observation and action, and can also attend over the other agents' encoded 12 | observations and actions. 13 | """ 14 | def __init__(self, sa_sizes, hidden_dim=32, norm_in=True, attend_heads=1): 15 | """ 16 | Inputs: 17 | sa_sizes (list of (int, int)): Size of state and action spaces per 18 | agent 19 | hidden_dim (int): Number of hidden dimensions 20 | norm_in (bool): Whether to apply BatchNorm to input 21 | attend_heads (int): Number of attention heads to use (use a number 22 | that hidden_dim is divisible by) 23 | """ 24 | super(AttentionCritic, self).__init__() 25 | assert (hidden_dim % attend_heads) == 0 26 | self.sa_sizes = sa_sizes 27 | self.nagents = len(sa_sizes) 28 | self.attend_heads = attend_heads 29 | 30 | self.critic_encoders = nn.ModuleList() 31 | self.critics = nn.ModuleList() 32 | 33 | self.state_encoders = nn.ModuleList() 34 | # iterate over agents 35 | for sdim, adim in sa_sizes: 36 | idim = sdim + adim 37 | odim = adim 38 | encoder = nn.Sequential() 39 | if norm_in: 40 | encoder.add_module('enc_bn', nn.BatchNorm1d(idim, 41 | affine=False)) 42 | encoder.add_module('enc_fc1', nn.Linear(idim, hidden_dim)) 43 | encoder.add_module('enc_nl', nn.LeakyReLU()) 44 | self.critic_encoders.append(encoder) 45 | critic = nn.Sequential() 46 | critic.add_module('critic_fc1', nn.Linear(2 * hidden_dim, 47 | hidden_dim)) 48 | critic.add_module('critic_nl', nn.LeakyReLU()) 49 | critic.add_module('critic_fc2', nn.Linear(hidden_dim, odim)) 50 | self.critics.append(critic) 51 | 52 | state_encoder = nn.Sequential() 53 | if norm_in: 54 | state_encoder.add_module('s_enc_bn', nn.BatchNorm1d( 55 | sdim, affine=False)) 56 | state_encoder.add_module('s_enc_fc1', nn.Linear(sdim, 57 | hidden_dim)) 58 | state_encoder.add_module('s_enc_nl', nn.LeakyReLU()) 59 | self.state_encoders.append(state_encoder) 60 | 61 | attend_dim = hidden_dim // attend_heads 62 | self.key_extractors = nn.ModuleList() 63 | self.selector_extractors = nn.ModuleList() 64 | self.value_extractors = nn.ModuleList() 65 | for i in range(attend_heads): 66 | self.key_extractors.append(nn.Linear(hidden_dim, attend_dim, bias=False)) 67 | self.selector_extractors.append(nn.Linear(hidden_dim, attend_dim, bias=False)) 68 | self.value_extractors.append(nn.Sequential(nn.Linear(hidden_dim, 69 | attend_dim), 70 | nn.LeakyReLU())) 71 | 72 | self.shared_modules = [self.key_extractors, self.selector_extractors, 73 | self.value_extractors, self.critic_encoders] 74 | 75 | def shared_parameters(self): 76 | """ 77 | Parameters shared across agents and reward heads 78 | """ 79 | return chain(*[m.parameters() for m in self.shared_modules]) 80 | 81 | def scale_shared_grads(self): 82 | """ 83 | Scale gradients for parameters that are shared since they accumulate 84 | gradients from the critic loss function multiple times 85 | """ 86 | for p in self.shared_parameters(): 87 | p.grad.data.mul_(1. / self.nagents) 88 | 89 | def forward(self, inps, agents=None, return_q=True, return_all_q=False, 90 | regularize=False, return_attend=False, logger=None, niter=0): 91 | """ 92 | Inputs: 93 | inps (list of PyTorch Matrices): Inputs to each agents' encoder 94 | (batch of obs + ac) 95 | agents (int): indices of agents to return Q for 96 | return_q (bool): return Q-value 97 | return_all_q (bool): return Q-value for all actions 98 | regularize (bool): returns values to add to loss function for 99 | regularization 100 | return_attend (bool): return attention weights per agent 101 | logger (TensorboardX SummaryWriter): If passed in, important values 102 | are logged 103 | """ 104 | if agents is None: 105 | agents = range(len(self.critic_encoders)) 106 | states = [s for s, a in inps] 107 | actions = [a for s, a in inps] 108 | inps = [torch.cat((s, a), dim=1) for s, a in inps] 109 | # extract state-action encoding for each agent 110 | sa_encodings = [encoder(inp) for encoder, inp in zip(self.critic_encoders, inps)] 111 | # extract state encoding for each agent that we're returning Q for 112 | s_encodings = [self.state_encoders[a_i](states[a_i]) for a_i in agents] 113 | # extract keys for each head for each agent 114 | all_head_keys = [[k_ext(enc) for enc in sa_encodings] for k_ext in self.key_extractors] 115 | # extract sa values for each head for each agent 116 | all_head_values = [[v_ext(enc) for enc in sa_encodings] for v_ext in self.value_extractors] 117 | # extract selectors for each head for each agent that we're returning Q for 118 | all_head_selectors = [[sel_ext(enc) for i, enc in enumerate(s_encodings) if i in agents] 119 | for sel_ext in self.selector_extractors] 120 | 121 | other_all_values = [[] for _ in range(len(agents))] 122 | all_attend_logits = [[] for _ in range(len(agents))] 123 | all_attend_probs = [[] for _ in range(len(agents))] 124 | # calculate attention per head 125 | for curr_head_keys, curr_head_values, curr_head_selectors in zip( 126 | all_head_keys, all_head_values, all_head_selectors): 127 | # iterate over agents 128 | for i, a_i, selector in zip(range(len(agents)), agents, curr_head_selectors): 129 | keys = [k for j, k in enumerate(curr_head_keys) if j != a_i] 130 | values = [v for j, v in enumerate(curr_head_values) if j != a_i] 131 | # calculate attention across agents 132 | attend_logits = torch.matmul(selector.view(selector.shape[0], 1, -1), 133 | torch.stack(keys).permute(1, 2, 0)) 134 | # scale dot-products by size of key (from Attention is All You Need) 135 | scaled_attend_logits = attend_logits / np.sqrt(keys[0].shape[1]) 136 | attend_weights = F.softmax(scaled_attend_logits, dim=2) 137 | other_values = (torch.stack(values).permute(1, 2, 0) * 138 | attend_weights).sum(dim=2) 139 | other_all_values[i].append(other_values) 140 | all_attend_logits[i].append(attend_logits) 141 | all_attend_probs[i].append(attend_weights) 142 | # calculate Q per agent 143 | all_rets = [] 144 | for i, a_i in enumerate(agents): 145 | head_entropies = [(-((probs + 1e-8).log() * probs).squeeze().sum(1) 146 | .mean()) for probs in all_attend_probs[i]] 147 | agent_rets = [] 148 | critic_in = torch.cat((s_encodings[i], *other_all_values[i]), dim=1) 149 | all_q = self.critics[a_i](critic_in) 150 | int_acs = actions[a_i].max(dim=1, keepdim=True)[1] 151 | q = all_q.gather(1, int_acs) 152 | if return_q: 153 | agent_rets.append(q) 154 | if return_all_q: 155 | agent_rets.append(all_q) 156 | if regularize: 157 | # regularize magnitude of attention logits 158 | attend_mag_reg = 1e-3 * sum((logit**2).mean() for logit in 159 | all_attend_logits[i]) 160 | regs = (attend_mag_reg,) 161 | agent_rets.append(regs) 162 | if return_attend: 163 | agent_rets.append(np.array(all_attend_probs[i])) 164 | if logger is not None: 165 | logger.add_scalars('agent%i/attention' % a_i, 166 | dict(('head%i_entropy' % h_i, ent) for h_i, ent 167 | in enumerate(head_entropies)), 168 | niter) 169 | if len(agent_rets) == 1: 170 | all_rets.append(agent_rets[0]) 171 | else: 172 | all_rets.append(agent_rets) 173 | if len(all_rets) == 1: 174 | return all_rets[0] 175 | else: 176 | return all_rets 177 | -------------------------------------------------------------------------------- /utils/env_wrappers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from OpenAI Baselines code to work with multi-agent envs 3 | """ 4 | import numpy as np 5 | from multiprocessing import Process, Pipe 6 | from baselines.common.vec_env import VecEnv, CloudpickleWrapper 7 | 8 | 9 | def worker(remote, parent_remote, env_fn_wrapper): 10 | parent_remote.close() 11 | env = env_fn_wrapper.x() 12 | while True: 13 | cmd, data = remote.recv() 14 | if cmd == 'step': 15 | ob, reward, done, info = env.step(data) 16 | if all(done): 17 | ob = env.reset() 18 | remote.send((ob, reward, done, info)) 19 | elif cmd == 'reset': 20 | ob = env.reset() 21 | remote.send(ob) 22 | elif cmd == 'reset_task': 23 | ob = env.reset_task() 24 | remote.send(ob) 25 | elif cmd == 'close': 26 | remote.close() 27 | break 28 | elif cmd == 'get_spaces': 29 | remote.send((env.observation_space, env.action_space)) 30 | elif cmd == 'get_agent_types': 31 | if all([hasattr(a, 'adversary') for a in env.agents]): 32 | remote.send(['adversary' if a.adversary else 'agent' for a in 33 | env.agents]) 34 | else: 35 | remote.send(['agent' for _ in env.agents]) 36 | else: 37 | raise NotImplementedError 38 | 39 | 40 | class SubprocVecEnv(VecEnv): 41 | def __init__(self, env_fns, spaces=None): 42 | """ 43 | envs: list of gym environments to run in subprocesses 44 | """ 45 | self.waiting = False 46 | self.closed = False 47 | nenvs = len(env_fns) 48 | self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) 49 | self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) 50 | for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] 51 | for p in self.ps: 52 | p.daemon = True # if the main process crashes, we should not cause things to hang 53 | p.start() 54 | for remote in self.work_remotes: 55 | remote.close() 56 | 57 | self.remotes[0].send(('get_spaces', None)) 58 | observation_space, action_space = self.remotes[0].recv() 59 | self.remotes[0].send(('get_agent_types', None)) 60 | self.agent_types = self.remotes[0].recv() 61 | VecEnv.__init__(self, len(env_fns), observation_space, action_space) 62 | 63 | def step_async(self, actions): 64 | for remote, action in zip(self.remotes, actions): 65 | remote.send(('step', action)) 66 | self.waiting = True 67 | 68 | def step_wait(self): 69 | results = [remote.recv() for remote in self.remotes] 70 | self.waiting = False 71 | obs, rews, dones, infos = zip(*results) 72 | return np.stack(obs), np.stack(rews), np.stack(dones), infos 73 | 74 | def reset(self): 75 | for remote in self.remotes: 76 | remote.send(('reset', None)) 77 | return np.stack([remote.recv() for remote in self.remotes]) 78 | 79 | def reset_task(self): 80 | for remote in self.remotes: 81 | remote.send(('reset_task', None)) 82 | return np.stack([remote.recv() for remote in self.remotes]) 83 | 84 | def close(self): 85 | if self.closed: 86 | return 87 | if self.waiting: 88 | for remote in self.remotes: 89 | remote.recv() 90 | for remote in self.remotes: 91 | remote.send(('close', None)) 92 | for p in self.ps: 93 | p.join() 94 | self.closed = True 95 | 96 | 97 | class DummyVecEnv(VecEnv): 98 | def __init__(self, env_fns): 99 | self.envs = [fn() for fn in env_fns] 100 | env = self.envs[0] 101 | VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) 102 | if all([hasattr(a, 'adversary') for a in env.agents]): 103 | self.agent_types = ['adversary' if a.adversary else 'agent' for a in 104 | env.agents] 105 | else: 106 | self.agent_types = ['agent' for _ in env.agents] 107 | self.ts = np.zeros(len(self.envs), dtype='int') 108 | self.actions = None 109 | 110 | def step_async(self, actions): 111 | self.actions = actions 112 | 113 | def step_wait(self): 114 | results = [env.step(a) for (a,env) in zip(self.actions, self.envs)] 115 | obs, rews, dones, infos = map(np.array, zip(*results)) 116 | self.ts += 1 117 | for (i, done) in enumerate(dones): 118 | if all(done): 119 | obs[i] = self.envs[i].reset() 120 | self.ts[i] = 0 121 | self.actions = None 122 | return np.array(obs), np.array(rews), np.array(dones), infos 123 | 124 | def reset(self): 125 | results = [env.reset() for env in self.envs] 126 | return np.array(results) 127 | 128 | def close(self): 129 | return -------------------------------------------------------------------------------- /utils/make_env.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code for creating a multiagent environment with one of the scenarios listed 3 | in ./scenarios/. 4 | Can be called by using, for example: 5 | env = make_env('simple_speaker_listener') 6 | After producing the env object, can be used similarly to an OpenAI gym 7 | environment. 8 | 9 | A policy using this environment must output actions in the form of a list 10 | for all agents. Each element of the list should be a numpy array, 11 | of size (env.world.dim_p + env.world.dim_c, 1). Physical actions precede 12 | communication actions in this array. See environment.py for more details. 13 | """ 14 | 15 | def make_env(scenario_name, benchmark=False, discrete_action=False): 16 | ''' 17 | Creates a MultiAgentEnv object as env. This can be used similar to a gym 18 | environment by calling env.reset() and env.step(). 19 | Use env.render() to view the environment on the screen. 20 | 21 | Input: 22 | scenario_name : name of the scenario from ./scenarios/ to be Returns 23 | (without the .py extension) 24 | benchmark : whether you want to produce benchmarking data 25 | (usually only done during evaluation) 26 | 27 | Some useful env properties (see environment.py): 28 | .observation_space : Returns the observation space for each agent 29 | .action_space : Returns the action space for each agent 30 | .n : Returns the number of Agents 31 | ''' 32 | from multiagent.environment import MultiAgentEnv 33 | import multiagent.scenarios as old_scenarios 34 | import envs.mpe_scenarios as new_scenarios 35 | 36 | # load scenario from script 37 | try: 38 | scenario = old_scenarios.load(scenario_name + ".py").Scenario() 39 | except: 40 | scenario = new_scenarios.load(scenario_name + ".py").Scenario() 41 | # create world 42 | world = scenario.make_world() 43 | # create multiagent environment 44 | if hasattr(scenario, 'post_step'): 45 | post_step = scenario.post_step 46 | else: 47 | post_step = None 48 | if benchmark: 49 | env = MultiAgentEnv(world, reset_callback=scenario.reset_world, 50 | reward_callback=scenario.reward, 51 | observation_callback=scenario.observation, 52 | post_step_callback=post_step, 53 | info_callback=scenario.benchmark_data, 54 | discrete_action=discrete_action) 55 | else: 56 | env = MultiAgentEnv(world, reset_callback=scenario.reset_world, 57 | reward_callback=scenario.reward, 58 | observation_callback=scenario.observation, 59 | post_step_callback=post_step, 60 | discrete_action=discrete_action) 61 | return env 62 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.distributed as dist 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | # https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L11 9 | def soft_update(target, source, tau): 10 | """ 11 | Perform DDPG soft update (move target params toward source based on weight 12 | factor tau) 13 | Inputs: 14 | target (torch.nn.Module): Net to copy parameters to 15 | source (torch.nn.Module): Net whose parameters to copy 16 | tau (float, 0 < x < 1): Weight factor for update 17 | """ 18 | for target_param, param in zip(target.parameters(), source.parameters()): 19 | target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) 20 | 21 | # https://github.com/ikostrikov/pytorch-ddpg-naf/blob/master/ddpg.py#L15 22 | def hard_update(target, source): 23 | """ 24 | Copy network parameters from source to target 25 | Inputs: 26 | target (torch.nn.Module): Net to copy parameters to 27 | source (torch.nn.Module): Net whose parameters to copy 28 | """ 29 | for target_param, param in zip(target.parameters(), source.parameters()): 30 | target_param.data.copy_(param.data) 31 | 32 | # https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py 33 | def average_gradients(model): 34 | """ Gradient averaging. """ 35 | size = float(dist.get_world_size()) 36 | for param in model.parameters(): 37 | dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM, group=0) 38 | param.grad.data /= size 39 | 40 | # https://github.com/seba-1511/dist_tuto.pth/blob/gh-pages/train_dist.py 41 | def init_processes(rank, size, fn, backend='gloo'): 42 | """ Initialize the distributed environment. """ 43 | os.environ['MASTER_ADDR'] = '127.0.0.1' 44 | os.environ['MASTER_PORT'] = '29500' 45 | dist.init_process_group(backend, rank=rank, world_size=size) 46 | fn(rank, size) 47 | 48 | def onehot_from_logits(logits, eps=0.0, dim=1): 49 | """ 50 | Given batch of logits, return one-hot sample using epsilon greedy strategy 51 | (based on given epsilon) 52 | """ 53 | # get best (according to current policy) actions in one-hot form 54 | argmax_acs = (logits == logits.max(dim, keepdim=True)[0]).float() 55 | if eps == 0.0: 56 | return argmax_acs 57 | # get random actions in one-hot form 58 | rand_acs = Variable(torch.eye(logits.shape[1])[[np.random.choice( 59 | range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False) 60 | # chooses between best and random actions using epsilon greedy 61 | return torch.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in 62 | enumerate(torch.rand(logits.shape[0]))]) 63 | 64 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 65 | def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor): 66 | """Sample from Gumbel(0, 1)""" 67 | U = Variable(tens_type(*shape).uniform_(), requires_grad=False) 68 | return -torch.log(-torch.log(U + eps) + eps) 69 | 70 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 71 | def gumbel_softmax_sample(logits, temperature, dim=1): 72 | """ Draw a sample from the Gumbel-Softmax distribution""" 73 | y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)) 74 | return F.softmax(y / temperature, dim=dim) 75 | 76 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 77 | def gumbel_softmax(logits, temperature=1.0, hard=False, dim=1): 78 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 79 | Args: 80 | logits: [batch_size, n_class] unnormalized log-probs 81 | temperature: non-negative scalar 82 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 83 | Returns: 84 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 85 | If hard=True, then the returned sample will be one-hot, otherwise it will 86 | be a probabilitiy distribution that sums to 1 across classes 87 | """ 88 | y = gumbel_softmax_sample(logits, temperature, dim=dim) 89 | if hard: 90 | y_hard = onehot_from_logits(y, dim=dim) 91 | y = (y_hard - y).detach() + y 92 | return y 93 | 94 | def firmmax_sample(logits, temperature, dim=1): 95 | if temperature == 0: 96 | return F.softmax(logits, dim=dim) 97 | y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)) / temperature 98 | return F.softmax(y, dim=dim) 99 | 100 | def categorical_sample(probs, use_cuda=False): 101 | int_acs = torch.multinomial(probs, 1) 102 | if use_cuda: 103 | tensor_type = torch.cuda.FloatTensor 104 | else: 105 | tensor_type = torch.FloatTensor 106 | acs = Variable(tensor_type(*probs.shape).fill_(0)).scatter_(1, int_acs, 1) 107 | return int_acs, acs 108 | 109 | def disable_gradients(module): 110 | for p in module.parameters(): 111 | p.requires_grad = False 112 | 113 | def enable_gradients(module): 114 | for p in module.parameters(): 115 | p.requires_grad = True 116 | 117 | def sep_clip_grad_norm(parameters, max_norm, norm_type=2): 118 | """ 119 | Clips gradient norms calculated on a per-parameter basis, rather than over 120 | the whole list of parameters as in torch.nn.utils.clip_grad_norm. 121 | Code based on torch.nn.utils.clip_grad_norm 122 | """ 123 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 124 | max_norm = float(max_norm) 125 | norm_type = float(norm_type) 126 | for p in parameters: 127 | if norm_type == float('inf'): 128 | p_norm = p.grad.data.abs().max() 129 | else: 130 | p_norm = p.grad.data.norm(norm_type) 131 | clip_coef = max_norm / (p_norm + 1e-6) 132 | if clip_coef < 1: 133 | p.grad.data.mul_(clip_coef) 134 | -------------------------------------------------------------------------------- /utils/policies.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from utils.misc import onehot_from_logits, categorical_sample 5 | 6 | class BasePolicy(nn.Module): 7 | """ 8 | Base policy network 9 | """ 10 | def __init__(self, input_dim, out_dim, hidden_dim=64, nonlin=F.leaky_relu, 11 | norm_in=True, onehot_dim=0): 12 | """ 13 | Inputs: 14 | input_dim (int): Number of dimensions in input 15 | out_dim (int): Number of dimensions in output 16 | hidden_dim (int): Number of hidden dimensions 17 | nonlin (PyTorch function): Nonlinearity to apply to hidden layers 18 | """ 19 | super(BasePolicy, self).__init__() 20 | 21 | if norm_in: # normalize inputs 22 | self.in_fn = nn.BatchNorm1d(input_dim, affine=False) 23 | else: 24 | self.in_fn = lambda x: x 25 | self.fc1 = nn.Linear(input_dim + onehot_dim, hidden_dim) 26 | self.fc2 = nn.Linear(hidden_dim, hidden_dim) 27 | self.fc3 = nn.Linear(hidden_dim, out_dim) 28 | self.nonlin = nonlin 29 | 30 | def forward(self, X): 31 | """ 32 | Inputs: 33 | X (PyTorch Matrix): Batch of observations (optionally a tuple that 34 | additionally includes a onehot label) 35 | Outputs: 36 | out (PyTorch Matrix): Actions 37 | """ 38 | onehot = None 39 | if type(X) is tuple: 40 | X, onehot = X 41 | inp = self.in_fn(X) # don't batchnorm onehot 42 | if onehot is not None: 43 | inp = torch.cat((onehot, inp), dim=1) 44 | h1 = self.nonlin(self.fc1(inp)) 45 | h2 = self.nonlin(self.fc2(h1)) 46 | out = self.fc3(h2) 47 | return out 48 | 49 | 50 | class DiscretePolicy(BasePolicy): 51 | """ 52 | Policy Network for discrete action spaces 53 | """ 54 | def __init__(self, *args, **kwargs): 55 | super(DiscretePolicy, self).__init__(*args, **kwargs) 56 | 57 | def forward(self, obs, sample=True, return_all_probs=False, 58 | return_log_pi=False, regularize=False, 59 | return_entropy=False): 60 | out = super(DiscretePolicy, self).forward(obs) 61 | probs = F.softmax(out, dim=1) 62 | on_gpu = next(self.parameters()).is_cuda 63 | if sample: 64 | int_act, act = categorical_sample(probs, use_cuda=on_gpu) 65 | else: 66 | act = onehot_from_logits(probs) 67 | rets = [act] 68 | if return_log_pi or return_entropy: 69 | log_probs = F.log_softmax(out, dim=1) 70 | if return_all_probs: 71 | rets.append(probs) 72 | if return_log_pi: 73 | # return log probability of selected action 74 | rets.append(log_probs.gather(1, int_act)) 75 | if regularize: 76 | rets.append([(out**2).mean()]) 77 | if return_entropy: 78 | rets.append(-(log_probs * probs).sum(1).mean()) 79 | if len(rets) == 1: 80 | return rets[0] 81 | return rets 82 | --------------------------------------------------------------------------------