├── utils ├── __init__.py ├── multi_discrete.py ├── util.py └── popart.py ├── algorithms ├── __init__.py ├── utils │ ├── util.py │ ├── cnn.py │ ├── mlp.py │ ├── rnn.py │ ├── distributions.py │ └── act.py ├── happo_policy.py ├── hatrpo_policy.py ├── actor_critic.py └── happo_trainer.py ├── scripts ├── __init__.py ├── train │ ├── __init__.py │ ├── train_smac.py │ └── train_mujoco.py ├── train_smac.sh └── train_mujoco.sh ├── envs ├── ma_mujoco │ ├── __init__.py │ └── multiagent_mujoco │ │ ├── assets │ │ ├── __init__.py │ │ ├── .gitignore │ │ ├── manyagent_swimmer.xml.template │ │ ├── manyagent_swimmer_bckp.xml │ │ ├── manyagent_swimmer__bckp2.xml │ │ ├── manyagent_ant.xml.template │ │ ├── manyagent_ant__stage1.xml │ │ ├── manyagent_ant.xml │ │ └── coupled_half_cheetah.xml │ │ ├── __init__.py │ │ ├── coupled_half_cheetah.py │ │ ├── multiagentenv.py │ │ ├── manyagent_swimmer.py │ │ ├── manyagent_ant.py │ │ └── mujoco_multi.py ├── __init__.py └── starcraft2 │ ├── multiagentenv.py │ └── smac_maps.py ├── runners ├── __init__.py └── separated │ ├── __init__.py │ ├── mujoco_runner.py │ ├── base_runner.py │ └── smac_runner.py ├── plots ├── smac.png ├── ma-mujoco_1.png └── ma-mujoco_2.png ├── install_sc2.sh ├── LICENSE ├── README.md ├── requirements.txt └── configs └── config.py /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /envs/ma_mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/train/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/.gitignore: -------------------------------------------------------------------------------- 1 | *.auto.xml 2 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | from runners import separated 2 | 3 | __all__=[ 4 | 5 | "separated" 6 | ] -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import socket 3 | from absl import flags 4 | FLAGS = flags.FLAGS 5 | FLAGS(['train_sc.py']) 6 | 7 | 8 | -------------------------------------------------------------------------------- /plots/smac.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous-ICLR22/Trust-Region-Methods-in-Multi-Agent-Reinforcement-Learning/HEAD/plots/smac.png -------------------------------------------------------------------------------- /plots/ma-mujoco_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous-ICLR22/Trust-Region-Methods-in-Multi-Agent-Reinforcement-Learning/HEAD/plots/ma-mujoco_1.png -------------------------------------------------------------------------------- /plots/ma-mujoco_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anonymous-ICLR22/Trust-Region-Methods-in-Multi-Agent-Reinforcement-Learning/HEAD/plots/ma-mujoco_2.png -------------------------------------------------------------------------------- /runners/separated/__init__.py: -------------------------------------------------------------------------------- 1 | from runners.separated import base_runner,smac_runner 2 | 3 | __all__=[ 4 | "base_runner", 5 | "smac_runner" 6 | ] -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/__init__.py: -------------------------------------------------------------------------------- 1 | from .mujoco_multi import MujocoMulti 2 | from .coupled_half_cheetah import CoupledHalfCheetah 3 | from .manyagent_swimmer import ManyAgentSwimmerEnv 4 | from .manyagent_ant import ManyAgentAntEnv 5 | -------------------------------------------------------------------------------- /algorithms/utils/util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | def init(module, weight_init, bias_init, gain=1): 8 | weight_init(module.weight.data, gain=gain) 9 | bias_init(module.bias.data) 10 | return module 11 | 12 | def get_clones(module, N): 13 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 14 | 15 | def check(input): 16 | output = torch.from_numpy(input) if type(input) == np.ndarray else input 17 | return output 18 | -------------------------------------------------------------------------------- /scripts/train_smac.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | env="StarCraft2" 3 | map="3s5z" 4 | algo="happo" 5 | exp="mlp" 6 | running_max=20 7 | kl_threshold=0.06 8 | echo "env is ${env}, map is ${map}, algo is ${algo}, exp is ${exp}, max seed is ${seed_max}" 9 | for number in `seq ${running_max}`; 10 | do 11 | echo "the ${number}-th running:" 12 | CUDA_VISIBLE_DEVICES=1 python train/train_smac.py --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --map_name ${map} --running_id ${number}--n_training_threads 32 --n_rollout_threads 8 --num_mini_batch 1 --episode_length 400 --num_env_steps 20000000 --ppo_epoch 5 --stacked_frames 1 --kl_threshold ${kl_threshold} --use_value_active_masks --use_eval --add_center_xy --use_state_agent --share_policy 13 | done 14 | -------------------------------------------------------------------------------- /scripts/train_mujoco.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | env="mujoco" 3 | scenario="Ant-v2" 4 | agent_conf="2x4" 5 | agent_obsk=2 6 | algo="happo" 7 | exp="mlp" 8 | running_max=20 9 | kl_threshold=1e-4 10 | echo "env is ${env}, scenario is ${scenario}, algo is ${algo}, exp is ${exp}, max seed is ${seed_max}" 11 | for number in `seq ${running_max}`; 12 | do 13 | echo "the ${number}-th running:" 14 | CUDA_VISIBLE_DEVICES=1 python train/train_mujoco.py --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario ${scenario} --agent_conf ${agent_conf} --agent_obsk ${agent_obsk} --lr 5e-6 --critic_lr 5e-3 --std_x_coef 1 --std_y_coef 5e-1 --running_id ${number} --n_training_threads 8 --n_rollout_threads 4 --num_mini_batch 40 --episode_length 1000 --num_env_steps 10000000 --ppo_epoch 5 --kl_threshold ${kl_threshold} --use_value_active_masks --use_eval --add_center_xy --use_state_agent --share_policy 15 | done 16 | -------------------------------------------------------------------------------- /install_sc2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Install SC2 and add the custom maps 3 | 4 | if [ -z "$EXP_DIR" ] 5 | then 6 | EXP_DIR=~ 7 | fi 8 | 9 | echo "EXP_DIR: $EXP_DIR" 10 | cd $EXP_DIR/pymarl 11 | 12 | mkdir 3rdparty 13 | cd 3rdparty 14 | 15 | export SC2PATH=`pwd`'/StarCraftII' 16 | echo 'SC2PATH is set to '$SC2PATH 17 | 18 | if [ ! -d $SC2PATH ]; then 19 | echo 'StarCraftII is not installed. Installing now ...'; 20 | wget http://blzdistsc2-a.akamaihd.net/Linux/SC2.4.10.zip 21 | unzip -P iagreetotheeula SC2.4.10.zip 22 | rm -rf SC2.4.10.zip 23 | else 24 | echo 'StarCraftII is already installed.' 25 | fi 26 | 27 | echo 'Adding SMAC maps.' 28 | MAP_DIR="$SC2PATH/Maps/" 29 | echo 'MAP_DIR is set to '$MAP_DIR 30 | 31 | if [ ! -d $MAP_DIR ]; then 32 | mkdir -p $MAP_DIR 33 | fi 34 | 35 | cd .. 36 | wget https://github.com/oxwhirl/smac/releases/download/v0.1-beta1/SMAC_Maps.zip 37 | unzip SMAC_Maps.zip 38 | mv SMAC_Maps $MAP_DIR 39 | rm -rf SMAC_Maps.zip 40 | 41 | echo 'StarCraft II and SMAC are installed.' 42 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Tianshou contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/coupled_half_cheetah.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | import os 5 | 6 | 7 | class CoupledHalfCheetah(mujoco_env.MujocoEnv, utils.EzPickle): 8 | def __init__(self, **kwargs): 9 | mujoco_env.MujocoEnv.__init__(self, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'coupled_half_cheetah.xml'), 5) 10 | utils.EzPickle.__init__(self) 11 | 12 | def step(self, action): 13 | xposbefore1 = self.sim.data.qpos[0] 14 | xposbefore2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2] 15 | self.do_simulation(action, self.frame_skip) 16 | xposafter1 = self.sim.data.qpos[0] 17 | xposafter2 = self.sim.data.qpos[len(self.sim.data.qpos)//2] 18 | ob = self._get_obs() 19 | reward_ctrl1 = - 0.1 * np.square(action[0:len(action)//2]).sum() 20 | reward_ctrl2 = - 0.1 * np.square(action[len(action)//2:]).sum() 21 | reward_run1 = (xposafter1 - xposbefore1)/self.dt 22 | reward_run2 = (xposafter2 - xposbefore2) / self.dt 23 | reward = (reward_ctrl1 + reward_ctrl2)/2.0 + (reward_run1 + reward_run2)/2.0 24 | done = False 25 | return ob, reward, done, dict(reward_run1=reward_run1, reward_ctrl1=reward_ctrl1, 26 | reward_run2=reward_run2, reward_ctrl2=reward_ctrl2) 27 | 28 | def _get_obs(self): 29 | return np.concatenate([ 30 | self.sim.data.qpos.flat[1:], 31 | self.sim.data.qvel.flat, 32 | ]) 33 | 34 | def reset_model(self): 35 | qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq) 36 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 37 | self.set_state(qpos, qvel) 38 | return self._get_obs() 39 | 40 | def viewer_setup(self): 41 | self.viewer.cam.distance = self.model.stat.extent * 0.5 42 | 43 | def get_env_info(self): 44 | return {"episode_limit": self.episode_limit} -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer.xml.template: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /algorithms/utils/cnn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .util import init 3 | 4 | """CNN Modules and utils.""" 5 | 6 | class Flatten(nn.Module): 7 | def forward(self, x): 8 | return x.view(x.size(0), -1) 9 | 10 | 11 | class CNNLayer(nn.Module): 12 | def __init__(self, obs_shape, hidden_size, use_orthogonal, use_ReLU, kernel_size=3, stride=1): 13 | super(CNNLayer, self).__init__() 14 | 15 | active_func = [nn.Tanh(), nn.ReLU()][use_ReLU] 16 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 17 | gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU]) 18 | 19 | def init_(m): 20 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) 21 | 22 | input_channel = obs_shape[0] 23 | input_width = obs_shape[1] 24 | input_height = obs_shape[2] 25 | 26 | self.cnn = nn.Sequential( 27 | init_(nn.Conv2d(in_channels=input_channel, 28 | out_channels=hidden_size // 2, 29 | kernel_size=kernel_size, 30 | stride=stride) 31 | ), 32 | active_func, 33 | Flatten(), 34 | init_(nn.Linear(hidden_size // 2 * (input_width - kernel_size + stride) * (input_height - kernel_size + stride), 35 | hidden_size) 36 | ), 37 | active_func, 38 | init_(nn.Linear(hidden_size, hidden_size)), active_func) 39 | 40 | def forward(self, x): 41 | x = x / 255.0 42 | x = self.cnn(x) 43 | return x 44 | 45 | 46 | class CNNBase(nn.Module): 47 | def __init__(self, args, obs_shape): 48 | super(CNNBase, self).__init__() 49 | 50 | self._use_orthogonal = args.use_orthogonal 51 | self._use_ReLU = args.use_ReLU 52 | self.hidden_size = args.hidden_size 53 | 54 | self.cnn = CNNLayer(obs_shape, self.hidden_size, self._use_orthogonal, self._use_ReLU) 55 | 56 | def forward(self, x): 57 | x = self.cnn(x) 58 | return x 59 | -------------------------------------------------------------------------------- /envs/starcraft2/multiagentenv.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | 6 | class MultiAgentEnv(object): 7 | 8 | def step(self, actions): 9 | """Returns reward, terminated, info.""" 10 | raise NotImplementedError 11 | 12 | def get_obs(self): 13 | """Returns all agent observations in a list.""" 14 | raise NotImplementedError 15 | 16 | def get_obs_agent(self, agent_id): 17 | """Returns observation for agent_id.""" 18 | raise NotImplementedError 19 | 20 | def get_obs_size(self): 21 | """Returns the size of the observation.""" 22 | raise NotImplementedError 23 | 24 | def get_state(self): 25 | """Returns the global state.""" 26 | raise NotImplementedError 27 | 28 | def get_state_size(self): 29 | """Returns the size of the global state.""" 30 | raise NotImplementedError 31 | 32 | def get_avail_actions(self): 33 | """Returns the available actions of all agents in a list.""" 34 | raise NotImplementedError 35 | 36 | def get_avail_agent_actions(self, agent_id): 37 | """Returns the available actions for agent_id.""" 38 | raise NotImplementedError 39 | 40 | def get_total_actions(self): 41 | """Returns the total number of actions an agent could ever take.""" 42 | raise NotImplementedError 43 | 44 | def reset(self): 45 | """Returns initial observations and states.""" 46 | raise NotImplementedError 47 | 48 | def render(self): 49 | raise NotImplementedError 50 | 51 | def close(self): 52 | raise NotImplementedError 53 | 54 | def seed(self): 55 | raise NotImplementedError 56 | 57 | def save_replay(self): 58 | """Save a replay.""" 59 | raise NotImplementedError 60 | 61 | def get_env_info(self): 62 | env_info = {"state_shape": self.get_state_size(), 63 | "obs_shape": self.get_obs_size(), 64 | "obs_alone_shape": self.get_obs_alone_size(), 65 | "n_actions": self.get_total_actions(), 66 | "n_agents": self.n_agents, 67 | "episode_limit": self.episode_limit} 68 | return env_info 69 | -------------------------------------------------------------------------------- /algorithms/utils/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from .util import init, get_clones 3 | 4 | """MLP modules.""" 5 | 6 | class MLPLayer(nn.Module): 7 | def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, use_ReLU): 8 | super(MLPLayer, self).__init__() 9 | self._layer_N = layer_N 10 | 11 | active_func = [nn.Tanh(), nn.ReLU()][use_ReLU] 12 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 13 | gain = nn.init.calculate_gain(['tanh', 'relu'][use_ReLU]) 14 | 15 | def init_(m): 16 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) 17 | 18 | self.fc1 = nn.Sequential( 19 | init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size)) 20 | # self.fc_h = nn.Sequential(init_( 21 | # nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) 22 | # self.fc2 = get_clones(self.fc_h, self._layer_N) 23 | self.fc2 = nn.ModuleList([nn.Sequential(init_( 24 | nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) for i in range(self._layer_N)]) 25 | 26 | def forward(self, x): 27 | x = self.fc1(x) 28 | for i in range(self._layer_N): 29 | x = self.fc2[i](x) 30 | return x 31 | 32 | 33 | class MLPBase(nn.Module): 34 | def __init__(self, args, obs_shape, cat_self=True, attn_internal=False): 35 | super(MLPBase, self).__init__() 36 | 37 | self._use_feature_normalization = args.use_feature_normalization 38 | self._use_orthogonal = args.use_orthogonal 39 | self._use_ReLU = args.use_ReLU 40 | self._stacked_frames = args.stacked_frames 41 | self._layer_N = args.layer_N 42 | self.hidden_size = args.hidden_size 43 | 44 | obs_dim = obs_shape[0] 45 | 46 | if self._use_feature_normalization: 47 | self.feature_norm = nn.LayerNorm(obs_dim) 48 | 49 | self.mlp = MLPLayer(obs_dim, self.hidden_size, 50 | self._layer_N, self._use_orthogonal, self._use_ReLU) 51 | 52 | def forward(self, x): 53 | if self._use_feature_normalization: 54 | x = self.feature_norm(x) 55 | 56 | x = self.mlp(x) 57 | 58 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Trust Region Policy Optimisation in Multi-Agent Reinforcement Learning 2 | Anonymous code release for ICLR 22 paper submission, named "**Trust Region Policy Optimisation in Multi-Agent Reinforcement Learning**". This repository develops *Heterogeneous Agent Trust Region Policy Optimisation (HATRPO)* and *Heterogeneous-Agent Proximal Policy Optimisation (HAPPO)* algorithms on the bechmarks of SMAC and Multi-agent MUJOCO. *HATRPO* and *HAPPO* are the first trust region methods for multi-agent reinforcement learning **with theoretically-justified monotonic improvement guarantee**. Performance wise, it is the new state-of-the-art algorithm against its rivals such as [IPPO](https://arxiv.org/abs/2011.09533), [MAPPO](https://arxiv.org/abs/2103.01955) and [MADDPG](https://arxiv.org/abs/1706.02275) 3 | 4 | ## Installation 5 | ### Create environment 6 | ``` Bash 7 | conda create -n env_name python=3.9 8 | conda activate env_name 9 | pip install -r requirements.txt 10 | conda install pytorch torchvision torchaudio cudatoolkit=11.1 -c pytorch -c nvidia 11 | ``` 12 | 13 | ### Multi-agent MuJoCo 14 | Following the instructions in https://github.com/openai/mujoco-py and https://github.com/schroederdewitt/multiagent_mujoco to setup a mujoco environment. In the end, remember to set the following environment variables: 15 | ``` Bash 16 | LD_LIBRARY_PATH=${HOME}/.mujoco/mujoco200/bin; 17 | LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libGLEW.so 18 | ``` 19 | ### StarCraft II & SMAC 20 | Run the script 21 | ``` Bash 22 | bash install_sc2.sh 23 | ``` 24 | Or you could install them manually to other path you like, just follow here: https://github.com/oxwhirl/smac. 25 | 26 | ## How to run 27 | When your environment is ready, you could run shell scripts provided. For example: 28 | ``` Bash 29 | cd scripts 30 | ./train_mujoco.sh # run with HAPPO/HATRPO on Multi-agent MuJoCo 31 | ./train_smac.sh # run with HAPPO/HATRPO on StarCraft II 32 | ``` 33 | 34 | If you would like to change the configs of experiments, you could modify sh files or look for config files for more details. And you can change algorithm by modify **algo=happo** as **algo=hatrpo**. 35 | 36 | 37 | 38 | ## Some experiment results 39 | 40 | ### SMAC 41 | 42 | 43 | 44 | 45 | ### Multi-agent MuJoCo on MAPPO 46 | 47 | 48 | 49 | ### 50 | 51 | 52 | -------------------------------------------------------------------------------- /utils/multi_discrete.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | # An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates) 5 | # (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py) 6 | class MultiDiscrete(gym.Space): 7 | """ 8 | - The multi-discrete action space consists of a series of discrete action spaces with different parameters 9 | - It can be adapted to both a Discrete action space or a continuous (Box) action space 10 | - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space 11 | - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space where the discrete action space can take any integers from `min` to `max` (both inclusive) 12 | Note: A value of 0 always need to represent the NOOP action. 13 | e.g. Nintendo Game Controller 14 | - Can be conceptualized as 3 discrete action spaces: 15 | 1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4 16 | 2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 17 | 3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 18 | - Can be initialized as 19 | MultiDiscrete([ [0,4], [0,1], [0,1] ]) 20 | """ 21 | 22 | def __init__(self, array_of_param_array): 23 | self.low = np.array([x[0] for x in array_of_param_array]) 24 | self.high = np.array([x[1] for x in array_of_param_array]) 25 | self.num_discrete_space = self.low.shape[0] 26 | self.n = np.sum(self.high) + 2 27 | 28 | def sample(self): 29 | """ Returns a array with one sample from each discrete action space """ 30 | # For each row: round(random .* (max - min) + min, 0) 31 | random_array = np.random.rand(self.num_discrete_space) 32 | return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)] 33 | 34 | def contains(self, x): 35 | return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all() 36 | 37 | @property 38 | def shape(self): 39 | return self.num_discrete_space 40 | 41 | def __repr__(self): 42 | return "MultiDiscrete" + str(self.num_discrete_space) 43 | 44 | def __eq__(self, other): 45 | return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high) 46 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | 5 | def check(input): 6 | if type(input) == np.ndarray: 7 | return torch.from_numpy(input) 8 | 9 | def get_gard_norm(it): 10 | sum_grad = 0 11 | for x in it: 12 | if x.grad is None: 13 | continue 14 | sum_grad += x.grad.norm() ** 2 15 | return math.sqrt(sum_grad) 16 | 17 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 18 | """Decreases the learning rate linearly""" 19 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 20 | for param_group in optimizer.param_groups: 21 | param_group['lr'] = lr 22 | 23 | def huber_loss(e, d): 24 | a = (abs(e) <= d).float() 25 | b = (e > d).float() 26 | return a*e**2/2 + b*d*(abs(e)-d/2) 27 | 28 | def mse_loss(e): 29 | return e**2/2 30 | 31 | def get_shape_from_obs_space(obs_space): 32 | if obs_space.__class__.__name__ == 'Box': 33 | obs_shape = obs_space.shape 34 | elif obs_space.__class__.__name__ == 'list': 35 | obs_shape = obs_space 36 | else: 37 | raise NotImplementedError 38 | return obs_shape 39 | 40 | def get_shape_from_act_space(act_space): 41 | if act_space.__class__.__name__ == 'Discrete': 42 | act_shape = 1 43 | elif act_space.__class__.__name__ == "MultiDiscrete": 44 | act_shape = act_space.shape 45 | elif act_space.__class__.__name__ == "Box": 46 | act_shape = act_space.shape[0] 47 | elif act_space.__class__.__name__ == "MultiBinary": 48 | act_shape = act_space.shape[0] 49 | else: # agar 50 | act_shape = act_space[0].shape[0] + 1 51 | return act_shape 52 | 53 | 54 | def tile_images(img_nhwc): 55 | """ 56 | Tile N images into one big PxQ image 57 | (P,Q) are chosen to be as close as possible, and if N 58 | is square, then P=Q. 59 | input: img_nhwc, list or array of images, ndim=4 once turned into array 60 | n = batch index, h = height, w = width, c = channel 61 | returns: 62 | bigim_HWc, ndarray with ndim=3 63 | """ 64 | img_nhwc = np.asarray(img_nhwc) 65 | N, h, w, c = img_nhwc.shape 66 | H = int(np.ceil(np.sqrt(N))) 67 | W = int(np.ceil(float(N)/H)) 68 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 69 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 70 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 71 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 72 | return img_Hh_Ww_c -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer_bckp.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/multiagentenv.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | 4 | 5 | def convert(dictionary): 6 | return namedtuple('GenericDict', dictionary.keys())(**dictionary) 7 | 8 | class MultiAgentEnv(object): 9 | 10 | def __init__(self, batch_size=None, **kwargs): 11 | # Unpack arguments from sacred 12 | args = kwargs["env_args"] 13 | if isinstance(args, dict): 14 | args = convert(args) 15 | self.args = args 16 | 17 | if getattr(args, "seed", None) is not None: 18 | self.seed = args.seed 19 | self.rs = np.random.RandomState(self.seed) # initialise numpy random state 20 | 21 | def step(self, actions): 22 | """ Returns reward, terminated, info """ 23 | raise NotImplementedError 24 | 25 | def get_obs(self): 26 | """ Returns all agent observations in a list """ 27 | raise NotImplementedError 28 | 29 | def get_obs_agent(self, agent_id): 30 | """ Returns observation for agent_id """ 31 | raise NotImplementedError 32 | 33 | def get_obs_size(self): 34 | """ Returns the shape of the observation """ 35 | raise NotImplementedError 36 | 37 | def get_state(self): 38 | raise NotImplementedError 39 | 40 | def get_state_size(self): 41 | """ Returns the shape of the state""" 42 | raise NotImplementedError 43 | 44 | def get_avail_actions(self): 45 | raise NotImplementedError 46 | 47 | def get_avail_agent_actions(self, agent_id): 48 | """ Returns the available actions for agent_id """ 49 | raise NotImplementedError 50 | 51 | def get_total_actions(self): 52 | """ Returns the total number of actions an agent could ever take """ 53 | # TODO: This is only suitable for a discrete 1 dimensional action space for each agent 54 | raise NotImplementedError 55 | 56 | def get_stats(self): 57 | raise NotImplementedError 58 | 59 | # TODO: Temp hack 60 | def get_agg_stats(self, stats): 61 | return {} 62 | 63 | def reset(self): 64 | """ Returns initial observations and states""" 65 | raise NotImplementedError 66 | 67 | def render(self): 68 | raise NotImplementedError 69 | 70 | def close(self): 71 | raise NotImplementedError 72 | 73 | def seed(self, seed): 74 | raise NotImplementedError 75 | 76 | def get_env_info(self): 77 | env_info = {"state_shape": self.get_state_size(), 78 | "obs_shape": self.get_obs_size(), 79 | "n_actions": self.get_total_actions(), 80 | "n_agents": self.n_agents, 81 | "episode_limit": self.episode_limit} 82 | return env_info -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/manyagent_swimmer__bckp2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /algorithms/utils/rnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | """RNN modules.""" 5 | 6 | 7 | class RNNLayer(nn.Module): 8 | def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal): 9 | super(RNNLayer, self).__init__() 10 | self._recurrent_N = recurrent_N 11 | self._use_orthogonal = use_orthogonal 12 | 13 | self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N) 14 | for name, param in self.rnn.named_parameters(): 15 | if 'bias' in name: 16 | nn.init.constant_(param, 0) 17 | elif 'weight' in name: 18 | if self._use_orthogonal: 19 | nn.init.orthogonal_(param) 20 | else: 21 | nn.init.xavier_uniform_(param) 22 | self.norm = nn.LayerNorm(outputs_dim) 23 | 24 | def forward(self, x, hxs, masks): 25 | if x.size(0) == hxs.size(0): 26 | x, hxs = self.rnn(x.unsqueeze(0), 27 | (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous()) 28 | x = x.squeeze(0) 29 | hxs = hxs.transpose(0, 1) 30 | else: 31 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 32 | N = hxs.size(0) 33 | T = int(x.size(0) / N) 34 | 35 | # unflatten 36 | x = x.view(T, N, x.size(1)) 37 | 38 | # Same deal with masks 39 | masks = masks.view(T, N) 40 | 41 | # Let's figure out which steps in the sequence have a zero for any agent 42 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 43 | has_zeros = ((masks[1:] == 0.0) 44 | .any(dim=-1) 45 | .nonzero() 46 | .squeeze() 47 | .cpu()) 48 | 49 | # +1 to correct the masks[1:] 50 | if has_zeros.dim() == 0: 51 | # Deal with scalar 52 | has_zeros = [has_zeros.item() + 1] 53 | else: 54 | has_zeros = (has_zeros + 1).numpy().tolist() 55 | 56 | # add t=0 and t=T to the list 57 | has_zeros = [0] + has_zeros + [T] 58 | 59 | hxs = hxs.transpose(0, 1) 60 | 61 | outputs = [] 62 | for i in range(len(has_zeros) - 1): 63 | # We can now process steps that don't have any zeros in masks together! 64 | # This is much faster 65 | start_idx = has_zeros[i] 66 | end_idx = has_zeros[i + 1] 67 | temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous() 68 | rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp) 69 | outputs.append(rnn_scores) 70 | 71 | # assert len(outputs) == T 72 | # x is a (T, N, -1) tensor 73 | x = torch.cat(outputs, dim=0) 74 | 75 | # flatten 76 | x = x.reshape(T * N, -1) 77 | hxs = hxs.transpose(0, 1) 78 | 79 | x = self.norm(x) 80 | return x, hxs 81 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant.xml.template: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /utils/popart.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class PopArt(nn.Module): 9 | """ Normalize a vector of observations - across the first norm_axes dimensions""" 10 | 11 | def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element_update=False, epsilon=1e-5, device=torch.device("cpu")): 12 | super(PopArt, self).__init__() 13 | 14 | self.input_shape = input_shape 15 | self.norm_axes = norm_axes 16 | self.epsilon = epsilon 17 | self.beta = beta 18 | self.per_element_update = per_element_update 19 | self.tpdv = dict(dtype=torch.float32, device=device) 20 | 21 | self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) 22 | self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) 23 | self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv) 24 | 25 | def reset_parameters(self): 26 | self.running_mean.zero_() 27 | self.running_mean_sq.zero_() 28 | self.debiasing_term.zero_() 29 | 30 | def running_mean_var(self): 31 | debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon) 32 | debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon) 33 | debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2) 34 | return debiased_mean, debiased_var 35 | 36 | def forward(self, input_vector, train=True): 37 | # Make sure input is float32 38 | if type(input_vector) == np.ndarray: 39 | input_vector = torch.from_numpy(input_vector) 40 | input_vector = input_vector.to(**self.tpdv) 41 | 42 | if train: 43 | # Detach input before adding it to running means to avoid backpropping through it on 44 | # subsequent batches. 45 | detached_input = input_vector.detach() 46 | batch_mean = detached_input.mean(dim=tuple(range(self.norm_axes))) 47 | batch_sq_mean = (detached_input ** 2).mean(dim=tuple(range(self.norm_axes))) 48 | 49 | if self.per_element_update: 50 | batch_size = np.prod(detached_input.size()[:self.norm_axes]) 51 | weight = self.beta ** batch_size 52 | else: 53 | weight = self.beta 54 | 55 | self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight)) 56 | self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight)) 57 | self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight)) 58 | 59 | mean, var = self.running_mean_var() 60 | out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes] 61 | 62 | return out 63 | 64 | def denormalize(self, input_vector): 65 | """ Transform normalized data back into original distribution """ 66 | if type(input_vector) == np.ndarray: 67 | input_vector = torch.from_numpy(input_vector) 68 | input_vector = input_vector.to(**self.tpdv) 69 | 70 | mean, var = self.running_mean_var() 71 | out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes] 72 | 73 | out = out.cpu().numpy() 74 | 75 | return out 76 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | aiohttp==3.6.2 3 | aioredis==1.3.1 4 | astor==0.8.0 5 | astunparse==1.6.3 6 | async-timeout==3.0.1 7 | atari-py==0.2.6 8 | atomicwrites==1.2.1 9 | attrs==18.2.0 10 | beautifulsoup4==4.9.1 11 | blessings==1.7 12 | cachetools==4.1.1 13 | certifi==2020.4.5.2 14 | cffi==1.14.1 15 | chardet==3.0.4 16 | click==7.1.2 17 | cloudpickle==1.3.0 18 | colorama==0.4.3 19 | colorful==0.5.4 20 | configparser==5.0.1 21 | contextvars==2.4 22 | cycler==0.10.0 23 | Cython==0.29.21 24 | deepdiff==4.3.2 25 | dill==0.3.2 26 | docker-pycreds==0.4.0 27 | docopt==0.6.2 28 | fasteners==0.15 29 | filelock==3.0.12 30 | funcsigs==1.0.2 31 | future==0.16.0 32 | gast==0.2.2 33 | gin==0.1.6 34 | gin-config==0.3.0 35 | gitdb==4.0.5 36 | GitPython==3.1.9 37 | glfw==1.12.0 38 | google==3.0.0 39 | google-api-core==1.22.1 40 | google-auth==1.21.0 41 | google-auth-oauthlib==0.4.1 42 | google-pasta==0.2.0 43 | googleapis-common-protos==1.52.0 44 | gpustat==0.6.0 45 | gql==0.2.0 46 | graphql-core==1.1 47 | grpcio==1.31.0 48 | gym==0.17.2 49 | h5py==2.10.0 50 | hiredis==1.1.0 51 | idna==2.7 52 | idna-ssl==1.1.0 53 | imageio==2.4.1 54 | immutables==0.14 55 | importlib-metadata==1.7.0 56 | joblib==0.16.0 57 | jsonnet==0.16.0 58 | jsonpickle==0.9.6 59 | jsonschema==3.2.0 60 | Keras-Applications==1.0.8 61 | Keras-Preprocessing==1.1.2 62 | kiwisolver==1.0.1 63 | lockfile==0.12.2 64 | Markdown==3.1.1 65 | matplotlib==3.0.0 66 | mkl-fft==1.2.0 67 | mkl-random==1.2.0 68 | mkl-service==2.3.0 69 | mock==2.0.0 70 | monotonic==1.5 71 | more-itertools==4.3.0 72 | mpi4py==3.0.3 73 | mpyq==0.2.5 74 | msgpack==1.0.0 75 | mujoco-py==2.0.2.8 76 | multidict==4.7.6 77 | munch==2.3.2 78 | numpy 79 | nvidia-ml-py3==7.352.0 80 | oauthlib==3.1.0 81 | opencensus==0.7.10 82 | opencensus-context==0.1.1 83 | opencv-python==4.2.0.34 84 | opt-einsum==3.1.0 85 | ordered-set==4.0.2 86 | packaging==20.4 87 | pandas==1.1.1 88 | pathlib2==2.3.2 89 | pathtools==0.1.2 90 | pbr==4.3.0 91 | Pillow==5.3.0 92 | pluggy==0.7.1 93 | portpicker==1.2.0 94 | probscale==0.2.3 95 | progressbar2==3.53.1 96 | prometheus-client==0.8.0 97 | promise==2.3 98 | protobuf==3.12.4 99 | psutil==5.7.2 100 | py==1.6.0 101 | py-spy==0.3.3 102 | pyasn1==0.4.8 103 | pyasn1-modules==0.2.8 104 | pycparser==2.20 105 | pygame==1.9.4 106 | pyglet==1.5.0 107 | PyOpenGL==3.1.5 108 | PyOpenGL-accelerate==3.1.5 109 | pyparsing==2.2.2 110 | pyrsistent==0.16.0 111 | PySC2==3.0.0 112 | pytest==3.8.2 113 | python-dateutil==2.7.3 114 | python-utils==2.4.0 115 | pytz==2020.1 116 | PyYAML==3.13 117 | pyzmq==19.0.2 118 | redis==3.4.1 119 | requests==2.24.0 120 | requests-oauthlib==1.3.0 121 | rsa==4.6 122 | s2clientprotocol==4.10.1.75800.0 123 | s2protocol==4.11.4.78285.0 124 | sacred==0.7.2 125 | scipy==1.4.1 126 | seaborn==0.10.1 127 | sentry-sdk==0.18.0 128 | setproctitle==1.1.10 129 | shortuuid==1.0.1 130 | six==1.15.0 131 | sk-video==1.1.10 132 | smmap==3.0.4 133 | snakeviz==1.0.0 134 | soupsieve==2.0.1 135 | subprocess32==3.5.4 136 | tabulate==0.8.7 137 | tensorboard==2.0.2 138 | tensorboard-logger==0.1.0 139 | tensorboard-plugin-wit==1.7.0 140 | tensorboardX==2.0 141 | tensorflow==2.0.0 142 | tensorflow-estimator==2.0.0 143 | termcolor==1.1.0 144 | torch 145 | torchvision 146 | tornado 147 | tqdm==4.48.2 148 | typing-extensions==3.7.4.3 149 | urllib3==1.23 150 | watchdog==0.10.3 151 | websocket-client==0.53.0 152 | Werkzeug==0.16.1 153 | whichcraft==0.5.2 154 | wrapt==1.12.1 155 | xmltodict==0.12.0 156 | yarl==1.5.1 157 | zipp==3.1.0 158 | zmq==0.0.0 159 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/manyagent_swimmer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | import os 5 | from jinja2 import Template 6 | 7 | class ManyAgentSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle): 8 | def __init__(self, **kwargs): 9 | agent_conf = kwargs.get("agent_conf") 10 | n_agents = int(agent_conf.split("x")[0]) 11 | n_segs_per_agents = int(agent_conf.split("x")[1]) 12 | n_segs = n_agents * n_segs_per_agents 13 | 14 | # Check whether asset file exists already, otherwise create it 15 | asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 16 | 'manyagent_swimmer_{}_agents_each_{}_segments.auto.xml'.format(n_agents, 17 | n_segs_per_agents)) 18 | # if not os.path.exists(asset_path): 19 | print("Auto-Generating Manyagent Swimmer asset with {} segments at {}.".format(n_segs, asset_path)) 20 | self._generate_asset(n_segs=n_segs, asset_path=asset_path) 21 | 22 | #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p 23 | # 'manyagent_swimmer.xml') 24 | 25 | mujoco_env.MujocoEnv.__init__(self, asset_path, 4) 26 | utils.EzPickle.__init__(self) 27 | 28 | def _generate_asset(self, n_segs, asset_path): 29 | template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 30 | 'manyagent_swimmer.xml.template') 31 | with open(template_path, "r") as f: 32 | t = Template(f.read()) 33 | body_str_template = """ 34 | 35 | 36 | 37 | """ 38 | 39 | body_end_str_template = """ 40 | 41 | 42 | 43 | 44 | """ 45 | 46 | body_close_str_template ="\n" 47 | actuator_str_template = """\t \n""" 48 | 49 | body_str = "" 50 | for i in range(1,n_segs-1): 51 | body_str += body_str_template.format(i, (-1)**(i+1), i) 52 | body_str += body_end_str_template.format(n_segs-1) 53 | body_str += body_close_str_template*(n_segs-2) 54 | 55 | actuator_str = "" 56 | for i in range(n_segs): 57 | actuator_str += actuator_str_template.format(i) 58 | 59 | rt = t.render(body=body_str, actuators=actuator_str) 60 | with open(asset_path, "w") as f: 61 | f.write(rt) 62 | pass 63 | 64 | def step(self, a): 65 | ctrl_cost_coeff = 0.0001 66 | xposbefore = self.sim.data.qpos[0] 67 | self.do_simulation(a, self.frame_skip) 68 | xposafter = self.sim.data.qpos[0] 69 | reward_fwd = (xposafter - xposbefore) / self.dt 70 | reward_ctrl = - ctrl_cost_coeff * np.square(a).sum() 71 | reward = reward_fwd + reward_ctrl 72 | ob = self._get_obs() 73 | return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl) 74 | 75 | def _get_obs(self): 76 | qpos = self.sim.data.qpos 77 | qvel = self.sim.data.qvel 78 | return np.concatenate([qpos.flat[2:], qvel.flat]) 79 | 80 | def reset_model(self): 81 | self.set_state( 82 | self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq), 83 | self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv) 84 | ) 85 | return self._get_obs() 86 | -------------------------------------------------------------------------------- /algorithms/utils/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .util import init 4 | 5 | """ 6 | Modify standard PyTorch distributions so they to make compatible with this codebase. 7 | """ 8 | 9 | # 10 | # Standardize distribution interfaces 11 | # 12 | 13 | # Categorical 14 | class FixedCategorical(torch.distributions.Categorical): 15 | def sample(self): 16 | return super().sample().unsqueeze(-1) 17 | 18 | def log_probs(self, actions): 19 | return ( 20 | super() 21 | .log_prob(actions.squeeze(-1)) 22 | .view(actions.size(0), -1) 23 | .sum(-1) 24 | .unsqueeze(-1) 25 | ) 26 | 27 | def mode(self): 28 | return self.probs.argmax(dim=-1, keepdim=True) 29 | 30 | 31 | # Normal 32 | class FixedNormal(torch.distributions.Normal): 33 | def log_probs(self, actions): 34 | return super().log_prob(actions) 35 | # return super().log_prob(actions).sum(-1, keepdim=True) 36 | 37 | def entrop(self): 38 | return super.entropy().sum(-1) 39 | 40 | def mode(self): 41 | return self.mean 42 | 43 | 44 | # Bernoulli 45 | class FixedBernoulli(torch.distributions.Bernoulli): 46 | def log_probs(self, actions): 47 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 48 | 49 | def entropy(self): 50 | return super().entropy().sum(-1) 51 | 52 | def mode(self): 53 | return torch.gt(self.probs, 0.5).float() 54 | 55 | 56 | class Categorical(nn.Module): 57 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 58 | super(Categorical, self).__init__() 59 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 60 | def init_(m): 61 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 62 | 63 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 64 | 65 | def forward(self, x, available_actions=None): 66 | x = self.linear(x) 67 | if available_actions is not None: 68 | x[available_actions == 0] = -1e10 69 | return FixedCategorical(logits=x) 70 | 71 | 72 | # class DiagGaussian(nn.Module): 73 | # def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 74 | # super(DiagGaussian, self).__init__() 75 | # 76 | # init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 77 | # def init_(m): 78 | # return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 79 | # 80 | # self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 81 | # self.logstd = AddBias(torch.zeros(num_outputs)) 82 | # 83 | # def forward(self, x, available_actions=None): 84 | # action_mean = self.fc_mean(x) 85 | # 86 | # # An ugly hack for my KFAC implementation. 87 | # zeros = torch.zeros(action_mean.size()) 88 | # if x.is_cuda: 89 | # zeros = zeros.cuda() 90 | # 91 | # action_logstd = self.logstd(zeros) 92 | # return FixedNormal(action_mean, action_logstd.exp()) 93 | 94 | class DiagGaussian(nn.Module): 95 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01, args=None): 96 | super(DiagGaussian, self).__init__() 97 | 98 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 99 | 100 | def init_(m): 101 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 102 | 103 | if args is not None: 104 | self.std_x_coef = args.std_x_coef 105 | self.std_y_coef = args.std_y_coef 106 | else: 107 | self.std_x_coef = 1. 108 | self.std_y_coef = 0.5 109 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 110 | log_std = torch.ones(num_outputs) * self.std_x_coef 111 | self.log_std = torch.nn.Parameter(log_std) 112 | 113 | def forward(self, x, available_actions=None): 114 | action_mean = self.fc_mean(x) 115 | action_std = torch.sigmoid(self.log_std / self.std_x_coef) * self.std_y_coef 116 | return FixedNormal(action_mean, action_std) 117 | 118 | class Bernoulli(nn.Module): 119 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 120 | super(Bernoulli, self).__init__() 121 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 122 | def init_(m): 123 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 124 | 125 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 126 | 127 | def forward(self, x): 128 | x = self.linear(x) 129 | return FixedBernoulli(logits=x) 130 | 131 | class AddBias(nn.Module): 132 | def __init__(self, bias): 133 | super(AddBias, self).__init__() 134 | self._bias = nn.Parameter(bias.unsqueeze(1)) 135 | 136 | def forward(self, x): 137 | if x.dim() == 2: 138 | bias = self._bias.t().view(1, -1) 139 | else: 140 | bias = self._bias.t().view(1, -1, 1, 1) 141 | 142 | return x + bias 143 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant__stage1.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /scripts/train/train_smac.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import os 4 | sys.path.append("../") 5 | import socket 6 | import setproctitle 7 | import numpy as np 8 | from pathlib import Path 9 | import torch 10 | from configs.config import get_config 11 | from envs.starcraft2.StarCraft2_Env import StarCraft2Env 12 | from envs.starcraft2.smac_maps import get_map_params 13 | from envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv 14 | from runners.separated.smac_runner import SMACRunner as Runner 15 | """Train script for SMAC.""" 16 | 17 | def make_train_env(all_args): 18 | def get_env_fn(rank): 19 | def init_env(): 20 | if all_args.env_name == "StarCraft2": 21 | env = StarCraft2Env(all_args) 22 | else: 23 | print("Can not support the " + all_args.env_name + "environment.") 24 | raise NotImplementedError 25 | env.seed(all_args.seed + rank * 1000) 26 | return env 27 | 28 | return init_env 29 | 30 | if all_args.n_rollout_threads == 1: 31 | return ShareDummyVecEnv([get_env_fn(0)]) 32 | else: 33 | return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)]) 34 | 35 | 36 | def make_eval_env(all_args): 37 | def get_env_fn(rank): 38 | def init_env(): 39 | if all_args.env_name == "StarCraft2": 40 | env = StarCraft2Env(all_args) 41 | else: 42 | print("Can not support the " + all_args.env_name + "environment.") 43 | raise NotImplementedError 44 | env.seed(all_args.seed * 50000 + rank * 10000) 45 | return env 46 | 47 | return init_env 48 | 49 | if all_args.n_eval_rollout_threads == 1: 50 | return ShareDummyVecEnv([get_env_fn(0)]) 51 | else: 52 | return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)]) 53 | 54 | 55 | def parse_args(args, parser): 56 | parser.add_argument('--map_name', type=str, default='3m',help="Which smac map to run on") 57 | parser.add_argument("--add_move_state", action='store_true', default=False) 58 | parser.add_argument("--add_local_obs", action='store_true', default=False) 59 | parser.add_argument("--add_distance_state", action='store_true', default=False) 60 | parser.add_argument("--add_enemy_action_state", action='store_true', default=False) 61 | parser.add_argument("--add_agent_id", action='store_true', default=False) 62 | parser.add_argument("--add_visible_state", action='store_true', default=False) 63 | parser.add_argument("--add_xy_state", action='store_true', default=False) 64 | parser.add_argument("--use_state_agent", action='store_true', default=False) 65 | parser.add_argument("--use_mustalive", action='store_false', default=True) 66 | parser.add_argument("--add_center_xy", action='store_true', default=False) 67 | parser.add_argument("--use_single_network", action='store_true', default=False) 68 | all_args = parser.parse_known_args(args)[0] 69 | 70 | return all_args 71 | 72 | 73 | def main(args): 74 | parser = get_config() 75 | all_args = parse_args(args, parser) 76 | print("all config: ", all_args) 77 | if all_args.seed_specify: 78 | all_args.seed=all_args.runing_id 79 | else: 80 | all_args.seed=np.random.randint(1000,10000) 81 | print("seed is :",all_args.seed) 82 | # cuda 83 | if all_args.cuda and torch.cuda.is_available(): 84 | print("choose to use gpu...") 85 | device = torch.device("cuda:0") 86 | torch.set_num_threads(all_args.n_training_threads) 87 | if all_args.cuda_deterministic: 88 | torch.backends.cudnn.benchmark = False 89 | torch.backends.cudnn.deterministic = True 90 | else: 91 | print("choose to use cpu...") 92 | device = torch.device("cpu") 93 | torch.set_num_threads(all_args.n_training_threads) 94 | 95 | run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[ 96 | 0] + "/results") / all_args.env_name / all_args.map_name / all_args.algorithm_name / all_args.experiment_name / str(all_args.seed) 97 | if not run_dir.exists(): 98 | os.makedirs(str(run_dir)) 99 | 100 | if not run_dir.exists(): 101 | curr_run = 'run1' 102 | else: 103 | exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if 104 | str(folder.name).startswith('run')] 105 | if len(exst_run_nums) == 0: 106 | curr_run = 'run1' 107 | else: 108 | curr_run = 'run%i' % (max(exst_run_nums) + 1) 109 | run_dir = run_dir / curr_run 110 | if not run_dir.exists(): 111 | os.makedirs(str(run_dir)) 112 | 113 | setproctitle.setproctitle( 114 | str(all_args.algorithm_name) + "-" + str(all_args.env_name) + "-" + str(all_args.experiment_name) + "@" + str( 115 | all_args.user_name)) 116 | 117 | # seed 118 | torch.manual_seed(all_args.seed) 119 | torch.cuda.manual_seed_all(all_args.seed) 120 | np.random.seed(all_args.seed) 121 | 122 | # env 123 | envs = make_train_env(all_args) 124 | eval_envs = make_eval_env(all_args) if all_args.use_eval else None 125 | num_agents = get_map_params(all_args.map_name)["n_agents"] 126 | 127 | config = { 128 | "all_args": all_args, 129 | "envs": envs, 130 | "eval_envs": eval_envs, 131 | "num_agents": num_agents, 132 | "device": device, 133 | "run_dir": run_dir 134 | } 135 | # run experiments 136 | runner = Runner(config) 137 | runner.run() 138 | 139 | # post process 140 | envs.close() 141 | if all_args.use_eval and eval_envs is not envs: 142 | eval_envs.close() 143 | runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json')) 144 | runner.writter.close() 145 | 146 | 147 | if __name__ == "__main__": 148 | 149 | main(sys.argv[1:]) 150 | 151 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/manyagent_ant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gym import utils 3 | from gym.envs.mujoco import mujoco_env 4 | from jinja2 import Template 5 | import os 6 | 7 | class ManyAgentAntEnv(mujoco_env.MujocoEnv, utils.EzPickle): 8 | def __init__(self, **kwargs): 9 | agent_conf = kwargs.get("agent_conf") 10 | n_agents = int(agent_conf.split("x")[0]) 11 | n_segs_per_agents = int(agent_conf.split("x")[1]) 12 | n_segs = n_agents * n_segs_per_agents 13 | 14 | # Check whether asset file exists already, otherwise create it 15 | asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 16 | 'manyagent_ant_{}_agents_each_{}_segments.auto.xml'.format(n_agents, 17 | n_segs_per_agents)) 18 | #if not os.path.exists(asset_path): 19 | print("Auto-Generating Manyagent Ant asset with {} segments at {}.".format(n_segs, asset_path)) 20 | self._generate_asset(n_segs=n_segs, asset_path=asset_path) 21 | 22 | #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p 23 | # 'manyagent_swimmer.xml') 24 | 25 | mujoco_env.MujocoEnv.__init__(self, asset_path, 4) 26 | utils.EzPickle.__init__(self) 27 | 28 | def _generate_asset(self, n_segs, asset_path): 29 | template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 30 | 'manyagent_ant.xml.template') 31 | with open(template_path, "r") as f: 32 | t = Template(f.read()) 33 | body_str_template = """ 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | """ 60 | 61 | body_close_str_template ="\n" 62 | actuator_str_template = """\t 63 | 64 | 65 | \n""" 66 | 67 | body_str = "" 68 | for i in range(1,n_segs): 69 | body_str += body_str_template.format(*([i]*16)) 70 | body_str += body_close_str_template*(n_segs-1) 71 | 72 | actuator_str = "" 73 | for i in range(n_segs): 74 | actuator_str += actuator_str_template.format(*([i]*8)) 75 | 76 | rt = t.render(body=body_str, actuators=actuator_str) 77 | with open(asset_path, "w") as f: 78 | f.write(rt) 79 | pass 80 | 81 | def step(self, a): 82 | xposbefore = self.get_body_com("torso_0")[0] 83 | self.do_simulation(a, self.frame_skip) 84 | xposafter = self.get_body_com("torso_0")[0] 85 | forward_reward = (xposafter - xposbefore)/self.dt 86 | ctrl_cost = .5 * np.square(a).sum() 87 | contact_cost = 0.5 * 1e-3 * np.sum( 88 | np.square(np.clip(self.sim.data.cfrc_ext, -1, 1))) 89 | survive_reward = 1.0 90 | reward = forward_reward - ctrl_cost - contact_cost + survive_reward 91 | state = self.state_vector() 92 | notdone = np.isfinite(state).all() \ 93 | and state[2] >= 0.2 and state[2] <= 1.0 94 | done = not notdone 95 | ob = self._get_obs() 96 | return ob, reward, done, dict( 97 | reward_forward=forward_reward, 98 | reward_ctrl=-ctrl_cost, 99 | reward_contact=-contact_cost, 100 | reward_survive=survive_reward) 101 | 102 | def _get_obs(self): 103 | return np.concatenate([ 104 | self.sim.data.qpos.flat[2:], 105 | self.sim.data.qvel.flat, 106 | np.clip(self.sim.data.cfrc_ext, -1, 1).flat, 107 | ]) 108 | 109 | def reset_model(self): 110 | qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1) 111 | qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1 112 | self.set_state(qpos, qvel) 113 | return self._get_obs() 114 | 115 | def viewer_setup(self): 116 | self.viewer.cam.distance = self.model.stat.extent * 0.5 -------------------------------------------------------------------------------- /scripts/train/train_mujoco.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import sys 3 | import os 4 | sys.path.append("../") 5 | import socket 6 | import setproctitle 7 | import numpy as np 8 | from pathlib import Path 9 | import torch 10 | from configs.config import get_config 11 | from envs.ma_mujoco.multiagent_mujoco.mujoco_multi import MujocoMulti 12 | from envs.env_wrappers import ShareSubprocVecEnv, ShareDummyVecEnv 13 | from runners.separated.mujoco_runner import MujocoRunner as Runner 14 | """Train script for Mujoco.""" 15 | 16 | 17 | def make_train_env(all_args): 18 | def get_env_fn(rank): 19 | def init_env(): 20 | if all_args.env_name == "mujoco": 21 | env_args = {"scenario": all_args.scenario, 22 | "agent_conf": all_args.agent_conf, 23 | "agent_obsk": all_args.agent_obsk, 24 | "episode_limit": 1000} 25 | env = MujocoMulti(env_args=env_args) 26 | else: 27 | print("Can not support the " + all_args.env_name + "environment.") 28 | raise NotImplementedError 29 | env.seed(all_args.seed + rank * 1000) 30 | return env 31 | 32 | return init_env 33 | 34 | if all_args.n_rollout_threads == 1: 35 | return ShareDummyVecEnv([get_env_fn(0)]) 36 | else: 37 | return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_rollout_threads)]) 38 | 39 | 40 | def make_eval_env(all_args): 41 | def get_env_fn(rank): 42 | def init_env(): 43 | if all_args.env_name == "mujoco": 44 | env_args = {"scenario": all_args.scenario, 45 | "agent_conf": all_args.agent_conf, 46 | "agent_obsk": all_args.agent_obsk, 47 | "episode_limit": 1000} 48 | env = MujocoMulti(env_args=env_args) 49 | else: 50 | print("Can not support the " + all_args.env_name + "environment.") 51 | raise NotImplementedError 52 | env.seed(all_args.seed * 50000 + rank * 10000) 53 | return env 54 | 55 | return init_env 56 | 57 | if all_args.n_eval_rollout_threads == 1: 58 | return ShareDummyVecEnv([get_env_fn(0)]) 59 | else: 60 | return ShareSubprocVecEnv([get_env_fn(i) for i in range(all_args.n_eval_rollout_threads)]) 61 | 62 | 63 | def parse_args(args, parser): 64 | parser.add_argument('--scenario', type=str, default='Hopper-v2', help="Which mujoco task to run on") 65 | parser.add_argument('--agent_conf', type=str, default='3x1') 66 | parser.add_argument('--agent_obsk', type=int, default=0) 67 | parser.add_argument("--add_move_state", action='store_true', default=False) 68 | parser.add_argument("--add_local_obs", action='store_true', default=False) 69 | parser.add_argument("--add_distance_state", action='store_true', default=False) 70 | parser.add_argument("--add_enemy_action_state", action='store_true', default=False) 71 | parser.add_argument("--add_agent_id", action='store_true', default=False) 72 | parser.add_argument("--add_visible_state", action='store_true', default=False) 73 | parser.add_argument("--add_xy_state", action='store_true', default=False) 74 | 75 | # agent-specific state should be designed carefully 76 | parser.add_argument("--use_state_agent", action='store_true', default=False) 77 | parser.add_argument("--use_mustalive", action='store_false', default=True) 78 | parser.add_argument("--add_center_xy", action='store_true', default=False) 79 | parser.add_argument("--use_single_network", action='store_true', default=False) 80 | 81 | all_args = parser.parse_known_args(args)[0] 82 | 83 | return all_args 84 | 85 | 86 | def main(args): 87 | parser = get_config() 88 | all_args = parse_args(args, parser) 89 | print("all config: ", all_args) 90 | if all_args.seed_specify: 91 | all_args.seed=all_args.runing_id 92 | else: 93 | all_args.seed=np.random.randint(1000,10000) 94 | print("seed is :",all_args.seed) 95 | # cuda 96 | if all_args.cuda and torch.cuda.is_available(): 97 | print("choose to use gpu...") 98 | device = torch.device("cuda:0") 99 | torch.set_num_threads(all_args.n_training_threads) 100 | if all_args.cuda_deterministic: 101 | torch.backends.cudnn.benchmark = False 102 | torch.backends.cudnn.deterministic = True 103 | else: 104 | print("choose to use cpu...") 105 | device = torch.device("cpu") 106 | torch.set_num_threads(all_args.n_training_threads) 107 | 108 | run_dir = Path(os.path.split(os.path.dirname(os.path.abspath(__file__)))[ 109 | 0] + "/results") / all_args.env_name / all_args.scenario / all_args.algorithm_name / all_args.experiment_name / str(all_args.seed) 110 | if not run_dir.exists(): 111 | os.makedirs(str(run_dir)) 112 | 113 | if not run_dir.exists(): 114 | curr_run = 'run1' 115 | else: 116 | exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if 117 | str(folder.name).startswith('run')] 118 | if len(exst_run_nums) == 0: 119 | curr_run = 'run1' 120 | else: 121 | curr_run = 'run%i' % (max(exst_run_nums) + 1) 122 | run_dir = run_dir / curr_run 123 | if not run_dir.exists(): 124 | os.makedirs(str(run_dir)) 125 | 126 | setproctitle.setproctitle( 127 | str(all_args.algorithm_name) + "-" + str(all_args.env_name) + "-" + str(all_args.experiment_name) + "@" + str( 128 | all_args.user_name)) 129 | 130 | # seed 131 | torch.manual_seed(all_args.seed) 132 | torch.cuda.manual_seed_all(all_args.seed) 133 | np.random.seed(all_args.seed) 134 | 135 | # env 136 | envs = make_train_env(all_args) 137 | eval_envs = make_eval_env(all_args) if all_args.use_eval else None 138 | num_agents = envs.n_agents 139 | 140 | config = { 141 | "all_args": all_args, 142 | "envs": envs, 143 | "eval_envs": eval_envs, 144 | "num_agents": num_agents, 145 | "device": device, 146 | "run_dir": run_dir 147 | } 148 | 149 | # run experiments 150 | runner = Runner(config) 151 | runner.run() 152 | 153 | # post process 154 | envs.close() 155 | if all_args.use_eval and eval_envs is not envs: 156 | eval_envs.close() 157 | 158 | runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json')) 159 | runner.writter.close() 160 | 161 | 162 | if __name__ == "__main__": 163 | main(sys.argv[1:]) 164 | -------------------------------------------------------------------------------- /algorithms/happo_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from algorithms.actor_critic import Actor, Critic 3 | from utils.util import update_linear_schedule 4 | 5 | 6 | class HAPPO_Policy: 7 | """ 8 | HAPPO Policy class. Wraps actor and critic networks to compute actions and value function predictions. 9 | 10 | :param args: (argparse.Namespace) arguments containing relevant model and policy information. 11 | :param obs_space: (gym.Space) observation space. 12 | :param cent_obs_space: (gym.Space) value function input space (centralized input for HAPPO, decentralized for IPPO). 13 | :param action_space: (gym.Space) action space. 14 | :param device: (torch.device) specifies the device to run on (cpu/gpu). 15 | """ 16 | 17 | def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device("cpu")): 18 | self.args=args 19 | self.device = device 20 | self.lr = args.lr 21 | self.critic_lr = args.critic_lr 22 | self.opti_eps = args.opti_eps 23 | self.weight_decay = args.weight_decay 24 | 25 | self.obs_space = obs_space 26 | self.share_obs_space = cent_obs_space 27 | self.act_space = act_space 28 | 29 | self.actor = Actor(args, self.obs_space, self.act_space, self.device) 30 | 31 | ######################################Please Note######################################### 32 | ##### We create one critic for each agent, but they are trained with same data ##### 33 | ##### and using same update setting. Therefore they have the same parameter, ##### 34 | ##### you can regard them as the same critic. ##### 35 | ########################################################################################## 36 | self.critic = Critic(args, self.share_obs_space, self.device) 37 | 38 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 39 | lr=self.lr, eps=self.opti_eps, 40 | weight_decay=self.weight_decay) 41 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 42 | lr=self.critic_lr, 43 | eps=self.opti_eps, 44 | weight_decay=self.weight_decay) 45 | 46 | def lr_decay(self, episode, episodes): 47 | """ 48 | Decay the actor and critic learning rates. 49 | :param episode: (int) current training episode. 50 | :param episodes: (int) total number of training episodes. 51 | """ 52 | update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr) 53 | update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr) 54 | 55 | def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None, 56 | deterministic=False): 57 | """ 58 | Compute actions and value function predictions for the given inputs. 59 | :param cent_obs (np.ndarray): centralized input to the critic. 60 | :param obs (np.ndarray): local agent inputs to the actor. 61 | :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor. 62 | :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic. 63 | :param masks: (np.ndarray) denotes points at which RNN states should be reset. 64 | :param available_actions: (np.ndarray) denotes which actions are available to agent 65 | (if None, all actions available) 66 | :param deterministic: (bool) whether the action should be mode of distribution or should be sampled. 67 | 68 | :return values: (torch.Tensor) value function predictions. 69 | :return actions: (torch.Tensor) actions to take. 70 | :return action_log_probs: (torch.Tensor) log probabilities of chosen actions. 71 | :return rnn_states_actor: (torch.Tensor) updated actor network RNN states. 72 | :return rnn_states_critic: (torch.Tensor) updated critic network RNN states. 73 | """ 74 | actions, action_log_probs, rnn_states_actor = self.actor(obs, 75 | rnn_states_actor, 76 | masks, 77 | available_actions, 78 | deterministic) 79 | 80 | values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks) 81 | return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic 82 | 83 | def get_values(self, cent_obs, rnn_states_critic, masks): 84 | """ 85 | Get value function predictions. 86 | :param cent_obs (np.ndarray): centralized input to the critic. 87 | :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic. 88 | :param masks: (np.ndarray) denotes points at which RNN states should be reset. 89 | 90 | :return values: (torch.Tensor) value function predictions. 91 | """ 92 | values, _ = self.critic(cent_obs, rnn_states_critic, masks) 93 | return values 94 | 95 | def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks, 96 | available_actions=None, active_masks=None): 97 | """ 98 | Get action logprobs / entropy and value function predictions for actor update. 99 | :param cent_obs (np.ndarray): centralized input to the critic. 100 | :param obs (np.ndarray): local agent inputs to the actor. 101 | :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor. 102 | :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic. 103 | :param action: (np.ndarray) actions whose log probabilites and entropy to compute. 104 | :param masks: (np.ndarray) denotes points at which RNN states should be reset. 105 | :param available_actions: (np.ndarray) denotes which actions are available to agent 106 | (if None, all actions available) 107 | :param active_masks: (torch.Tensor) denotes whether an agent is active or dead. 108 | 109 | :return values: (torch.Tensor) value function predictions. 110 | :return action_log_probs: (torch.Tensor) log probabilities of the input actions. 111 | :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs. 112 | """ 113 | 114 | action_log_probs, dist_entropy = self.actor.evaluate_actions(obs, 115 | rnn_states_actor, 116 | action, 117 | masks, 118 | available_actions, 119 | active_masks) 120 | 121 | values, _ = self.critic(cent_obs, rnn_states_critic, masks) 122 | return values, action_log_probs, dist_entropy 123 | 124 | 125 | def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False): 126 | """ 127 | Compute actions using the given inputs. 128 | :param obs (np.ndarray): local agent inputs to the actor. 129 | :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor. 130 | :param masks: (np.ndarray) denotes points at which RNN states should be reset. 131 | :param available_actions: (np.ndarray) denotes which actions are available to agent 132 | (if None, all actions available) 133 | :param deterministic: (bool) whether the action should be mode of distribution or should be sampled. 134 | """ 135 | actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic) 136 | return actions, rnn_states_actor 137 | -------------------------------------------------------------------------------- /algorithms/hatrpo_policy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from algorithms.actor_critic import Actor, Critic 3 | from utils.util import update_linear_schedule 4 | 5 | 6 | class HATRPO_Policy: 7 | """ 8 | HATRPO Policy class. Wraps actor and critic networks to compute actions and value function predictions. 9 | 10 | :param args: (argparse.Namespace) arguments containing relevant model and policy information. 11 | :param obs_space: (gym.Space) observation space. 12 | :param cent_obs_space: (gym.Space) value function input space . 13 | :param action_space: (gym.Space) action space. 14 | :param device: (torch.device) specifies the device to run on (cpu/gpu). 15 | """ 16 | 17 | def __init__(self, args, obs_space, cent_obs_space, act_space, device=torch.device("cpu")): 18 | self.args=args 19 | self.device = device 20 | self.lr = args.lr 21 | self.critic_lr = args.critic_lr 22 | self.opti_eps = args.opti_eps 23 | self.weight_decay = args.weight_decay 24 | 25 | self.obs_space = obs_space 26 | self.share_obs_space = cent_obs_space 27 | self.act_space = act_space 28 | 29 | self.actor = Actor(args, self.obs_space, self.act_space, self.device) 30 | 31 | ######################################Please Note######################################### 32 | ##### We create one critic for each agent, but they are trained with same data ##### 33 | ##### and using same update setting. Therefore they have the same parameter, ##### 34 | ##### you can regard them as the same critic. ##### 35 | ########################################################################################## 36 | self.critic = Critic(args, self.share_obs_space, self.device) 37 | 38 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 39 | lr=self.lr, eps=self.opti_eps, 40 | weight_decay=self.weight_decay) 41 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 42 | lr=self.critic_lr, 43 | eps=self.opti_eps, 44 | weight_decay=self.weight_decay) 45 | 46 | def lr_decay(self, episode, episodes): 47 | """ 48 | Decay the actor and critic learning rates. 49 | :param episode: (int) current training episode. 50 | :param episodes: (int) total number of training episodes. 51 | """ 52 | update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr) 53 | update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr) 54 | 55 | def get_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None, 56 | deterministic=False): 57 | """ 58 | Compute actions and value function predictions for the given inputs. 59 | :param cent_obs (np.ndarray): centralized input to the critic. 60 | :param obs (np.ndarray): local agent inputs to the actor. 61 | :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor. 62 | :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic. 63 | :param masks: (np.ndarray) denotes points at which RNN states should be reset. 64 | :param available_actions: (np.ndarray) denotes which actions are available to agent 65 | (if None, all actions available) 66 | :param deterministic: (bool) whether the action should be mode of distribution or should be sampled. 67 | 68 | :return values: (torch.Tensor) value function predictions. 69 | :return actions: (torch.Tensor) actions to take. 70 | :return action_log_probs: (torch.Tensor) log probabilities of chosen actions. 71 | :return rnn_states_actor: (torch.Tensor) updated actor network RNN states. 72 | :return rnn_states_critic: (torch.Tensor) updated critic network RNN states. 73 | """ 74 | actions, action_log_probs, rnn_states_actor = self.actor(obs, 75 | rnn_states_actor, 76 | masks, 77 | available_actions, 78 | deterministic) 79 | 80 | values, rnn_states_critic = self.critic(cent_obs, rnn_states_critic, masks) 81 | return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic 82 | 83 | def get_values(self, cent_obs, rnn_states_critic, masks): 84 | """ 85 | Get value function predictions. 86 | :param cent_obs (np.ndarray): centralized input to the critic. 87 | :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic. 88 | :param masks: (np.ndarray) denotes points at which RNN states should be reset. 89 | 90 | :return values: (torch.Tensor) value function predictions. 91 | """ 92 | values, _ = self.critic(cent_obs, rnn_states_critic, masks) 93 | return values 94 | 95 | def evaluate_actions(self, cent_obs, obs, rnn_states_actor, rnn_states_critic, action, masks, 96 | available_actions=None, active_masks=None): 97 | """ 98 | Get action logprobs / entropy and value function predictions for actor update. 99 | :param cent_obs (np.ndarray): centralized input to the critic. 100 | :param obs (np.ndarray): local agent inputs to the actor. 101 | :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor. 102 | :param rnn_states_critic: (np.ndarray) if critic is RNN, RNN states for critic. 103 | :param action: (np.ndarray) actions whose log probabilites and entropy to compute. 104 | :param masks: (np.ndarray) denotes points at which RNN states should be reset. 105 | :param available_actions: (np.ndarray) denotes which actions are available to agent 106 | (if None, all actions available) 107 | :param active_masks: (torch.Tensor) denotes whether an agent is active or dead. 108 | 109 | :return values: (torch.Tensor) value function predictions. 110 | :return action_log_probs: (torch.Tensor) log probabilities of the input actions. 111 | :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs. 112 | """ 113 | 114 | action_log_probs, dist_entropy , action_mu, action_std, all_probs= self.actor.evaluate_actions(obs, 115 | rnn_states_actor, 116 | action, 117 | masks, 118 | available_actions, 119 | active_masks) 120 | values, _ = self.critic(cent_obs, rnn_states_critic, masks) 121 | return values, action_log_probs, dist_entropy, action_mu, action_std, all_probs 122 | 123 | 124 | 125 | def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False): 126 | """ 127 | Compute actions using the given inputs. 128 | :param obs (np.ndarray): local agent inputs to the actor. 129 | :param rnn_states_actor: (np.ndarray) if actor is RNN, RNN states for actor. 130 | :param masks: (np.ndarray) denotes points at which RNN states should be reset. 131 | :param available_actions: (np.ndarray) denotes which actions are available to agent 132 | (if None, all actions available) 133 | :param deterministic: (bool) whether the action should be mode of distribution or should be sampled. 134 | """ 135 | actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic) 136 | return actions, rnn_states_actor 137 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/manyagent_ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /algorithms/actor_critic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from algorithms.utils.util import init, check 4 | from algorithms.utils.cnn import CNNBase 5 | from algorithms.utils.mlp import MLPBase 6 | from algorithms.utils.rnn import RNNLayer 7 | from algorithms.utils.act import ACTLayer 8 | from utils.util import get_shape_from_obs_space 9 | 10 | 11 | class Actor(nn.Module): 12 | """ 13 | Actor network class for HAPPO. Outputs actions given observations. 14 | :param args: (argparse.Namespace) arguments containing relevant model information. 15 | :param obs_space: (gym.Space) observation space. 16 | :param action_space: (gym.Space) action space. 17 | :param device: (torch.device) specifies the device to run on (cpu/gpu). 18 | """ 19 | def __init__(self, args, obs_space, action_space, device=torch.device("cpu")): 20 | super(Actor, self).__init__() 21 | self.hidden_size = args.hidden_size 22 | self.args=args 23 | self._gain = args.gain 24 | self._use_orthogonal = args.use_orthogonal 25 | self._use_policy_active_masks = args.use_policy_active_masks 26 | self._use_naive_recurrent_policy = args.use_naive_recurrent_policy 27 | self._use_recurrent_policy = args.use_recurrent_policy 28 | self._recurrent_N = args.recurrent_N 29 | self.tpdv = dict(dtype=torch.float32, device=device) 30 | 31 | obs_shape = get_shape_from_obs_space(obs_space) 32 | base = CNNBase if len(obs_shape) == 3 else MLPBase 33 | self.base = base(args, obs_shape) 34 | 35 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 36 | self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal) 37 | 38 | self.act = ACTLayer(action_space, self.hidden_size, self._use_orthogonal, self._gain, args) 39 | 40 | self.to(device) 41 | 42 | def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False): 43 | """ 44 | Compute actions from the given inputs. 45 | :param obs: (np.ndarray / torch.Tensor) observation inputs into network. 46 | :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. 47 | :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros. 48 | :param available_actions: (np.ndarray / torch.Tensor) denotes which actions are available to agent 49 | (if None, all actions available) 50 | :param deterministic: (bool) whether to sample from action distribution or return the mode. 51 | 52 | :return actions: (torch.Tensor) actions to take. 53 | :return action_log_probs: (torch.Tensor) log probabilities of taken actions. 54 | :return rnn_states: (torch.Tensor) updated RNN hidden states. 55 | """ 56 | obs = check(obs).to(**self.tpdv) 57 | rnn_states = check(rnn_states).to(**self.tpdv) 58 | masks = check(masks).to(**self.tpdv) 59 | if available_actions is not None: 60 | available_actions = check(available_actions).to(**self.tpdv) 61 | 62 | actor_features = self.base(obs) 63 | 64 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 65 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) 66 | 67 | actions, action_log_probs = self.act(actor_features, available_actions, deterministic) 68 | 69 | return actions, action_log_probs, rnn_states 70 | 71 | def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None): 72 | """ 73 | Compute log probability and entropy of given actions. 74 | :param obs: (torch.Tensor) observation inputs into network. 75 | :param action: (torch.Tensor) actions whose entropy and log probability to evaluate. 76 | :param rnn_states: (torch.Tensor) if RNN network, hidden states for RNN. 77 | :param masks: (torch.Tensor) mask tensor denoting if hidden states should be reinitialized to zeros. 78 | :param available_actions: (torch.Tensor) denotes which actions are available to agent 79 | (if None, all actions available) 80 | :param active_masks: (torch.Tensor) denotes whether an agent is active or dead. 81 | 82 | :return action_log_probs: (torch.Tensor) log probabilities of the input actions. 83 | :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs. 84 | """ 85 | obs = check(obs).to(**self.tpdv) 86 | rnn_states = check(rnn_states).to(**self.tpdv) 87 | action = check(action).to(**self.tpdv) 88 | masks = check(masks).to(**self.tpdv) 89 | if available_actions is not None: 90 | available_actions = check(available_actions).to(**self.tpdv) 91 | 92 | if active_masks is not None: 93 | active_masks = check(active_masks).to(**self.tpdv) 94 | 95 | actor_features = self.base(obs) 96 | 97 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 98 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) 99 | 100 | if self.args.algorithm_name=="hatrpo": 101 | action_log_probs, dist_entropy ,action_mu, action_std, all_probs= self.act.evaluate_actions_trpo(actor_features, 102 | action, available_actions, 103 | active_masks= 104 | active_masks if self._use_policy_active_masks 105 | else None) 106 | 107 | return action_log_probs, dist_entropy, action_mu, action_std, all_probs 108 | else: 109 | action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features, 110 | action, available_actions, 111 | active_masks= 112 | active_masks if self._use_policy_active_masks 113 | else None) 114 | 115 | return action_log_probs, dist_entropy 116 | 117 | 118 | class Critic(nn.Module): 119 | """ 120 | Critic network class for HAPPO. Outputs value function predictions given centralized input (HAPPO) or local observations (IPPO). 121 | :param args: (argparse.Namespace) arguments containing relevant model information. 122 | :param cent_obs_space: (gym.Space) (centralized) observation space. 123 | :param device: (torch.device) specifies the device to run on (cpu/gpu). 124 | """ 125 | def __init__(self, args, cent_obs_space, device=torch.device("cpu")): 126 | super(Critic, self).__init__() 127 | self.hidden_size = args.hidden_size 128 | self._use_orthogonal = args.use_orthogonal 129 | self._use_naive_recurrent_policy = args.use_naive_recurrent_policy 130 | self._use_recurrent_policy = args.use_recurrent_policy 131 | self._recurrent_N = args.recurrent_N 132 | self.tpdv = dict(dtype=torch.float32, device=device) 133 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal] 134 | 135 | cent_obs_shape = get_shape_from_obs_space(cent_obs_space) 136 | base = CNNBase if len(cent_obs_shape) == 3 else MLPBase 137 | self.base = base(args, cent_obs_shape) 138 | 139 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 140 | self.rnn = RNNLayer(self.hidden_size, self.hidden_size, self._recurrent_N, self._use_orthogonal) 141 | 142 | def init_(m): 143 | return init(m, init_method, lambda x: nn.init.constant_(x, 0)) 144 | 145 | self.v_out = init_(nn.Linear(self.hidden_size, 1)) 146 | 147 | self.to(device) 148 | 149 | def forward(self, cent_obs, rnn_states, masks): 150 | """ 151 | Compute actions from the given inputs. 152 | :param cent_obs: (np.ndarray / torch.Tensor) observation inputs into network. 153 | :param rnn_states: (np.ndarray / torch.Tensor) if RNN network, hidden states for RNN. 154 | :param masks: (np.ndarray / torch.Tensor) mask tensor denoting if RNN states should be reinitialized to zeros. 155 | 156 | :return values: (torch.Tensor) value function predictions. 157 | :return rnn_states: (torch.Tensor) updated RNN hidden states. 158 | """ 159 | cent_obs = check(cent_obs).to(**self.tpdv) 160 | rnn_states = check(rnn_states).to(**self.tpdv) 161 | masks = check(masks).to(**self.tpdv) 162 | 163 | critic_features = self.base(cent_obs) 164 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 165 | critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks) 166 | values = self.v_out(critic_features) 167 | 168 | return values, rnn_states 169 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/assets/coupled_half_cheetah.xml: -------------------------------------------------------------------------------- 1 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /algorithms/happo_trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from utils.util import get_gard_norm, huber_loss, mse_loss 5 | from utils.popart import PopArt 6 | from algorithms.utils.util import check 7 | 8 | class HAPPO(): 9 | """ 10 | Trainer class for HAPPO to update policies. 11 | :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information. 12 | :param policy: (HAPPO_Policy) policy to update. 13 | :param device: (torch.device) specifies the device to run on (cpu/gpu). 14 | """ 15 | def __init__(self, 16 | args, 17 | policy, 18 | device=torch.device("cpu")): 19 | 20 | self.device = device 21 | self.tpdv = dict(dtype=torch.float32, device=device) 22 | self.policy = policy 23 | 24 | self.clip_param = args.clip_param 25 | self.ppo_epoch = args.ppo_epoch 26 | self.num_mini_batch = args.num_mini_batch 27 | self.data_chunk_length = args.data_chunk_length 28 | self.value_loss_coef = args.value_loss_coef 29 | self.entropy_coef = args.entropy_coef 30 | self.max_grad_norm = args.max_grad_norm 31 | self.huber_delta = args.huber_delta 32 | 33 | self._use_recurrent_policy = args.use_recurrent_policy 34 | self._use_naive_recurrent = args.use_naive_recurrent_policy 35 | self._use_max_grad_norm = args.use_max_grad_norm 36 | self._use_clipped_value_loss = args.use_clipped_value_loss 37 | self._use_huber_loss = args.use_huber_loss 38 | self._use_popart = args.use_popart 39 | self._use_value_active_masks = args.use_value_active_masks 40 | self._use_policy_active_masks = args.use_policy_active_masks 41 | 42 | 43 | if self._use_popart: 44 | self.value_normalizer = PopArt(1, device=self.device) 45 | else: 46 | self.value_normalizer = None 47 | 48 | def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch): 49 | """ 50 | Calculate value function loss. 51 | :param values: (torch.Tensor) value function predictions. 52 | :param value_preds_batch: (torch.Tensor) "old" value predictions from data batch (used for value clip loss) 53 | :param return_batch: (torch.Tensor) reward to go returns. 54 | :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep. 55 | 56 | :return value_loss: (torch.Tensor) value function loss. 57 | """ 58 | if self._use_popart: 59 | value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param, 60 | self.clip_param) 61 | error_clipped = self.value_normalizer(return_batch) - value_pred_clipped 62 | error_original = self.value_normalizer(return_batch) - values 63 | else: 64 | value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param, 65 | self.clip_param) 66 | error_clipped = return_batch - value_pred_clipped 67 | error_original = return_batch - values 68 | 69 | if self._use_huber_loss: 70 | value_loss_clipped = huber_loss(error_clipped, self.huber_delta) 71 | value_loss_original = huber_loss(error_original, self.huber_delta) 72 | else: 73 | value_loss_clipped = mse_loss(error_clipped) 74 | value_loss_original = mse_loss(error_original) 75 | 76 | if self._use_clipped_value_loss: 77 | value_loss = torch.max(value_loss_original, value_loss_clipped) 78 | else: 79 | value_loss = value_loss_original 80 | 81 | if self._use_value_active_masks: 82 | value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum() 83 | else: 84 | value_loss = value_loss.mean() 85 | 86 | return value_loss 87 | 88 | def ppo_update(self, sample, update_actor=True): 89 | """ 90 | Update actor and critic networks. 91 | :param sample: (Tuple) contains data batch with which to update networks. 92 | :update_actor: (bool) whether to update actor network. 93 | 94 | :return value_loss: (torch.Tensor) value function loss. 95 | :return critic_grad_norm: (torch.Tensor) gradient norm from critic update. 96 | ;return policy_loss: (torch.Tensor) actor(policy) loss value. 97 | :return dist_entropy: (torch.Tensor) action entropies. 98 | :return actor_grad_norm: (torch.Tensor) gradient norm from actor update. 99 | :return imp_weights: (torch.Tensor) importance sampling weights. 100 | """ 101 | share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \ 102 | value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \ 103 | adv_targ, available_actions_batch, factor_batch = sample 104 | 105 | 106 | 107 | old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv) 108 | adv_targ = check(adv_targ).to(**self.tpdv) 109 | 110 | 111 | value_preds_batch = check(value_preds_batch).to(**self.tpdv) 112 | return_batch = check(return_batch).to(**self.tpdv) 113 | 114 | 115 | active_masks_batch = check(active_masks_batch).to(**self.tpdv) 116 | 117 | factor_batch = check(factor_batch).to(**self.tpdv) 118 | # Reshape to do in a single forward pass for all steps 119 | values, action_log_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch, 120 | obs_batch, 121 | rnn_states_batch, 122 | rnn_states_critic_batch, 123 | actions_batch, 124 | masks_batch, 125 | available_actions_batch, 126 | active_masks_batch) 127 | # actor update 128 | imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch) 129 | 130 | surr1 = imp_weights * adv_targ 131 | surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ 132 | 133 | if self._use_policy_active_masks: 134 | policy_action_loss = (-torch.sum(factor_batch * torch.min(surr1, surr2), 135 | dim=-1, 136 | keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum() 137 | else: 138 | policy_action_loss = -torch.sum(factor_batch * torch.min(surr1, surr2), dim=-1, keepdim=True).mean() 139 | 140 | policy_loss = policy_action_loss 141 | 142 | self.policy.actor_optimizer.zero_grad() 143 | 144 | if update_actor: 145 | (policy_loss - dist_entropy * self.entropy_coef).backward() 146 | 147 | if self._use_max_grad_norm: 148 | actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm) 149 | else: 150 | actor_grad_norm = get_gard_norm(self.policy.actor.parameters()) 151 | 152 | self.policy.actor_optimizer.step() 153 | 154 | value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch) 155 | 156 | self.policy.critic_optimizer.zero_grad() 157 | 158 | (value_loss * self.value_loss_coef).backward() 159 | 160 | if self._use_max_grad_norm: 161 | critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm) 162 | else: 163 | critic_grad_norm = get_gard_norm(self.policy.critic.parameters()) 164 | 165 | self.policy.critic_optimizer.step() 166 | 167 | return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights 168 | 169 | def train(self, buffer, update_actor=True): 170 | """ 171 | Perform a training update using minibatch GD. 172 | :param buffer: (SharedReplayBuffer) buffer containing training data. 173 | :param update_actor: (bool) whether to update actor network. 174 | 175 | :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc). 176 | """ 177 | if self._use_popart: 178 | advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1]) 179 | else: 180 | advantages = buffer.returns[:-1] - buffer.value_preds[:-1] 181 | 182 | advantages_copy = advantages.copy() 183 | advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan 184 | mean_advantages = np.nanmean(advantages_copy) 185 | std_advantages = np.nanstd(advantages_copy) 186 | advantages = (advantages - mean_advantages) / (std_advantages + 1e-5) 187 | 188 | train_info = {} 189 | 190 | train_info['value_loss'] = 0 191 | train_info['policy_loss'] = 0 192 | train_info['dist_entropy'] = 0 193 | train_info['actor_grad_norm'] = 0 194 | train_info['critic_grad_norm'] = 0 195 | train_info['ratio'] = 0 196 | 197 | for _ in range(self.ppo_epoch): 198 | if self._use_recurrent_policy: 199 | data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length) 200 | elif self._use_naive_recurrent: 201 | data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch) 202 | else: 203 | data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch) 204 | 205 | for sample in data_generator: 206 | value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights = self.ppo_update(sample, update_actor=update_actor) 207 | 208 | train_info['value_loss'] += value_loss.item() 209 | train_info['policy_loss'] += policy_loss.item() 210 | train_info['dist_entropy'] += dist_entropy.item() 211 | train_info['actor_grad_norm'] += actor_grad_norm 212 | train_info['critic_grad_norm'] += critic_grad_norm 213 | train_info['ratio'] += imp_weights.mean() 214 | 215 | num_updates = self.ppo_epoch * self.num_mini_batch 216 | 217 | for k in train_info.keys(): 218 | train_info[k] /= num_updates 219 | 220 | return train_info 221 | 222 | def prep_training(self): 223 | self.policy.actor.train() 224 | self.policy.critic.train() 225 | 226 | def prep_rollout(self): 227 | self.policy.actor.eval() 228 | self.policy.critic.eval() 229 | -------------------------------------------------------------------------------- /runners/separated/mujoco_runner.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from functools import reduce 4 | import torch 5 | from runners.separated.base_runner import Runner 6 | 7 | 8 | def _t2n(x): 9 | return x.detach().cpu().numpy() 10 | 11 | 12 | class MujocoRunner(Runner): 13 | """Runner class to perform training, evaluation. and data collection for SMAC. See parent class for details.""" 14 | 15 | def __init__(self, config): 16 | super(MujocoRunner, self).__init__(config) 17 | 18 | def run(self): 19 | self.warmup() 20 | 21 | start = time.time() 22 | episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads 23 | 24 | train_episode_rewards = [0 for _ in range(self.n_rollout_threads)] 25 | 26 | for episode in range(episodes): 27 | if self.use_linear_lr_decay: 28 | self.trainer.policy.lr_decay(episode, episodes) 29 | 30 | done_episodes_rewards = [] 31 | 32 | for step in range(self.episode_length): 33 | # Sample actions 34 | values, actions, action_log_probs, rnn_states, rnn_states_critic = self.collect(step) 35 | 36 | # Obser reward and next obs 37 | obs, share_obs, rewards, dones, infos, _ = self.envs.step(actions) 38 | 39 | dones_env = np.all(dones, axis=1) 40 | reward_env = np.mean(rewards, axis=1).flatten() 41 | train_episode_rewards += reward_env 42 | for t in range(self.n_rollout_threads): 43 | if dones_env[t]: 44 | done_episodes_rewards.append(train_episode_rewards[t]) 45 | train_episode_rewards[t] = 0 46 | 47 | data = obs, share_obs, rewards, dones, infos, \ 48 | values, actions, action_log_probs, \ 49 | rnn_states, rnn_states_critic 50 | 51 | # insert data into buffer 52 | self.insert(data) 53 | 54 | # compute return and update network 55 | self.compute() 56 | train_infos = self.train() 57 | 58 | # post process 59 | total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads 60 | # save model 61 | if (episode % self.save_interval == 0 or episode == episodes - 1): 62 | self.save() 63 | 64 | # log information 65 | if episode % self.log_interval == 0: 66 | end = time.time() 67 | print("\n Scenario {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n" 68 | .format(self.all_args.scenario, 69 | self.algorithm_name, 70 | self.experiment_name, 71 | episode, 72 | episodes, 73 | total_num_steps, 74 | self.num_env_steps, 75 | int(total_num_steps / (end - start)))) 76 | 77 | self.log_train(train_infos, total_num_steps) 78 | 79 | if len(done_episodes_rewards) > 0: 80 | aver_episode_rewards = np.mean(done_episodes_rewards) 81 | print("some episodes done, average rewards: ", aver_episode_rewards) 82 | self.writter.add_scalars("train_episode_rewards", {"aver_rewards": aver_episode_rewards}, 83 | total_num_steps) 84 | 85 | # eval 86 | if episode % self.eval_interval == 0 and self.use_eval: 87 | self.eval(total_num_steps) 88 | 89 | def warmup(self): 90 | # reset env 91 | obs, share_obs, _ = self.envs.reset() 92 | # replay buffer 93 | if not self.use_centralized_V: 94 | share_obs = obs 95 | 96 | for agent_id in range(self.num_agents): 97 | self.buffer[agent_id].share_obs[0] = share_obs[:, agent_id].copy() 98 | self.buffer[agent_id].obs[0] = obs[:, agent_id].copy() 99 | 100 | @torch.no_grad() 101 | def collect(self, step): 102 | value_collector = [] 103 | action_collector = [] 104 | action_log_prob_collector = [] 105 | rnn_state_collector = [] 106 | rnn_state_critic_collector = [] 107 | for agent_id in range(self.num_agents): 108 | self.trainer[agent_id].prep_rollout() 109 | value, action, action_log_prob, rnn_state, rnn_state_critic \ 110 | = self.trainer[agent_id].policy.get_actions(self.buffer[agent_id].share_obs[step], 111 | self.buffer[agent_id].obs[step], 112 | self.buffer[agent_id].rnn_states[step], 113 | self.buffer[agent_id].rnn_states_critic[step], 114 | self.buffer[agent_id].masks[step]) 115 | value_collector.append(_t2n(value)) 116 | action_collector.append(_t2n(action)) 117 | action_log_prob_collector.append(_t2n(action_log_prob)) 118 | rnn_state_collector.append(_t2n(rnn_state)) 119 | rnn_state_critic_collector.append(_t2n(rnn_state_critic)) 120 | # [self.envs, agents, dim] 121 | values = np.array(value_collector).transpose(1, 0, 2) 122 | actions = np.array(action_collector).transpose(1, 0, 2) 123 | action_log_probs = np.array(action_log_prob_collector).transpose(1, 0, 2) 124 | rnn_states = np.array(rnn_state_collector).transpose(1, 0, 2, 3) 125 | rnn_states_critic = np.array(rnn_state_critic_collector).transpose(1, 0, 2, 3) 126 | 127 | return values, actions, action_log_probs, rnn_states, rnn_states_critic 128 | 129 | def insert(self, data): 130 | obs, share_obs, rewards, dones, infos, \ 131 | values, actions, action_log_probs, rnn_states, rnn_states_critic = data 132 | 133 | dones_env = np.all(dones, axis=1) 134 | 135 | rnn_states[dones_env == True] = np.zeros( 136 | ((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32) 137 | rnn_states_critic[dones_env == True] = np.zeros( 138 | ((dones_env == True).sum(), self.num_agents, *self.buffer[0].rnn_states_critic.shape[2:]), dtype=np.float32) 139 | 140 | masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32) 141 | masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32) 142 | 143 | active_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32) 144 | active_masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32) 145 | active_masks[dones_env == True] = np.ones(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32) 146 | 147 | if not self.use_centralized_V: 148 | share_obs = obs 149 | 150 | for agent_id in range(self.num_agents): 151 | self.buffer[agent_id].insert(share_obs[:, agent_id], obs[:, agent_id], rnn_states[:, agent_id], 152 | rnn_states_critic[:, agent_id], actions[:, agent_id], 153 | action_log_probs[:, agent_id], 154 | values[:, agent_id], rewards[:, agent_id], masks[:, agent_id], None, 155 | active_masks[:, agent_id], None) 156 | 157 | def log_train(self, train_infos, total_num_steps): 158 | print("average_step_rewards is {}.".format(np.mean(self.buffer[0].rewards))) 159 | for agent_id in range(self.num_agents): 160 | train_infos[agent_id]["average_step_rewards"] = np.mean(self.buffer[agent_id].rewards) 161 | for k, v in train_infos[agent_id].items(): 162 | agent_k = "agent%i/" % agent_id + k 163 | self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps) 164 | 165 | @torch.no_grad() 166 | def eval(self, total_num_steps): 167 | eval_episode = 0 168 | eval_episode_rewards = [] 169 | one_episode_rewards = [] 170 | for eval_i in range(self.n_eval_rollout_threads): 171 | one_episode_rewards.append([]) 172 | eval_episode_rewards.append([]) 173 | 174 | eval_obs, eval_share_obs, _ = self.eval_envs.reset() 175 | 176 | eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), 177 | dtype=np.float32) 178 | eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32) 179 | 180 | while True: 181 | eval_actions_collector = [] 182 | eval_rnn_states_collector = [] 183 | for agent_id in range(self.num_agents): 184 | self.trainer[agent_id].prep_rollout() 185 | eval_actions, temp_rnn_state = \ 186 | self.trainer[agent_id].policy.act(eval_obs[:, agent_id], 187 | eval_rnn_states[:, agent_id], 188 | eval_masks[:, agent_id], 189 | deterministic=True) 190 | eval_rnn_states[:, agent_id] = _t2n(temp_rnn_state) 191 | eval_actions_collector.append(_t2n(eval_actions)) 192 | 193 | eval_actions = np.array(eval_actions_collector).transpose(1, 0, 2) 194 | 195 | # Obser reward and next obs 196 | eval_obs, eval_share_obs, eval_rewards, eval_dones, eval_infos, _ = self.eval_envs.step( 197 | eval_actions) 198 | for eval_i in range(self.n_eval_rollout_threads): 199 | one_episode_rewards[eval_i].append(eval_rewards[eval_i]) 200 | 201 | eval_dones_env = np.all(eval_dones, axis=1) 202 | 203 | eval_rnn_states[eval_dones_env == True] = np.zeros( 204 | ((eval_dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32) 205 | 206 | eval_masks = np.ones((self.all_args.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32) 207 | eval_masks[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, 1), 208 | dtype=np.float32) 209 | 210 | for eval_i in range(self.n_eval_rollout_threads): 211 | if eval_dones_env[eval_i]: 212 | eval_episode += 1 213 | eval_episode_rewards[eval_i].append(np.sum(one_episode_rewards[eval_i], axis=0)) 214 | one_episode_rewards[eval_i] = [] 215 | 216 | if eval_episode >= self.all_args.eval_episodes: 217 | eval_episode_rewards = np.concatenate(eval_episode_rewards) 218 | eval_env_infos = {'eval_average_episode_rewards': eval_episode_rewards, 219 | 'eval_max_episode_rewards': [np.max(eval_episode_rewards)]} 220 | self.log_env(eval_env_infos, total_num_steps) 221 | print("eval_average_episode_rewards is {}.".format(np.mean(eval_episode_rewards))) 222 | break 223 | -------------------------------------------------------------------------------- /algorithms/utils/act.py: -------------------------------------------------------------------------------- 1 | from .distributions import Bernoulli, Categorical, DiagGaussian 2 | import torch 3 | import torch.nn as nn 4 | 5 | class ACTLayer(nn.Module): 6 | """ 7 | MLP Module to compute actions. 8 | :param action_space: (gym.Space) action space. 9 | :param inputs_dim: (int) dimension of network input. 10 | :param use_orthogonal: (bool) whether to use orthogonal initialization. 11 | :param gain: (float) gain of the output layer of the network. 12 | """ 13 | def __init__(self, action_space, inputs_dim, use_orthogonal, gain, args=None): 14 | super(ACTLayer, self).__init__() 15 | self.mixed_action = False 16 | self.multi_discrete = False 17 | self.action_type = action_space.__class__.__name__ 18 | if action_space.__class__.__name__ == "Discrete": 19 | action_dim = action_space.n 20 | self.action_out = Categorical(inputs_dim, action_dim, use_orthogonal, gain) 21 | elif action_space.__class__.__name__ == "Box": 22 | action_dim = action_space.shape[0] 23 | self.action_out = DiagGaussian(inputs_dim, action_dim, use_orthogonal, gain, args) 24 | elif action_space.__class__.__name__ == "MultiBinary": 25 | action_dim = action_space.shape[0] 26 | self.action_out = Bernoulli(inputs_dim, action_dim, use_orthogonal, gain) 27 | elif action_space.__class__.__name__ == "MultiDiscrete": 28 | self.multi_discrete = True 29 | action_dims = action_space.high - action_space.low + 1 30 | self.action_outs = [] 31 | for action_dim in action_dims: 32 | self.action_outs.append(Categorical(inputs_dim, action_dim, use_orthogonal, gain)) 33 | self.action_outs = nn.ModuleList(self.action_outs) 34 | else: # discrete + continous 35 | self.mixed_action = True 36 | continous_dim = action_space[0].shape[0] 37 | discrete_dim = action_space[1].n 38 | self.action_outs = nn.ModuleList([DiagGaussian(inputs_dim, continous_dim, use_orthogonal, gain, args), 39 | Categorical(inputs_dim, discrete_dim, use_orthogonal, gain)]) 40 | 41 | def forward(self, x, available_actions=None, deterministic=False): 42 | """ 43 | Compute actions and action logprobs from given input. 44 | :param x: (torch.Tensor) input to network. 45 | :param available_actions: (torch.Tensor) denotes which actions are available to agent 46 | (if None, all actions available) 47 | :param deterministic: (bool) whether to sample from action distribution or return the mode. 48 | 49 | :return actions: (torch.Tensor) actions to take. 50 | :return action_log_probs: (torch.Tensor) log probabilities of taken actions. 51 | """ 52 | if self.mixed_action : 53 | actions = [] 54 | action_log_probs = [] 55 | for action_out in self.action_outs: 56 | action_logit = action_out(x) 57 | action = action_logit.mode() if deterministic else action_logit.sample() 58 | action_log_prob = action_logit.log_probs(action) 59 | actions.append(action.float()) 60 | action_log_probs.append(action_log_prob) 61 | 62 | actions = torch.cat(actions, -1) 63 | action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True) 64 | 65 | elif self.multi_discrete: 66 | actions = [] 67 | action_log_probs = [] 68 | for action_out in self.action_outs: 69 | action_logit = action_out(x) 70 | action = action_logit.mode() if deterministic else action_logit.sample() 71 | action_log_prob = action_logit.log_probs(action) 72 | actions.append(action) 73 | action_log_probs.append(action_log_prob) 74 | 75 | actions = torch.cat(actions, -1) 76 | action_log_probs = torch.cat(action_log_probs, -1) 77 | 78 | else: 79 | action_logits = self.action_out(x, available_actions) 80 | actions = action_logits.mode() if deterministic else action_logits.sample() 81 | action_log_probs = action_logits.log_probs(actions) 82 | 83 | return actions, action_log_probs 84 | 85 | def get_probs(self, x, available_actions=None): 86 | """ 87 | Compute action probabilities from inputs. 88 | :param x: (torch.Tensor) input to network. 89 | :param available_actions: (torch.Tensor) denotes which actions are available to agent 90 | (if None, all actions available) 91 | 92 | :return action_probs: (torch.Tensor) 93 | """ 94 | if self.mixed_action or self.multi_discrete: 95 | action_probs = [] 96 | for action_out in self.action_outs: 97 | action_logit = action_out(x) 98 | action_prob = action_logit.probs 99 | action_probs.append(action_prob) 100 | action_probs = torch.cat(action_probs, -1) 101 | else: 102 | action_logits = self.action_out(x, available_actions) 103 | action_probs = action_logits.probs 104 | 105 | return action_probs 106 | 107 | def evaluate_actions(self, x, action, available_actions=None, active_masks=None): 108 | """ 109 | Compute log probability and entropy of given actions. 110 | :param x: (torch.Tensor) input to network. 111 | :param action: (torch.Tensor) actions whose entropy and log probability to evaluate. 112 | :param available_actions: (torch.Tensor) denotes which actions are available to agent 113 | (if None, all actions available) 114 | :param active_masks: (torch.Tensor) denotes whether an agent is active or dead. 115 | 116 | :return action_log_probs: (torch.Tensor) log probabilities of the input actions. 117 | :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs. 118 | """ 119 | if self.mixed_action: 120 | a, b = action.split((2, 1), -1) 121 | b = b.long() 122 | action = [a, b] 123 | action_log_probs = [] 124 | dist_entropy = [] 125 | for action_out, act in zip(self.action_outs, action): 126 | action_logit = action_out(x) 127 | action_log_probs.append(action_logit.log_probs(act)) 128 | if active_masks is not None: 129 | if len(action_logit.entropy().shape) == len(active_masks.shape): 130 | dist_entropy.append((action_logit.entropy() * active_masks).sum()/active_masks.sum()) 131 | else: 132 | dist_entropy.append((action_logit.entropy() * active_masks.squeeze(-1)).sum()/active_masks.sum()) 133 | else: 134 | dist_entropy.append(action_logit.entropy().mean()) 135 | 136 | action_log_probs = torch.sum(torch.cat(action_log_probs, -1), -1, keepdim=True) 137 | dist_entropy = dist_entropy[0] / 2.0 + dist_entropy[1] / 0.98 138 | 139 | elif self.multi_discrete: 140 | action = torch.transpose(action, 0, 1) 141 | action_log_probs = [] 142 | dist_entropy = [] 143 | for action_out, act in zip(self.action_outs, action): 144 | action_logit = action_out(x) 145 | action_log_probs.append(action_logit.log_probs(act)) 146 | if active_masks is not None: 147 | dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()) 148 | else: 149 | dist_entropy.append(action_logit.entropy().mean()) 150 | 151 | action_log_probs = torch.cat(action_log_probs, -1) 152 | dist_entropy = torch.tensor(dist_entropy).mean() 153 | 154 | else: 155 | action_logits = self.action_out(x, available_actions) 156 | action_log_probs = action_logits.log_probs(action) 157 | if active_masks is not None: 158 | if self.action_type=="Discrete": 159 | dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum() 160 | else: 161 | dist_entropy = (action_logits.entropy()*active_masks).sum()/active_masks.sum() 162 | else: 163 | dist_entropy = action_logits.entropy().mean() 164 | 165 | return action_log_probs, dist_entropy 166 | 167 | def evaluate_actions_trpo(self, x, action, available_actions=None, active_masks=None): 168 | """ 169 | Compute log probability and entropy of given actions. 170 | :param x: (torch.Tensor) input to network. 171 | :param action: (torch.Tensor) actions whose entropy and log probability to evaluate. 172 | :param available_actions: (torch.Tensor) denotes which actions are available to agent 173 | (if None, all actions available) 174 | :param active_masks: (torch.Tensor) denotes whether an agent is active or dead. 175 | 176 | :return action_log_probs: (torch.Tensor) log probabilities of the input actions. 177 | :return dist_entropy: (torch.Tensor) action distribution entropy for the given inputs. 178 | """ 179 | 180 | if self.multi_discrete: 181 | action = torch.transpose(action, 0, 1) 182 | action_log_probs = [] 183 | dist_entropy = [] 184 | mu_collector = [] 185 | std_collector = [] 186 | probs_collector = [] 187 | for action_out, act in zip(self.action_outs, action): 188 | action_logit = action_out(x) 189 | mu = action_logit.mean 190 | std = action_logit.stddev 191 | action_log_probs.append(action_logit.log_probs(act)) 192 | mu_collector.append(mu) 193 | std_collector.append(std) 194 | probs_collector.append(action_logit.logits) 195 | if active_masks is not None: 196 | dist_entropy.append((action_logit.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum()) 197 | else: 198 | dist_entropy.append(action_logit.entropy().mean()) 199 | action_mu = torch.cat(mu_collector,-1) 200 | action_std = torch.cat(std_collector,-1) 201 | all_probs = torch.cat(probs_collector,-1) 202 | action_log_probs = torch.cat(action_log_probs, -1) 203 | dist_entropy = torch.tensor(dist_entropy).mean() 204 | 205 | else: 206 | action_logits = self.action_out(x, available_actions) 207 | action_mu = action_logits.mean 208 | action_std = action_logits.stddev 209 | action_log_probs = action_logits.log_probs(action) 210 | if self.action_type=="Discrete": 211 | all_probs = action_logits.logits 212 | else: 213 | all_probs = None 214 | if active_masks is not None: 215 | if self.action_type=="Discrete": 216 | dist_entropy = (action_logits.entropy()*active_masks.squeeze(-1)).sum()/active_masks.sum() 217 | else: 218 | dist_entropy = (action_logits.entropy()*active_masks).sum()/active_masks.sum() 219 | else: 220 | dist_entropy = action_logits.entropy().mean() 221 | 222 | return action_log_probs, dist_entropy, action_mu, action_std, all_probs 223 | -------------------------------------------------------------------------------- /envs/ma_mujoco/multiagent_mujoco/mujoco_multi.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import gym 3 | from gym.spaces import Box 4 | from gym.wrappers import TimeLimit 5 | import numpy as np 6 | 7 | from .multiagentenv import MultiAgentEnv 8 | from .manyagent_swimmer import ManyAgentSwimmerEnv 9 | from .obsk import get_joints_at_kdist, get_parts_and_edges, build_obs 10 | 11 | 12 | def env_fn(env, **kwargs) -> MultiAgentEnv: # TODO: this may be a more complex function 13 | # env_args = kwargs.get("env_args", {}) 14 | return env(**kwargs) 15 | 16 | env_REGISTRY = {} 17 | env_REGISTRY["manyagent_swimmer"] = partial(env_fn, env=ManyAgentSwimmerEnv) 18 | 19 | 20 | # using code from https://github.com/ikostrikov/pytorch-ddpg-naf 21 | class NormalizedActions(gym.ActionWrapper): 22 | 23 | def _action(self, action): 24 | action = (action + 1) / 2 25 | action *= (self.action_space.high - self.action_space.low) 26 | action += self.action_space.low 27 | return action 28 | 29 | def action(self, action_): 30 | return self._action(action_) 31 | 32 | def _reverse_action(self, action): 33 | action -= self.action_space.low 34 | action /= (self.action_space.high - self.action_space.low) 35 | action = action * 2 - 1 36 | return action 37 | 38 | 39 | class MujocoMulti(MultiAgentEnv): 40 | 41 | def __init__(self, batch_size=None, **kwargs): 42 | super().__init__(batch_size, **kwargs) 43 | self.scenario = kwargs["env_args"]["scenario"] # e.g. Ant-v2 44 | self.agent_conf = kwargs["env_args"]["agent_conf"] # e.g. '2x3' 45 | 46 | self.agent_partitions, self.mujoco_edges, self.mujoco_globals = get_parts_and_edges(self.scenario, 47 | self.agent_conf) 48 | 49 | self.n_agents = len(self.agent_partitions) 50 | self.n_actions = max([len(l) for l in self.agent_partitions]) 51 | self.obs_add_global_pos = kwargs["env_args"].get("obs_add_global_pos", False) 52 | 53 | self.agent_obsk = kwargs["env_args"].get("agent_obsk", 54 | None) # if None, fully observable else k>=0 implies observe nearest k agents or joints 55 | self.agent_obsk_agents = kwargs["env_args"].get("agent_obsk_agents", 56 | False) # observe full k nearest agents (True) or just single joints (False) 57 | 58 | if self.agent_obsk is not None: 59 | self.k_categories_label = kwargs["env_args"].get("k_categories") 60 | if self.k_categories_label is None: 61 | if self.scenario in ["Ant-v2", "manyagent_ant"]: 62 | self.k_categories_label = "qpos,qvel,cfrc_ext|qpos" 63 | elif self.scenario in ["Humanoid-v2", "HumanoidStandup-v2"]: 64 | self.k_categories_label = "qpos,qvel,cfrc_ext,cvel,cinert,qfrc_actuator|qpos" 65 | elif self.scenario in ["Reacher-v2"]: 66 | self.k_categories_label = "qpos,qvel,fingertip_dist|qpos" 67 | elif self.scenario in ["coupled_half_cheetah"]: 68 | self.k_categories_label = "qpos,qvel,ten_J,ten_length,ten_velocity|" 69 | else: 70 | self.k_categories_label = "qpos,qvel|qpos" 71 | 72 | k_split = self.k_categories_label.split("|") 73 | self.k_categories = [k_split[k if k < len(k_split) else -1].split(",") for k in range(self.agent_obsk + 1)] 74 | 75 | self.global_categories_label = kwargs["env_args"].get("global_categories") 76 | self.global_categories = self.global_categories_label.split( 77 | ",") if self.global_categories_label is not None else [] 78 | 79 | if self.agent_obsk is not None: 80 | self.k_dicts = [get_joints_at_kdist(agent_id, 81 | self.agent_partitions, 82 | self.mujoco_edges, 83 | k=self.agent_obsk, 84 | kagents=False, ) for agent_id in range(self.n_agents)] 85 | 86 | # load scenario from script 87 | self.episode_limit = self.args.episode_limit 88 | 89 | self.env_version = kwargs["env_args"].get("env_version", 2) 90 | if self.env_version == 2: 91 | try: 92 | self.wrapped_env = NormalizedActions(gym.make(self.scenario)) 93 | except gym.error.Error: 94 | self.wrapped_env = NormalizedActions( 95 | TimeLimit(partial(env_REGISTRY[self.scenario], **kwargs["env_args"])(), 96 | max_episode_steps=self.episode_limit)) 97 | else: 98 | assert False, "not implemented!" 99 | self.timelimit_env = self.wrapped_env.env 100 | self.timelimit_env._max_episode_steps = self.episode_limit 101 | self.env = self.timelimit_env.env 102 | self.timelimit_env.reset() 103 | self.obs_size = self.get_obs_size() 104 | self.share_obs_size = self.get_state_size() 105 | 106 | # COMPATIBILITY 107 | self.n = self.n_agents 108 | # self.observation_space = [Box(low=np.array([-10]*self.n_agents), high=np.array([10]*self.n_agents)) for _ in range(self.n_agents)] 109 | self.observation_space = [Box(low=-10, high=10, shape=(self.obs_size,)) for _ in range(self.n_agents)] 110 | self.share_observation_space = [Box(low=-10, high=10, shape=(self.share_obs_size,)) for _ in 111 | range(self.n_agents)] 112 | 113 | acdims = [len(ap) for ap in self.agent_partitions] 114 | self.action_space = tuple([Box(self.env.action_space.low[sum(acdims[:a]):sum(acdims[:a + 1])], 115 | self.env.action_space.high[sum(acdims[:a]):sum(acdims[:a + 1])]) for a in 116 | range(self.n_agents)]) 117 | 118 | pass 119 | 120 | def step(self, actions): 121 | 122 | # need to remove dummy actions that arise due to unequal action vector sizes across agents 123 | flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)]) 124 | obs_n, reward_n, done_n, info_n = self.wrapped_env.step(flat_actions) 125 | self.steps += 1 126 | 127 | info = {} 128 | info.update(info_n) 129 | 130 | # if done_n: 131 | # if self.steps < self.episode_limit: 132 | # info["episode_limit"] = False # the next state will be masked out 133 | # else: 134 | # info["episode_limit"] = True # the next state will not be masked out 135 | if done_n: 136 | if self.steps < self.episode_limit: 137 | info["bad_transition"] = False # the next state will be masked out 138 | else: 139 | info["bad_transition"] = True # the next state will not be masked out 140 | 141 | # return reward_n, done_n, info 142 | rewards = [[reward_n]] * self.n_agents 143 | dones = [done_n] * self.n_agents 144 | infos = [info for _ in range(self.n_agents)] 145 | return self.get_obs(), self.get_state(), rewards, dones, infos, self.get_avail_actions() 146 | 147 | def get_obs(self): 148 | """ Returns all agent observat3ions in a list """ 149 | state = self.env._get_obs() 150 | obs_n = [] 151 | for a in range(self.n_agents): 152 | agent_id_feats = np.zeros(self.n_agents, dtype=np.float32) 153 | agent_id_feats[a] = 1.0 154 | # obs_n.append(self.get_obs_agent(a)) 155 | # obs_n.append(np.concatenate([state, self.get_obs_agent(a), agent_id_feats])) 156 | # obs_n.append(np.concatenate([self.get_obs_agent(a), agent_id_feats])) 157 | obs_i = np.concatenate([state, agent_id_feats]) 158 | obs_i = (obs_i - np.mean(obs_i)) / np.std(obs_i) 159 | obs_n.append(obs_i) 160 | return obs_n 161 | 162 | def get_obs_agent(self, agent_id): 163 | if self.agent_obsk is None: 164 | return self.env._get_obs() 165 | else: 166 | # return build_obs(self.env, 167 | # self.k_dicts[agent_id], 168 | # self.k_categories, 169 | # self.mujoco_globals, 170 | # self.global_categories, 171 | # vec_len=getattr(self, "obs_size", None)) 172 | return build_obs(self.env, 173 | self.k_dicts[agent_id], 174 | self.k_categories, 175 | self.mujoco_globals, 176 | self.global_categories) 177 | 178 | def get_obs_size(self): 179 | """ Returns the shape of the observation """ 180 | if self.agent_obsk is None: 181 | return self.get_obs_agent(0).size 182 | else: 183 | return len(self.get_obs()[0]) 184 | # return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)]) 185 | 186 | def get_state(self, team=None): 187 | # TODO: May want global states for different teams (so cannot see what the other team is communicating e.g.) 188 | state = self.env._get_obs() 189 | share_obs = [] 190 | for a in range(self.n_agents): 191 | agent_id_feats = np.zeros(self.n_agents, dtype=np.float32) 192 | agent_id_feats[a] = 1.0 193 | # share_obs.append(np.concatenate([state, self.get_obs_agent(a), agent_id_feats])) 194 | state_i = np.concatenate([state, agent_id_feats]) 195 | state_i = (state_i - np.mean(state_i)) / np.std(state_i) 196 | share_obs.append(state_i) 197 | return share_obs 198 | 199 | def get_state_size(self): 200 | """ Returns the shape of the state""" 201 | return len(self.get_state()[0]) 202 | 203 | def get_avail_actions(self): # all actions are always available 204 | return np.ones(shape=(self.n_agents, self.n_actions,)) 205 | 206 | def get_avail_agent_actions(self, agent_id): 207 | """ Returns the available actions for agent_id """ 208 | return np.ones(shape=(self.n_actions,)) 209 | 210 | def get_total_actions(self): 211 | """ Returns the total number of actions an agent could ever take """ 212 | return self.n_actions # CAREFUL! - for continuous dims, this is action space dim rather 213 | # return self.env.action_space.shape[0] 214 | 215 | def get_stats(self): 216 | return {} 217 | 218 | # TODO: Temp hack 219 | def get_agg_stats(self, stats): 220 | return {} 221 | 222 | def reset(self, **kwargs): 223 | """ Returns initial observations and states""" 224 | self.steps = 0 225 | self.timelimit_env.reset() 226 | return self.get_obs(), self.get_state(), self.get_avail_actions() 227 | 228 | def render(self, **kwargs): 229 | self.env.render(**kwargs) 230 | 231 | def close(self): 232 | pass 233 | 234 | def seed(self, args): 235 | pass 236 | 237 | def get_env_info(self): 238 | 239 | env_info = {"state_shape": self.get_state_size(), 240 | "obs_shape": self.get_obs_size(), 241 | "n_actions": self.get_total_actions(), 242 | "n_agents": self.n_agents, 243 | "episode_limit": self.episode_limit, 244 | "action_spaces": self.action_space, 245 | "actions_dtype": np.float32, 246 | "normalise_actions": False 247 | } 248 | return env_info 249 | -------------------------------------------------------------------------------- /runners/separated/base_runner.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import os 4 | import numpy as np 5 | from itertools import chain 6 | import torch 7 | from tensorboardX import SummaryWriter 8 | from utils.separated_buffer import SeparatedReplayBuffer 9 | from utils.util import update_linear_schedule 10 | 11 | def _t2n(x): 12 | return x.detach().cpu().numpy() 13 | 14 | class Runner(object): 15 | def __init__(self, config): 16 | 17 | self.all_args = config['all_args'] 18 | self.envs = config['envs'] 19 | self.eval_envs = config['eval_envs'] 20 | self.device = config['device'] 21 | self.num_agents = config['num_agents'] 22 | 23 | # parameters 24 | self.env_name = self.all_args.env_name 25 | self.algorithm_name = self.all_args.algorithm_name 26 | self.experiment_name = self.all_args.experiment_name 27 | self.use_centralized_V = self.all_args.use_centralized_V 28 | self.use_obs_instead_of_state = self.all_args.use_obs_instead_of_state 29 | self.num_env_steps = self.all_args.num_env_steps 30 | self.episode_length = self.all_args.episode_length 31 | self.n_rollout_threads = self.all_args.n_rollout_threads 32 | self.n_eval_rollout_threads = self.all_args.n_eval_rollout_threads 33 | self.use_linear_lr_decay = self.all_args.use_linear_lr_decay 34 | self.hidden_size = self.all_args.hidden_size 35 | self.use_render = self.all_args.use_render 36 | self.recurrent_N = self.all_args.recurrent_N 37 | self.use_single_network = self.all_args.use_single_network 38 | # interval 39 | self.save_interval = self.all_args.save_interval 40 | self.use_eval = self.all_args.use_eval 41 | self.eval_interval = self.all_args.eval_interval 42 | self.log_interval = self.all_args.log_interval 43 | 44 | # dir 45 | self.model_dir = self.all_args.model_dir 46 | 47 | if self.use_render: 48 | import imageio 49 | self.run_dir = config["run_dir"] 50 | self.gif_dir = str(self.run_dir / 'gifs') 51 | if not os.path.exists(self.gif_dir): 52 | os.makedirs(self.gif_dir) 53 | else: 54 | self.run_dir = config["run_dir"] 55 | self.log_dir = str(self.run_dir / 'logs') 56 | if not os.path.exists(self.log_dir): 57 | os.makedirs(self.log_dir) 58 | self.writter = SummaryWriter(self.log_dir) 59 | self.save_dir = str(self.run_dir / 'models') 60 | if not os.path.exists(self.save_dir): 61 | os.makedirs(self.save_dir) 62 | 63 | 64 | if self.all_args.algorithm_name == "happo": 65 | from algorithms.happo_trainer import HAPPO as TrainAlgo 66 | from algorithms.happo_policy import HAPPO_Policy as Policy 67 | elif self.all_args.algorithm_name == "hatrpo": 68 | from algorithms.hatrpo_trainer import HATRPO as TrainAlgo 69 | from algorithms.hatrpo_policy import HATRPO_Policy as Policy 70 | else: 71 | raise NotImplementedError 72 | 73 | print("share_observation_space: ", self.envs.share_observation_space) 74 | print("observation_space: ", self.envs.observation_space) 75 | print("action_space: ", self.envs.action_space) 76 | 77 | self.policy = [] 78 | for agent_id in range(self.num_agents): 79 | share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id] 80 | # policy network 81 | po = Policy(self.all_args, 82 | self.envs.observation_space[agent_id], 83 | share_observation_space, 84 | self.envs.action_space[agent_id], 85 | device = self.device) 86 | self.policy.append(po) 87 | 88 | if self.model_dir is not None: 89 | self.restore() 90 | 91 | self.trainer = [] 92 | self.buffer = [] 93 | for agent_id in range(self.num_agents): 94 | # algorithm 95 | tr = TrainAlgo(self.all_args, self.policy[agent_id], device = self.device) 96 | # buffer 97 | share_observation_space = self.envs.share_observation_space[agent_id] if self.use_centralized_V else self.envs.observation_space[agent_id] 98 | bu = SeparatedReplayBuffer(self.all_args, 99 | self.envs.observation_space[agent_id], 100 | share_observation_space, 101 | self.envs.action_space[agent_id]) 102 | self.buffer.append(bu) 103 | self.trainer.append(tr) 104 | 105 | def run(self): 106 | raise NotImplementedError 107 | 108 | def warmup(self): 109 | raise NotImplementedError 110 | 111 | def collect(self, step): 112 | raise NotImplementedError 113 | 114 | def insert(self, data): 115 | raise NotImplementedError 116 | 117 | @torch.no_grad() 118 | def compute(self): 119 | for agent_id in range(self.num_agents): 120 | self.trainer[agent_id].prep_rollout() 121 | next_value = self.trainer[agent_id].policy.get_values(self.buffer[agent_id].share_obs[-1], 122 | self.buffer[agent_id].rnn_states_critic[-1], 123 | self.buffer[agent_id].masks[-1]) 124 | next_value = _t2n(next_value) 125 | self.buffer[agent_id].compute_returns(next_value, self.trainer[agent_id].value_normalizer) 126 | 127 | def train(self): 128 | train_infos = [] 129 | # random update order 130 | 131 | action_dim=self.buffer[0].actions.shape[-1] 132 | factor = np.ones((self.episode_length, self.n_rollout_threads, action_dim), dtype=np.float32) 133 | 134 | for agent_id in torch.randperm(self.num_agents): 135 | self.trainer[agent_id].prep_training() 136 | self.buffer[agent_id].update_factor(factor) 137 | available_actions = None if self.buffer[agent_id].available_actions is None \ 138 | else self.buffer[agent_id].available_actions[:-1].reshape(-1, *self.buffer[agent_id].available_actions.shape[2:]) 139 | 140 | if self.all_args.algorithm_name == "hatrpo": 141 | old_actions_logprob, _, _, _, _ =self.trainer[agent_id].policy.actor.evaluate_actions(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:]), 142 | self.buffer[agent_id].rnn_states[0:1].reshape(-1, *self.buffer[agent_id].rnn_states.shape[2:]), 143 | self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:]), 144 | self.buffer[agent_id].masks[:-1].reshape(-1, *self.buffer[agent_id].masks.shape[2:]), 145 | available_actions, 146 | self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:])) 147 | else: 148 | old_actions_logprob, _ =self.trainer[agent_id].policy.actor.evaluate_actions(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:]), 149 | self.buffer[agent_id].rnn_states[0:1].reshape(-1, *self.buffer[agent_id].rnn_states.shape[2:]), 150 | self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:]), 151 | self.buffer[agent_id].masks[:-1].reshape(-1, *self.buffer[agent_id].masks.shape[2:]), 152 | available_actions, 153 | self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:])) 154 | train_info = self.trainer[agent_id].train(self.buffer[agent_id]) 155 | 156 | if self.all_args.algorithm_name == "hatrpo": 157 | new_actions_logprob, _, _, _, _ =self.trainer[agent_id].policy.actor.evaluate_actions(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:]), 158 | self.buffer[agent_id].rnn_states[0:1].reshape(-1, *self.buffer[agent_id].rnn_states.shape[2:]), 159 | self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:]), 160 | self.buffer[agent_id].masks[:-1].reshape(-1, *self.buffer[agent_id].masks.shape[2:]), 161 | available_actions, 162 | self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:])) 163 | else: 164 | new_actions_logprob, _ =self.trainer[agent_id].policy.actor.evaluate_actions(self.buffer[agent_id].obs[:-1].reshape(-1, *self.buffer[agent_id].obs.shape[2:]), 165 | self.buffer[agent_id].rnn_states[0:1].reshape(-1, *self.buffer[agent_id].rnn_states.shape[2:]), 166 | self.buffer[agent_id].actions.reshape(-1, *self.buffer[agent_id].actions.shape[2:]), 167 | self.buffer[agent_id].masks[:-1].reshape(-1, *self.buffer[agent_id].masks.shape[2:]), 168 | available_actions, 169 | self.buffer[agent_id].active_masks[:-1].reshape(-1, *self.buffer[agent_id].active_masks.shape[2:])) 170 | 171 | factor = factor*_t2n(torch.exp(new_actions_logprob-old_actions_logprob).reshape(self.episode_length,self.n_rollout_threads,action_dim)) 172 | train_infos.append(train_info) 173 | self.buffer[agent_id].after_update() 174 | 175 | return train_infos 176 | 177 | def save(self): 178 | for agent_id in range(self.num_agents): 179 | if self.use_single_network: 180 | policy_model = self.trainer[agent_id].policy.model 181 | torch.save(policy_model.state_dict(), str(self.save_dir) + "/model_agent" + str(agent_id) + ".pt") 182 | else: 183 | policy_actor = self.trainer[agent_id].policy.actor 184 | torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor_agent" + str(agent_id) + ".pt") 185 | policy_critic = self.trainer[agent_id].policy.critic 186 | torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic_agent" + str(agent_id) + ".pt") 187 | 188 | def restore(self): 189 | for agent_id in range(self.num_agents): 190 | if self.use_single_network: 191 | policy_model_state_dict = torch.load(str(self.model_dir) + '/model_agent' + str(agent_id) + '.pt') 192 | self.policy[agent_id].model.load_state_dict(policy_model_state_dict) 193 | else: 194 | policy_actor_state_dict = torch.load(str(self.model_dir) + '/actor_agent' + str(agent_id) + '.pt') 195 | self.policy[agent_id].actor.load_state_dict(policy_actor_state_dict) 196 | policy_critic_state_dict = torch.load(str(self.model_dir) + '/critic_agent' + str(agent_id) + '.pt') 197 | self.policy[agent_id].critic.load_state_dict(policy_critic_state_dict) 198 | 199 | def log_train(self, train_infos, total_num_steps): 200 | for agent_id in range(self.num_agents): 201 | for k, v in train_infos[agent_id].items(): 202 | agent_k = "agent%i/" % agent_id + k 203 | self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps) 204 | 205 | def log_env(self, env_infos, total_num_steps): 206 | for k, v in env_infos.items(): 207 | if len(v) > 0: 208 | self.writter.add_scalars(k, {k: np.mean(v)}, total_num_steps) 209 | -------------------------------------------------------------------------------- /runners/separated/smac_runner.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from functools import reduce 4 | import torch 5 | from runners.separated.base_runner import Runner 6 | 7 | def _t2n(x): 8 | return x.detach().cpu().numpy() 9 | 10 | class SMACRunner(Runner): 11 | """Runner class to perform training, evaluation. and data collection for SMAC. See parent class for details.""" 12 | def __init__(self, config): 13 | super(SMACRunner, self).__init__(config) 14 | 15 | def run(self): 16 | self.warmup() 17 | 18 | start = time.time() 19 | episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads 20 | 21 | last_battles_game = np.zeros(self.n_rollout_threads, dtype=np.float32) 22 | last_battles_won = np.zeros(self.n_rollout_threads, dtype=np.float32) 23 | 24 | for episode in range(episodes): 25 | if self.use_linear_lr_decay: 26 | self.trainer.policy.lr_decay(episode, episodes) 27 | 28 | for step in range(self.episode_length): 29 | # Sample actions 30 | values, actions, action_log_probs, rnn_states, rnn_states_critic = self.collect(step) 31 | # Obser reward and next obs 32 | obs, share_obs, rewards, dones, infos, available_actions = self.envs.step(actions) 33 | 34 | data = obs, share_obs, rewards, dones, infos, available_actions, \ 35 | values, actions, action_log_probs, \ 36 | rnn_states, rnn_states_critic 37 | 38 | # insert data into buffer 39 | self.insert(data) 40 | 41 | # compute return and update network 42 | self.compute() 43 | train_infos = self.train() 44 | 45 | # post process 46 | total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads 47 | # save model 48 | if (episode % self.save_interval == 0 or episode == episodes - 1): 49 | self.save() 50 | 51 | # log information 52 | if episode % self.log_interval == 0: 53 | end = time.time() 54 | print("\n Map {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n" 55 | .format(self.all_args.map_name, 56 | self.algorithm_name, 57 | self.experiment_name, 58 | episode, 59 | episodes, 60 | total_num_steps, 61 | self.num_env_steps, 62 | int(total_num_steps / (end - start)))) 63 | 64 | if self.env_name == "StarCraft2": 65 | battles_won = [] 66 | battles_game = [] 67 | incre_battles_won = [] 68 | incre_battles_game = [] 69 | 70 | for i, info in enumerate(infos): 71 | if 'battles_won' in info[0].keys(): 72 | battles_won.append(info[0]['battles_won']) 73 | incre_battles_won.append(info[0]['battles_won']-last_battles_won[i]) 74 | if 'battles_game' in info[0].keys(): 75 | battles_game.append(info[0]['battles_game']) 76 | incre_battles_game.append(info[0]['battles_game']-last_battles_game[i]) 77 | 78 | incre_win_rate = np.sum(incre_battles_won)/np.sum(incre_battles_game) if np.sum(incre_battles_game)>0 else 0.0 79 | print("incre win rate is {}.".format(incre_win_rate)) 80 | self.writter.add_scalars("incre_win_rate", {"incre_win_rate": incre_win_rate}, total_num_steps) 81 | 82 | last_battles_game = battles_game 83 | last_battles_won = battles_won 84 | # modified 85 | 86 | for agent_id in range(self.num_agents): 87 | train_infos[agent_id]['dead_ratio'] = 1 - self.buffer[agent_id].active_masks.sum() /(self.num_agents* reduce(lambda x, y: x*y, list(self.buffer[agent_id].active_masks.shape))) 88 | 89 | self.log_train(train_infos, total_num_steps) 90 | 91 | # eval 92 | if episode % self.eval_interval == 0 and self.use_eval: 93 | self.eval(total_num_steps) 94 | 95 | def warmup(self): 96 | # reset env 97 | obs, share_obs, available_actions = self.envs.reset() 98 | # replay buffer 99 | if not self.use_centralized_V: 100 | share_obs = obs 101 | for agent_id in range(self.num_agents): 102 | self.buffer[agent_id].share_obs[0] = share_obs[:,agent_id].copy() 103 | self.buffer[agent_id].obs[0] = obs[:,agent_id].copy() 104 | self.buffer[agent_id].available_actions[0] = available_actions[:,agent_id].copy() 105 | 106 | @torch.no_grad() 107 | def collect(self, step): 108 | value_collector=[] 109 | action_collector=[] 110 | action_log_prob_collector=[] 111 | rnn_state_collector=[] 112 | rnn_state_critic_collector=[] 113 | for agent_id in range(self.num_agents): 114 | self.trainer[agent_id].prep_rollout() 115 | value, action, action_log_prob, rnn_state, rnn_state_critic \ 116 | = self.trainer[agent_id].policy.get_actions(self.buffer[agent_id].share_obs[step], 117 | self.buffer[agent_id].obs[step], 118 | self.buffer[agent_id].rnn_states[step], 119 | self.buffer[agent_id].rnn_states_critic[step], 120 | self.buffer[agent_id].masks[step], 121 | self.buffer[agent_id].available_actions[step]) 122 | value_collector.append(_t2n(value)) 123 | action_collector.append(_t2n(action)) 124 | action_log_prob_collector.append(_t2n(action_log_prob)) 125 | rnn_state_collector.append(_t2n(rnn_state)) 126 | rnn_state_critic_collector.append(_t2n(rnn_state_critic)) 127 | # [self.envs, agents, dim] 128 | values = np.array(value_collector).transpose(1, 0, 2) 129 | actions = np.array(action_collector).transpose(1, 0, 2) 130 | action_log_probs = np.array(action_log_prob_collector).transpose(1, 0, 2) 131 | rnn_states = np.array(rnn_state_collector).transpose(1, 0, 2, 3) 132 | rnn_states_critic = np.array(rnn_state_critic_collector).transpose(1, 0, 2, 3) 133 | 134 | return values, actions, action_log_probs, rnn_states, rnn_states_critic 135 | 136 | def insert(self, data): 137 | obs, share_obs, rewards, dones, infos, available_actions, \ 138 | values, actions, action_log_probs, rnn_states, rnn_states_critic = data 139 | 140 | dones_env = np.all(dones, axis=1) 141 | 142 | rnn_states[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32) 143 | rnn_states_critic[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, *self.buffer[0].rnn_states_critic.shape[2:]), dtype=np.float32) 144 | 145 | masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32) 146 | masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32) 147 | 148 | active_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32) 149 | active_masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32) 150 | active_masks[dones_env == True] = np.ones(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32) 151 | 152 | bad_masks = np.array([[[0.0] if info[agent_id]['bad_transition'] else [1.0] for agent_id in range(self.num_agents)] for info in infos]) 153 | 154 | if not self.use_centralized_V: 155 | share_obs = obs 156 | for agent_id in range(self.num_agents): 157 | self.buffer[agent_id].insert(share_obs[:,agent_id], obs[:,agent_id], rnn_states[:,agent_id], 158 | rnn_states_critic[:,agent_id],actions[:,agent_id], action_log_probs[:,agent_id], 159 | values[:,agent_id], rewards[:,agent_id], masks[:,agent_id], bad_masks[:,agent_id], 160 | active_masks[:,agent_id], available_actions[:,agent_id]) 161 | 162 | def log_train(self, train_infos, total_num_steps): 163 | for agent_id in range(self.num_agents): 164 | train_infos[agent_id]["average_step_rewards"] = np.mean(self.buffer[agent_id].rewards) 165 | for k, v in train_infos[agent_id].items(): 166 | agent_k = "agent%i/" % agent_id + k 167 | self.writter.add_scalars(agent_k, {agent_k: v}, total_num_steps) 168 | 169 | @torch.no_grad() 170 | def eval(self, total_num_steps): 171 | eval_battles_won = 0 172 | eval_episode = 0 173 | 174 | eval_episode_rewards = [] 175 | one_episode_rewards = [] 176 | for eval_i in range(self.n_eval_rollout_threads): 177 | one_episode_rewards.append([]) 178 | eval_episode_rewards.append([]) 179 | 180 | eval_obs, eval_share_obs, eval_available_actions = self.eval_envs.reset() 181 | 182 | eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32) 183 | eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32) 184 | 185 | while True: 186 | eval_actions_collector=[] 187 | eval_rnn_states_collector=[] 188 | for agent_id in range(self.num_agents): 189 | self.trainer[agent_id].prep_rollout() 190 | eval_actions, temp_rnn_state = \ 191 | self.trainer[agent_id].policy.act(eval_obs[:,agent_id], 192 | eval_rnn_states[:,agent_id], 193 | eval_masks[:,agent_id], 194 | eval_available_actions[:,agent_id], 195 | deterministic=True) 196 | eval_rnn_states[:,agent_id]=_t2n(temp_rnn_state) 197 | eval_actions_collector.append(_t2n(eval_actions)) 198 | 199 | eval_actions = np.array(eval_actions_collector).transpose(1,0,2) 200 | 201 | 202 | # Obser reward and next obs 203 | eval_obs, eval_share_obs, eval_rewards, eval_dones, eval_infos, eval_available_actions = self.eval_envs.step(eval_actions) 204 | for eval_i in range(self.n_eval_rollout_threads): 205 | one_episode_rewards[eval_i].append(eval_rewards[eval_i]) 206 | 207 | eval_dones_env = np.all(eval_dones, axis=1) 208 | 209 | eval_rnn_states[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32) 210 | 211 | eval_masks = np.ones((self.all_args.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32) 212 | eval_masks[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, 1), dtype=np.float32) 213 | 214 | for eval_i in range(self.n_eval_rollout_threads): 215 | if eval_dones_env[eval_i]: 216 | eval_episode += 1 217 | eval_episode_rewards[eval_i].append(np.sum(one_episode_rewards[eval_i], axis=0)) 218 | one_episode_rewards[eval_i] = [] 219 | if eval_infos[eval_i][0]['won']: 220 | eval_battles_won += 1 221 | 222 | if eval_episode >= self.all_args.eval_episodes: 223 | eval_episode_rewards = np.concatenate(eval_episode_rewards) 224 | eval_env_infos = {'eval_average_episode_rewards': eval_episode_rewards} 225 | self.log_env(eval_env_infos, total_num_steps) 226 | eval_win_rate = eval_battles_won/eval_episode 227 | print("eval win rate is {}.".format(eval_win_rate)) 228 | self.writter.add_scalars("eval_win_rate", {"eval_win_rate": eval_win_rate}, total_num_steps) 229 | break 230 | -------------------------------------------------------------------------------- /envs/starcraft2/smac_maps.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from pysc2.maps import lib 6 | 7 | 8 | class SMACMap(lib.Map): 9 | directory = "SMAC_Maps" 10 | download = "https://github.com/oxwhirl/smac#smac-maps" 11 | players = 2 12 | step_mul = 8 13 | game_steps_per_episode = 0 14 | 15 | 16 | map_param_registry = { 17 | "3m": { 18 | "n_agents": 3, 19 | "n_enemies": 3, 20 | "limit": 60, 21 | "a_race": "T", 22 | "b_race": "T", 23 | "unit_type_bits": 0, 24 | "map_type": "marines", 25 | }, 26 | "8m": { 27 | "n_agents": 8, 28 | "n_enemies": 8, 29 | "limit": 120, 30 | "a_race": "T", 31 | "b_race": "T", 32 | "unit_type_bits": 0, 33 | "map_type": "marines", 34 | }, 35 | "25m": { 36 | "n_agents": 25, 37 | "n_enemies": 25, 38 | "limit": 150, 39 | "a_race": "T", 40 | "b_race": "T", 41 | "unit_type_bits": 0, 42 | "map_type": "marines", 43 | }, 44 | "5m_vs_6m": { 45 | "n_agents": 5, 46 | "n_enemies": 6, 47 | "limit": 70, 48 | "a_race": "T", 49 | "b_race": "T", 50 | "unit_type_bits": 0, 51 | "map_type": "marines", 52 | }, 53 | "8m_vs_9m": { 54 | "n_agents": 8, 55 | "n_enemies": 9, 56 | "limit": 120, 57 | "a_race": "T", 58 | "b_race": "T", 59 | "unit_type_bits": 0, 60 | "map_type": "marines", 61 | }, 62 | "10m_vs_11m": { 63 | "n_agents": 10, 64 | "n_enemies": 11, 65 | "limit": 150, 66 | "a_race": "T", 67 | "b_race": "T", 68 | "unit_type_bits": 0, 69 | "map_type": "marines", 70 | }, 71 | "27m_vs_30m": { 72 | "n_agents": 27, 73 | "n_enemies": 30, 74 | "limit": 180, 75 | "a_race": "T", 76 | "b_race": "T", 77 | "unit_type_bits": 0, 78 | "map_type": "marines", 79 | }, 80 | "MMM": { 81 | "n_agents": 10, 82 | "n_enemies": 10, 83 | "limit": 150, 84 | "a_race": "T", 85 | "b_race": "T", 86 | "unit_type_bits": 3, 87 | "map_type": "MMM", 88 | }, 89 | "MMM2": { 90 | "n_agents": 10, 91 | "n_enemies": 12, 92 | "limit": 180, 93 | "a_race": "T", 94 | "b_race": "T", 95 | "unit_type_bits": 3, 96 | "map_type": "MMM", 97 | }, 98 | "2s3z": { 99 | "n_agents": 5, 100 | "n_enemies": 5, 101 | "limit": 120, 102 | "a_race": "P", 103 | "b_race": "P", 104 | "unit_type_bits": 2, 105 | "map_type": "stalkers_and_zealots", 106 | }, 107 | "3s5z": { 108 | "n_agents": 8, 109 | "n_enemies": 8, 110 | "limit": 150, 111 | "a_race": "P", 112 | "b_race": "P", 113 | "unit_type_bits": 2, 114 | "map_type": "stalkers_and_zealots", 115 | }, 116 | "3s5z_vs_3s6z": { 117 | "n_agents": 8, 118 | "n_enemies": 9, 119 | "limit": 170, 120 | "a_race": "P", 121 | "b_race": "P", 122 | "unit_type_bits": 2, 123 | "map_type": "stalkers_and_zealots", 124 | }, 125 | "3s_vs_3z": { 126 | "n_agents": 3, 127 | "n_enemies": 3, 128 | "limit": 150, 129 | "a_race": "P", 130 | "b_race": "P", 131 | "unit_type_bits": 0, 132 | "map_type": "stalkers", 133 | }, 134 | "3s_vs_4z": { 135 | "n_agents": 3, 136 | "n_enemies": 4, 137 | "limit": 200, 138 | "a_race": "P", 139 | "b_race": "P", 140 | "unit_type_bits": 0, 141 | "map_type": "stalkers", 142 | }, 143 | "3s_vs_5z": { 144 | "n_agents": 3, 145 | "n_enemies": 5, 146 | "limit": 250, 147 | "a_race": "P", 148 | "b_race": "P", 149 | "unit_type_bits": 0, 150 | "map_type": "stalkers", 151 | }, 152 | "1c3s5z": { 153 | "n_agents": 9, 154 | "n_enemies": 9, 155 | "limit": 180, 156 | "a_race": "P", 157 | "b_race": "P", 158 | "unit_type_bits": 3, 159 | "map_type": "colossi_stalkers_zealots", 160 | }, 161 | "2m_vs_1z": { 162 | "n_agents": 2, 163 | "n_enemies": 1, 164 | "limit": 150, 165 | "a_race": "T", 166 | "b_race": "P", 167 | "unit_type_bits": 0, 168 | "map_type": "marines", 169 | }, 170 | "corridor": { 171 | "n_agents": 6, 172 | "n_enemies": 24, 173 | "limit": 400, 174 | "a_race": "P", 175 | "b_race": "Z", 176 | "unit_type_bits": 0, 177 | "map_type": "zealots", 178 | }, 179 | "6h_vs_8z": { 180 | "n_agents": 6, 181 | "n_enemies": 8, 182 | "limit": 150, 183 | "a_race": "Z", 184 | "b_race": "P", 185 | "unit_type_bits": 0, 186 | "map_type": "hydralisks", 187 | }, 188 | "2s_vs_1sc": { 189 | "n_agents": 2, 190 | "n_enemies": 1, 191 | "limit": 300, 192 | "a_race": "P", 193 | "b_race": "Z", 194 | "unit_type_bits": 0, 195 | "map_type": "stalkers", 196 | }, 197 | "so_many_baneling": { 198 | "n_agents": 7, 199 | "n_enemies": 32, 200 | "limit": 100, 201 | "a_race": "P", 202 | "b_race": "Z", 203 | "unit_type_bits": 0, 204 | "map_type": "zealots", 205 | }, 206 | "bane_vs_bane": { 207 | "n_agents": 24, 208 | "n_enemies": 24, 209 | "limit": 200, 210 | "a_race": "Z", 211 | "b_race": "Z", 212 | "unit_type_bits": 2, 213 | "map_type": "bane", 214 | }, 215 | "2c_vs_64zg": { 216 | "n_agents": 2, 217 | "n_enemies": 64, 218 | "limit": 400, 219 | "a_race": "P", 220 | "b_race": "Z", 221 | "unit_type_bits": 0, 222 | "map_type": "colossus", 223 | }, 224 | 225 | # This is adhoc environment 226 | "1c2z_vs_1c1s1z": { 227 | "n_agents": 3, 228 | "n_enemies": 3, 229 | "limit": 180, 230 | "a_race": "P", 231 | "b_race": "P", 232 | "unit_type_bits": 3, 233 | "map_type": "colossi_stalkers_zealots", 234 | }, 235 | "1c2s_vs_1c1s1z": { 236 | "n_agents": 3, 237 | "n_enemies": 3, 238 | "limit": 180, 239 | "a_race": "P", 240 | "b_race": "P", 241 | "unit_type_bits": 3, 242 | "map_type": "colossi_stalkers_zealots", 243 | }, 244 | "2c1z_vs_1c1s1z": { 245 | "n_agents": 3, 246 | "n_enemies": 3, 247 | "limit": 180, 248 | "a_race": "P", 249 | "b_race": "P", 250 | "unit_type_bits": 3, 251 | "map_type": "colossi_stalkers_zealots", 252 | }, 253 | "2c1s_vs_1c1s1z": { 254 | "n_agents": 3, 255 | "n_enemies": 3, 256 | "limit": 180, 257 | "a_race": "P", 258 | "b_race": "P", 259 | "unit_type_bits": 3, 260 | "map_type": "colossi_stalkers_zealots", 261 | }, 262 | "1c1s1z_vs_1c1s1z": { 263 | "n_agents": 3, 264 | "n_enemies": 3, 265 | "limit": 180, 266 | "a_race": "P", 267 | "b_race": "P", 268 | "unit_type_bits": 3, 269 | "map_type": "colossi_stalkers_zealots", 270 | }, 271 | 272 | "3s5z_vs_4s4z": { 273 | "n_agents": 8, 274 | "n_enemies": 8, 275 | "limit": 150, 276 | "a_race": "P", 277 | "b_race": "P", 278 | "unit_type_bits": 2, 279 | "map_type": "stalkers_and_zealots", 280 | }, 281 | "4s4z_vs_4s4z": { 282 | "n_agents": 8, 283 | "n_enemies": 8, 284 | "limit": 150, 285 | "a_race": "P", 286 | "b_race": "P", 287 | "unit_type_bits": 2, 288 | "map_type": "stalkers_and_zealots", 289 | }, 290 | "5s3z_vs_4s4z": { 291 | "n_agents": 8, 292 | "n_enemies": 8, 293 | "limit": 150, 294 | "a_race": "P", 295 | "b_race": "P", 296 | "unit_type_bits": 2, 297 | "map_type": "stalkers_and_zealots", 298 | }, 299 | "6s2z_vs_4s4z": { 300 | "n_agents": 8, 301 | "n_enemies": 8, 302 | "limit": 150, 303 | "a_race": "P", 304 | "b_race": "P", 305 | "unit_type_bits": 2, 306 | "map_type": "stalkers_and_zealots", 307 | }, 308 | "2s6z_vs_4s4z": { 309 | "n_agents": 8, 310 | "n_enemies": 8, 311 | "limit": 150, 312 | "a_race": "P", 313 | "b_race": "P", 314 | "unit_type_bits": 2, 315 | "map_type": "stalkers_and_zealots", 316 | }, 317 | 318 | "6m_vs_6m_tz": { 319 | "n_agents": 6, 320 | "n_enemies": 6, 321 | "limit": 70, 322 | "a_race": "T", 323 | "b_race": "T", 324 | "unit_type_bits": 0, 325 | "map_type": "marines", 326 | }, 327 | "5m_vs_6m_tz": { 328 | "n_agents": 5, 329 | "n_enemies": 6, 330 | "limit": 70, 331 | "a_race": "T", 332 | "b_race": "T", 333 | "unit_type_bits": 0, 334 | "map_type": "marines", 335 | }, 336 | "3s6z_vs_3s6z": { 337 | "n_agents": 9, 338 | "n_enemies": 9, 339 | "limit": 170, 340 | "a_race": "P", 341 | "b_race": "P", 342 | "unit_type_bits": 2, 343 | "map_type": "stalkers_and_zealots", 344 | }, 345 | "7h_vs_8z": { 346 | "n_agents": 7, 347 | "n_enemies": 8, 348 | "limit": 150, 349 | "a_race": "Z", 350 | "b_race": "P", 351 | "unit_type_bits": 0, 352 | "map_type": "hydralisks", 353 | }, 354 | "2s2z_vs_zg": { 355 | "n_agents": 4, 356 | "n_enemies": 20, 357 | "limit": 200, 358 | "a_race": "P", 359 | "b_race": "Z", 360 | "unit_type_bits": 2, 361 | "map_type": "stalkers_and_zealots_vs_zergling", 362 | }, 363 | "1s3z_vs_zg": { 364 | "n_agents": 4, 365 | "n_enemies": 20, 366 | "limit": 200, 367 | "a_race": "P", 368 | "b_race": "Z", 369 | "unit_type_bits": 2, 370 | "map_type": "stalkers_and_zealots_vs_zergling", 371 | }, 372 | "3s1z_vs_zg": { 373 | "n_agents": 4, 374 | "n_enemies": 20, 375 | "limit": 200, 376 | "a_race": "P", 377 | "b_race": "Z", 378 | "unit_type_bits": 2, 379 | "map_type": "stalkers_and_zealots_vs_zergling", 380 | }, 381 | 382 | "2s2z_vs_zg_easy": { 383 | "n_agents": 4, 384 | "n_enemies": 18, 385 | "limit": 200, 386 | "a_race": "P", 387 | "b_race": "Z", 388 | "unit_type_bits": 2, 389 | "map_type": "stalkers_and_zealots_vs_zergling", 390 | }, 391 | "1s3z_vs_zg_easy": { 392 | "n_agents": 4, 393 | "n_enemies": 18, 394 | "limit": 200, 395 | "a_race": "P", 396 | "b_race": "Z", 397 | "unit_type_bits": 2, 398 | "map_type": "stalkers_and_zealots_vs_zergling", 399 | }, 400 | "3s1z_vs_zg_easy": { 401 | "n_agents": 4, 402 | "n_enemies": 18, 403 | "limit": 200, 404 | "a_race": "P", 405 | "b_race": "Z", 406 | "unit_type_bits": 2, 407 | "map_type": "stalkers_and_zealots_vs_zergling", 408 | }, 409 | "28m_vs_30m": { 410 | "n_agents": 28, 411 | "n_enemies": 30, 412 | "limit": 180, 413 | "a_race": "T", 414 | "b_race": "T", 415 | "unit_type_bits": 0, 416 | "map_type": "marines", 417 | }, 418 | "29m_vs_30m": { 419 | "n_agents": 29, 420 | "n_enemies": 30, 421 | "limit": 180, 422 | "a_race": "T", 423 | "b_race": "T", 424 | "unit_type_bits": 0, 425 | "map_type": "marines", 426 | }, 427 | "30m_vs_30m": { 428 | "n_agents": 30, 429 | "n_enemies": 30, 430 | "limit": 180, 431 | "a_race": "T", 432 | "b_race": "T", 433 | "unit_type_bits": 0, 434 | "map_type": "marines", 435 | }, 436 | "MMM2_test": { 437 | "n_agents": 10, 438 | "n_enemies": 12, 439 | "limit": 180, 440 | "a_race": "T", 441 | "b_race": "T", 442 | "unit_type_bits": 3, 443 | "map_type": "MMM", 444 | }, 445 | } 446 | 447 | 448 | def get_smac_map_registry(): 449 | return map_param_registry 450 | 451 | 452 | for name in map_param_registry.keys(): 453 | globals()[name] = type(name, (SMACMap,), dict(filename=name)) 454 | 455 | 456 | def get_map_params(map_name): 457 | map_param_registry = get_smac_map_registry() 458 | return map_param_registry[map_name] 459 | -------------------------------------------------------------------------------- /configs/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_config(): 4 | """ 5 | The configuration parser for common hyperparameters of all environment. 6 | Please reach each `scripts/train/_runner.py` file to find private hyperparameters 7 | only used in . 8 | 9 | Prepare parameters: 10 | --algorithm_name 11 | specifiy the algorithm, including `["happo", "hatrpo"]` 12 | --experiment_name 13 | an identifier to distinguish different experiment. 14 | --seed 15 | set seed for numpy and torch 16 | --seed_specify 17 | by default True Random or specify seed for numpy/torch 18 | --runing_id 19 | the runing index of experiment (default=1) 20 | --cuda 21 | by default True, will use GPU to train; or else will use CPU; 22 | --cuda_deterministic 23 | by default, make sure random seed effective. if set, bypass such function. 24 | --n_training_threads 25 | number of training threads working in parallel. by default 1 26 | --n_rollout_threads 27 | number of parallel envs for training rollout. by default 32 28 | --n_eval_rollout_threads 29 | number of parallel envs for evaluating rollout. by default 1 30 | --n_render_rollout_threads 31 | number of parallel envs for rendering, could only be set as 1 for some environments. 32 | --num_env_steps 33 | number of env steps to train (default: 10e6) 34 | 35 | 36 | Env parameters: 37 | --env_name 38 | specify the name of environment 39 | --use_obs_instead_of_state 40 | [only for some env] by default False, will use global state; or else will use concatenated local obs. 41 | 42 | Replay Buffer parameters: 43 | --episode_length 44 | the max length of episode in the buffer. 45 | 46 | Network parameters: 47 | --share_policy 48 | by default True, all agents will share the same network; set to make training agents use different policies. 49 | --use_centralized_V 50 | by default True, use centralized training mode; or else will decentralized training mode. 51 | --stacked_frames 52 | Number of input frames which should be stack together. 53 | --hidden_size 54 | Dimension of hidden layers for actor/critic networks 55 | --layer_N 56 | Number of layers for actor/critic networks 57 | --use_ReLU 58 | by default True, will use ReLU. or else will use Tanh. 59 | --use_popart 60 | by default True, use running mean and std to normalize rewards. 61 | --use_feature_normalization 62 | by default True, apply layernorm to normalize inputs. 63 | --use_orthogonal 64 | by default True, use Orthogonal initialization for weights and 0 initialization for biases. or else, will use xavier uniform inilialization. 65 | --gain 66 | by default 0.01, use the gain # of last action layer 67 | --use_naive_recurrent_policy 68 | by default False, use the whole trajectory to calculate hidden states. 69 | --use_recurrent_policy 70 | by default, use Recurrent Policy. If set, do not use. 71 | --recurrent_N 72 | The number of recurrent layers ( default 1). 73 | --data_chunk_length 74 | Time length of chunks used to train a recurrent_policy, default 10. 75 | 76 | Optimizer parameters: 77 | --lr 78 | learning rate parameter, (default: 5e-4, fixed). 79 | --critic_lr 80 | learning rate of critic (default: 5e-4, fixed) 81 | --opti_eps 82 | RMSprop optimizer epsilon (default: 1e-5) 83 | --weight_decay 84 | coefficience of weight decay (default: 0) 85 | 86 | TRPO parameters: 87 | --kl_threshold 88 | the threshold of kl-divergence (default: 0.01) 89 | --ls_step 90 | the step of line search (default: 10) 91 | --accept_ratio 92 | accept ratio of loss improve (default: 0.5) 93 | 94 | PPO parameters: 95 | --ppo_epoch 96 | number of ppo epochs (default: 15) 97 | --use_clipped_value_loss 98 | by default, clip loss value. If set, do not clip loss value. 99 | --clip_param 100 | ppo clip parameter (default: 0.2) 101 | --num_mini_batch 102 | number of batches for ppo (default: 1) 103 | --entropy_coef 104 | entropy term coefficient (default: 0.01) 105 | --use_max_grad_norm 106 | by default, use max norm of gradients. If set, do not use. 107 | --max_grad_norm 108 | max norm of gradients (default: 0.5) 109 | --use_gae 110 | by default, use generalized advantage estimation. If set, do not use gae. 111 | --gamma 112 | discount factor for rewards (default: 0.99) 113 | --gae_lambda 114 | gae lambda parameter (default: 0.95) 115 | --use_proper_time_limits 116 | by default, the return value does consider limits of time. If set, compute returns with considering time limits factor. 117 | --use_huber_loss 118 | by default, use huber loss. If set, do not use huber loss. 119 | --use_value_active_masks 120 | by default True, whether to mask useless data in value loss. 121 | --huber_delta 122 | coefficient of huber loss. 123 | 124 | 125 | Run parameters: 126 | --use_linear_lr_decay 127 | by default, do not apply linear decay to learning rate. If set, use a linear schedule on the learning rate 128 | --save_interval 129 | time duration between contiunous twice models saving. 130 | --log_interval 131 | time duration between contiunous twice log printing. 132 | --model_dir 133 | by default None. set the path to pretrained model. 134 | 135 | Eval parameters: 136 | --use_eval 137 | by default, do not start evaluation. If set`, start evaluation alongside with training. 138 | --eval_interval 139 | time duration between contiunous twice evaluation progress. 140 | --eval_episodes 141 | number of episodes of a single evaluation. 142 | 143 | Render parameters: 144 | --save_gifs 145 | by default, do not save render video. If set, save video. 146 | --use_render 147 | by default, do not render the env during training. If set, start render. Note: something, the environment has internal render process which is not controlled by this hyperparam. 148 | --render_episodes 149 | the number of episodes to render a given env 150 | --ifi 151 | the play interval of each rendered image in saved video. 152 | 153 | Pretrained parameters: 154 | 155 | """ 156 | parser = argparse.ArgumentParser(description='onpolicy_algorithm', formatter_class=argparse.RawDescriptionHelpFormatter) 157 | 158 | # prepare parameters 159 | parser.add_argument("--algorithm_name", type=str, 160 | default=' ', choices=["happo","hatrpo"]) 161 | parser.add_argument("--experiment_name", type=str, 162 | default="check", help="an identifier to distinguish different experiment.") 163 | parser.add_argument("--seed", type=int, 164 | default=1, help="Random seed for numpy/torch") 165 | parser.add_argument("--seed_specify", action="store_true", 166 | default=False, help="Random or specify seed for numpy/torch") 167 | parser.add_argument("--runing_id", type=int, 168 | default=1, help="the runing index of experiment") 169 | parser.add_argument("--cuda", action='store_false', 170 | default=True, help="by default True, will use GPU to train; or else will use CPU;") 171 | parser.add_argument("--cuda_deterministic", action='store_false', 172 | default=True, help="by default, make sure random seed effective. if set, bypass such function.") 173 | parser.add_argument("--n_training_threads", type=int, 174 | default=1, help="Number of torch threads for training") 175 | parser.add_argument("--n_rollout_threads", type=int, 176 | default=32, help="Number of parallel envs for training rollouts") 177 | parser.add_argument("--n_eval_rollout_threads", type=int, 178 | default=1, help="Number of parallel envs for evaluating rollouts") 179 | parser.add_argument("--n_render_rollout_threads", type=int, 180 | default=1, help="Number of parallel envs for rendering rollouts") 181 | parser.add_argument("--num_env_steps", type=int, 182 | default=10e6, help='Number of environment steps to train (default: 10e6)') 183 | parser.add_argument("--user_name", type=str, 184 | default='marl',help="[for wandb usage], to specify user's name for simply collecting training data.") 185 | # env parameters 186 | parser.add_argument("--env_name", type=str, 187 | default='StarCraft2', help="specify the name of environment") 188 | parser.add_argument("--use_obs_instead_of_state", action='store_true', 189 | default=False, help="Whether to use global state or concatenated obs") 190 | 191 | # replay buffer parameters 192 | parser.add_argument("--episode_length", type=int, 193 | default=200, help="Max length for any episode") 194 | 195 | # network parameters 196 | parser.add_argument("--share_policy", action='store_false', 197 | default=True, help='Whether agent share the same policy') 198 | parser.add_argument("--use_centralized_V", action='store_false', 199 | default=True, help="Whether to use centralized V function") 200 | parser.add_argument("--stacked_frames", type=int, 201 | default=1, help="Dimension of hidden layers for actor/critic networks") 202 | parser.add_argument("--use_stacked_frames", action='store_true', 203 | default=False, help="Whether to use stacked_frames") 204 | parser.add_argument("--hidden_size", type=int, 205 | default=64, help="Dimension of hidden layers for actor/critic networks") 206 | parser.add_argument("--layer_N", type=int, 207 | default=1, help="Number of layers for actor/critic networks") 208 | parser.add_argument("--use_ReLU", action='store_false', 209 | default=True, help="Whether to use ReLU") 210 | parser.add_argument("--use_popart", action='store_false', 211 | default=True, help="by default True, use running mean and std to normalize rewards.") 212 | parser.add_argument("--use_valuenorm", action='store_false', 213 | default=True, help="by default True, use running mean and std to normalize rewards.") 214 | parser.add_argument("--use_feature_normalization", action='store_false', 215 | default=True, help="Whether to apply layernorm to the inputs") 216 | parser.add_argument("--use_orthogonal", action='store_false', 217 | default=True, help="Whether to use Orthogonal initialization for weights and 0 initialization for biases") 218 | parser.add_argument("--gain", type=float, 219 | default=0.01, help="The gain # of last action layer") 220 | 221 | # recurrent parameters 222 | parser.add_argument("--use_naive_recurrent_policy", action='store_true', 223 | default=False, help='Whether to use a naive recurrent policy') 224 | parser.add_argument("--use_recurrent_policy", action='store_true', 225 | default=False, help='use a recurrent policy') 226 | parser.add_argument("--recurrent_N", type=int, 227 | default=1, help="The number of recurrent layers.") 228 | parser.add_argument("--data_chunk_length", type=int, 229 | default=10, help="Time length of chunks used to train a recurrent_policy") 230 | 231 | # optimizer parameters 232 | parser.add_argument("--lr", type=float, 233 | default=5e-4, help='learning rate (default: 5e-4)') 234 | parser.add_argument("--critic_lr", type=float, 235 | default=5e-4, help='critic learning rate (default: 5e-4)') 236 | parser.add_argument("--opti_eps", type=float, 237 | default=1e-5, help='RMSprop optimizer epsilon (default: 1e-5)') 238 | parser.add_argument("--weight_decay", type=float, default=0) 239 | parser.add_argument("--std_x_coef", type=float, default=1) 240 | parser.add_argument("--std_y_coef", type=float, default=0.5) 241 | 242 | 243 | # trpo parameters 244 | parser.add_argument("--kl_threshold", type=float, 245 | default=0.01, help='the threshold of kl-divergence (default: 0.01)') 246 | parser.add_argument("--ls_step", type=int, 247 | default=10, help='number of line search (default: 10)') 248 | parser.add_argument("--accept_ratio", type=float, 249 | default=0.5, help='accept ratio of loss improve (default: 0.5)') 250 | 251 | # ppo parameters 252 | parser.add_argument("--ppo_epoch", type=int, 253 | default=15, help='number of ppo epochs (default: 15)') 254 | parser.add_argument("--use_clipped_value_loss", action='store_false', 255 | default=True, help="by default, clip loss value. If set, do not clip loss value.") 256 | parser.add_argument("--clip_param", type=float, 257 | default=0.2, help='ppo clip parameter (default: 0.2)') 258 | parser.add_argument("--num_mini_batch", type=int, 259 | default=1, help='number of batches for ppo (default: 1)') 260 | parser.add_argument("--entropy_coef", type=float, 261 | default=0.01, help='entropy term coefficient (default: 0.01)') 262 | parser.add_argument("--value_loss_coef", type=float, 263 | default=1, help='value loss coefficient (default: 0.5)') 264 | parser.add_argument("--use_max_grad_norm", action='store_false', 265 | default=True, help="by default, use max norm of gradients. If set, do not use.") 266 | parser.add_argument("--max_grad_norm", type=float, 267 | default=10.0, help='max norm of gradients (default: 0.5)') 268 | parser.add_argument("--use_gae", action='store_false', 269 | default=True, help='use generalized advantage estimation') 270 | parser.add_argument("--gamma", type=float, default=0.99, 271 | help='discount factor for rewards (default: 0.99)') 272 | parser.add_argument("--gae_lambda", type=float, default=0.95, 273 | help='gae lambda parameter (default: 0.95)') 274 | parser.add_argument("--use_proper_time_limits", action='store_true', 275 | default=False, help='compute returns taking into account time limits') 276 | parser.add_argument("--use_huber_loss", action='store_false', 277 | default=True, help="by default, use huber loss. If set, do not use huber loss.") 278 | parser.add_argument("--use_value_active_masks", action='store_false', 279 | default=True, help="by default True, whether to mask useless data in value loss.") 280 | parser.add_argument("--use_policy_active_masks", action='store_false', 281 | default=True, help="by default True, whether to mask useless data in policy loss.") 282 | parser.add_argument("--huber_delta", type=float, 283 | default=10.0, help=" coefficience of huber loss.") 284 | 285 | # run parameters 286 | parser.add_argument("--use_linear_lr_decay", action='store_true', 287 | default=False, help='use a linear schedule on the learning rate') 288 | parser.add_argument("--save_interval", type=int, 289 | default=1, help="time duration between contiunous twice models saving.") 290 | parser.add_argument("--log_interval", type=int, 291 | default=5, help="time duration between contiunous twice log printing.") 292 | parser.add_argument("--model_dir", type=str, 293 | default=None, help="by default None. set the path to pretrained model.") 294 | 295 | # eval parameters 296 | parser.add_argument("--use_eval", action='store_true', 297 | default=False, help="by default, do not start evaluation. If set`, start evaluation alongside with training.") 298 | parser.add_argument("--eval_interval", type=int, 299 | default=25, help="time duration between contiunous twice evaluation progress.") 300 | parser.add_argument("--eval_episodes", type=int, 301 | default=32, help="number of episodes of a single evaluation.") 302 | 303 | # render parameters 304 | parser.add_argument("--save_gifs", action='store_true', 305 | default=False, help="by default, do not save render video. If set, save video.") 306 | parser.add_argument("--use_render", action='store_true', 307 | default=False, help="by default, do not render the env during training. If set, start render. Note: something, the environment has internal render process which is not controlled by this hyperparam.") 308 | parser.add_argument("--render_episodes", type=int, 309 | default=5, help="the number of episodes to render a given env") 310 | parser.add_argument("--ifi", type=float, 311 | default=0.1, help="the play interval of each rendered image in saved video.") 312 | 313 | return parser 314 | --------------------------------------------------------------------------------