├── .gitignore ├── README.md ├── __init__.py ├── environment.py ├── img ├── ray_multi_agent_demo_model_env.png └── results_rewards.svg ├── model.py ├── multi_action_dist.py ├── multi_trainer.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | .idea/ 3 | *.DS_STORE 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimal example for multi-agent RL in RLLib with differentiable communication channel 2 | 3 | This is a minimal example to demonstrate how multi-agent reinforcement learning with differentiable communication channels and centralized critics can be realized in RLLib. This example serves as a reference implementation and starting point for making RLLib more compatible with such architectures. 4 | 5 | ## Citation 6 | This project is derived from our paper "The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning" (code available [here](https://github.com/proroklab/adversarial_comms)). If you use any part of this code in your research, please cite our paper: 7 | ``` 8 | @article{blumenkamp2020adversarial, 9 | title={The Emergence of Adversarial Communication in Multi-Agent Reinforcement Learning}, 10 | author={Blumenkamp, Jan and Prorok, Amanda}, 11 | journal={Conference on Robot Learning (CoRL)}, 12 | year={2020} 13 | } 14 | ``` 15 | 16 | ## Introduction 17 | While RLLib provides a multi-agent API, this API only supports non-differentiable communication channels, which means that agents can only communicate to each other through their actions and observations. Recently, Graph Neural Networks have gained traction and were demonstrated to be effective to learn decentralized homogeneous control policies for muli-agent systems [[1]](https://arxiv.org/abs/2012.14906), [[2]](https://arxiv.org/abs/2008.02616), [[3]](https://arxiv.org/abs/1912.06095). While it would be possible to learn a policy that utilizes such a GNN in RLLib with a centralized model and an environment that encapsulates all observations and actions at once, it will be hard to learn anything due to the credit assignment problem, since the default OpenAI Gym interface only allows to return a single scalar reward per time step. 18 | 19 | This repository uses the `info` dict that is returned from the `step` function to propagate local per-agent rewards into RLLib. The rewards are extracted for each agent and the trajectories are processed and discounted individually. The PPO loss is then computed for each trajectory and eventually summed so that the NN can be optimized for all agents. This procedure is explained and derived in [[2]](https://arxiv.org/abs/2008.02616). 20 | 21 | ## Environment 22 | The environment is a (grid) world populated with agents that can move in 2D space, either discrete (into one of its four neighboring cells) or continuous (dx and dy). Each agent's state consists of a 2D position and goal, both of which are the local observation for each agent. The agents are ordered, and each agent is only rewarded for moving to the next agent's goal. This can only be achieved with a shared, differentiable communication channel. 23 | 24 | ## Model 25 | Instead of a GNN, we use a simple centralized feedforward network. Such a model assumes a fixed number of agents and will not work for most more complicated scenarios, but for the purpose of demonstrating the efficiency of the implemented trainer it is sufficient. The model is summarized in this visualization: 26 | 27 | ![overview image](https://raw.githubusercontent.com/janblumenkamp/rllib_multi_agent_demo/master/img/ray_multi_agent_demo_model_env.png "Overview") 28 | 29 | ## Setup 30 | ``` 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | ## Results 35 | 36 | ![overview image](https://raw.githubusercontent.com/janblumenkamp/rllib_multi_agent_demo/master/img/results_rewards.svg "Overview") 37 | 38 | | Type | Comm | Command | Reward | Ep len | 39 | |------|------|---------------------------------------------------------------|--------|--------| 40 | | Cont | yes | `python train.py --action_space continuous` | -0.75 | 3.0 | 41 | | Dis | yes | `python train.py --action_space discrete` | -3.95 | 4.7 | 42 | | Cont | no | `python train.py --action_space continuous --disable_sharing` | -15.3 | 8.5 | 43 | | Dis | no | `python train.py --action_space discrete --disable_sharing` | -21.2 | 8.9 | 44 | 45 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/rllib_differentiable_comms/f6d4e6085f9d73bbab3c3b785ac75ad011f390a7/__init__.py -------------------------------------------------------------------------------- /environment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import gym 3 | from gym import spaces 4 | from gym.utils import seeding, EzPickle 5 | 6 | X = 1 7 | Y = 0 8 | 9 | 10 | class BaseAgent: 11 | def __init__(self, index, world_shape, random_state): 12 | self.goal = None 13 | self.pose = None 14 | self.reached_goal = None 15 | self.random_state = random_state 16 | self.index = index 17 | self.world_shape = world_shape 18 | self.reset() 19 | 20 | def is_valid_pose(self, p): 21 | return all([0 <= p[c] < self.world_shape[c] for c in [Y, X]]) 22 | 23 | def update_pose(self, delta_p): 24 | desired_pos = self.pose + delta_p 25 | if self.is_valid_pose(desired_pos): 26 | self.pose = desired_pos 27 | 28 | def get_obs(self): 29 | return np.hstack([self.goal, self.pose]) 30 | 31 | def reset(self): 32 | raise NotImplementedError() 33 | 34 | def step(self, action): 35 | raise NotImplementedError() 36 | 37 | 38 | class DiscreteAgent(BaseAgent): 39 | def __init__(self, *args, **kwargs): 40 | super().__init__(*args, **kwargs) 41 | 42 | def reset(self): 43 | self.pose = self.random_state.randint((0, 0), self.world_shape) 44 | self.goal = self.random_state.randint((0, 0), self.world_shape) 45 | self.reached_goal = False 46 | return 0 47 | 48 | def step(self, action): 49 | delta_pose = { 50 | 0: [0, 0], 51 | 1: [0, 1], 52 | 2: [0, -1], 53 | 3: [-1, 0], 54 | 4: [1, 0], 55 | }[action] 56 | self.update_pose(delta_pose) 57 | return self.get_obs() 58 | 59 | 60 | class ContinuousAgent(BaseAgent): 61 | def __init__(self, *args, **kwargs): 62 | super().__init__(*args, **kwargs) 63 | 64 | def reset(self): 65 | self.pose = self.random_state.uniform((0, 0), self.world_shape) 66 | self.goal = self.random_state.randint((0, 0), self.world_shape) 67 | self.reached_goal = False 68 | return [0, 0] 69 | 70 | def step(self, action): 71 | action_clipped = np.clip(action, -1, 1) 72 | self.update_pose(action_clipped) 73 | return self.get_obs() 74 | 75 | 76 | class InvalidConfigParameter(Exception): 77 | """Raised when a configuration parameter is invalid""" 78 | 79 | pass 80 | 81 | 82 | class DemoMultiAgentEnv(gym.Env, EzPickle): 83 | def __init__(self, env_config): 84 | EzPickle.__init__(self) 85 | self.timestep = None 86 | self.goal_poses = None 87 | self.random_state = None 88 | self.seed(1) 89 | 90 | self.cfg = env_config 91 | 92 | self.observation_space = spaces.Dict( 93 | { 94 | "agents": spaces.Tuple( 95 | ( 96 | spaces.Box( 97 | low=0.0, 98 | high=float(max(self.cfg["world_shape"])), 99 | shape=(4,), 100 | ), 101 | ) 102 | * self.cfg["n_agents"] 103 | ), 104 | "state": spaces.Box( 105 | low=0.0, high=1.0, shape=self.cfg["world_shape"] + [2] 106 | ), 107 | } 108 | ) 109 | if self.cfg["action_space"] == "discrete": 110 | agent_action_space = spaces.Discrete(5) 111 | agent_class = DiscreteAgent 112 | elif self.cfg["action_space"] == "continuous": 113 | agent_action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=float) 114 | agent_class = ContinuousAgent 115 | else: 116 | raise InvalidConfigParameter("Invalid action_space") 117 | self.action_space = spaces.Tuple((agent_action_space,) * self.cfg["n_agents"]) 118 | 119 | self.agents = [ 120 | agent_class(i, self.cfg["world_shape"], self.random_state) 121 | for i in range(self.cfg["n_agents"]) 122 | ] 123 | 124 | self.reset() 125 | 126 | def seed(self, seed=None): 127 | self.random_state, seed = seeding.np_random(seed) 128 | return [seed] 129 | 130 | def reset(self): 131 | reset_actions = [agent.reset() for agent in self.agents] 132 | self.goal_poses = [agent.goal for agent in self.agents] 133 | self.timestep = 0 134 | return self.step(reset_actions)[0] 135 | 136 | def step(self, actions): 137 | self.timestep += 1 138 | 139 | observations = [ 140 | agent.step(action) for agent, action in zip(self.agents, actions) 141 | ] 142 | 143 | rewards = {} 144 | # shift each agent's goal so that the shared NN has to be used to solve the problem 145 | shifted_poses = ( 146 | self.goal_poses[self.cfg["goal_shift"] :] 147 | + self.goal_poses[: self.cfg["goal_shift"]] 148 | ) 149 | for i, (agent, goal) in enumerate(zip(self.agents, shifted_poses)): 150 | rewards[i] = -1 if not agent.reached_goal else 0 151 | if not agent.reached_goal and np.linalg.norm(agent.pose - goal) < 1: 152 | rewards[i] = 1 153 | agent.reached_goal = True 154 | 155 | all_reached_goal = all([agent.reached_goal for agent in self.agents]) 156 | max_timestep_reached = self.timestep == self.cfg["max_episode_len"] 157 | done = all_reached_goal or max_timestep_reached 158 | 159 | global_state = np.zeros(self.cfg["world_shape"] + [2], dtype=np.uint8) 160 | for agent in self.agents: 161 | global_state[int(agent.pose[Y]), int(agent.pose[X]), 0] = 1 162 | global_state[int(agent.goal[Y]), int(agent.goal[X]), 1] = 1 163 | 164 | obs = {"agents": tuple(observations), "state": global_state} 165 | info = {"rewards": rewards} 166 | all_rewards = sum(rewards.values()) 167 | 168 | return obs, all_rewards, done, info 169 | 170 | def render(self, mode="human"): 171 | top_bot_margin = " " + "-" * self.cfg["world_shape"][Y] * 2 + "\n" 172 | r = top_bot_margin 173 | for y in range(self.cfg["world_shape"][Y]): 174 | r += "|" 175 | for x in range(self.cfg["world_shape"][X]): 176 | c = " " 177 | for i, agent in enumerate(self.agents): 178 | if np.all(agent.pose.astype(int) == np.array([y, x])): 179 | c = "x" if agent.reached_goal else str(i) 180 | if np.all(agent.goal == np.array([y, x])): 181 | c = "abcdefghijklmnopqrstuvwxyz"[i] 182 | r += c + " " 183 | r += "|\n" 184 | r += top_bot_margin 185 | print(r) 186 | -------------------------------------------------------------------------------- /img/ray_multi_agent_demo_model_env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/proroklab/rllib_differentiable_comms/f6d4e6085f9d73bbab3c3b785ac75ad011f390a7/img/ray_multi_agent_demo_model_env.png -------------------------------------------------------------------------------- /img/results_rewards.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from ray.rllib.models.modelv2 import ModelV2 2 | from ray.rllib.models.torch.misc import ( 3 | SlimConv2d, 4 | ) 5 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 6 | from ray.rllib.utils import try_import_torch 7 | from ray.rllib.utils.annotations import override 8 | 9 | torch, nn = try_import_torch() 10 | 11 | 12 | class Model(TorchModelV2, nn.Module): 13 | def __init__( 14 | self, 15 | obs_space, 16 | action_space, 17 | num_outputs, 18 | model_config, 19 | name, 20 | encoder_out_features, 21 | shared_nn_out_features_per_agent, 22 | value_state_encoder_cnn_out_features, 23 | share_observations, 24 | use_beta, 25 | ): 26 | TorchModelV2.__init__( 27 | self, obs_space, action_space, num_outputs, model_config, name 28 | ) 29 | nn.Module.__init__(self) 30 | 31 | self.encoder_out_features = encoder_out_features 32 | self.shared_nn_out_features_per_agent = shared_nn_out_features_per_agent 33 | self.value_state_encoder_cnn_out_features = value_state_encoder_cnn_out_features 34 | self.share_observations = share_observations 35 | self.use_beta = use_beta 36 | 37 | self.n_agents = len(obs_space.original_space["agents"]) 38 | self.outputs_per_agent = int(num_outputs / self.n_agents) 39 | 40 | obs_shape = obs_space.original_space["agents"][0].shape 41 | state_shape = obs_space.original_space["state"].shape 42 | 43 | ########### 44 | # Action NN 45 | 46 | self.action_encoder = nn.Sequential( 47 | nn.Linear(obs_shape[0], 32), 48 | nn.ReLU(), 49 | nn.Linear(32, self.encoder_out_features), 50 | nn.ReLU(), 51 | ) 52 | 53 | share_n_agents = self.n_agents if self.share_observations else 1 54 | self.action_shared = nn.Sequential( 55 | nn.Linear(self.encoder_out_features * share_n_agents, 64), 56 | nn.ReLU(), 57 | nn.Linear(64, self.shared_nn_out_features_per_agent * share_n_agents), 58 | nn.ReLU(), 59 | ) 60 | 61 | post_logits = [ 62 | nn.Linear(self.shared_nn_out_features_per_agent, 32), 63 | nn.ReLU(), 64 | nn.Linear(32, self.outputs_per_agent), 65 | ] 66 | nn.init.xavier_uniform_(post_logits[-1].weight) 67 | nn.init.constant_(post_logits[-1].bias, 0) 68 | self.action_output = nn.Sequential(*post_logits) 69 | 70 | ########### 71 | # Value NN 72 | 73 | self.value_encoder = nn.Sequential( 74 | nn.Linear(obs_shape[0], 32), 75 | nn.ReLU(), 76 | nn.Linear(32, self.encoder_out_features), 77 | nn.ReLU(), 78 | ) 79 | 80 | self.value_encoder_state = nn.Sequential( 81 | SlimConv2d( 82 | 2, 8, 3, 2, 1 83 | ), # in_channels, out_channels, kernel, stride, padding 84 | SlimConv2d( 85 | 8, 8, 3, 2, 1 86 | ), # in_channels, out_channels, kernel, stride, padding 87 | SlimConv2d(8, self.value_state_encoder_cnn_out_features, 3, 2, 1), 88 | nn.Flatten(1, -1), 89 | ) 90 | 91 | self.value_shared = nn.Sequential( 92 | nn.Linear( 93 | self.encoder_out_features * self.n_agents 94 | + self.value_state_encoder_cnn_out_features, 95 | 64, 96 | ), 97 | nn.ReLU(), 98 | nn.Linear(64, self.shared_nn_out_features_per_agent * self.n_agents), 99 | nn.ReLU(), 100 | ) 101 | 102 | value_post_logits = [ 103 | nn.Linear(self.shared_nn_out_features_per_agent, 32), 104 | nn.ReLU(), 105 | nn.Linear(32, 1), 106 | ] 107 | nn.init.xavier_uniform_(value_post_logits[-1].weight) 108 | nn.init.constant_(value_post_logits[-1].bias, 0) 109 | self.value_output = nn.Sequential(*value_post_logits) 110 | 111 | @override(ModelV2) 112 | def forward(self, input_dict, state, seq_lens): 113 | batch_size = input_dict["obs"]["state"].shape[0] 114 | device = input_dict["obs"]["state"].device 115 | 116 | action_feature_map = torch.zeros( 117 | batch_size, self.n_agents, self.encoder_out_features 118 | ).to(device) 119 | value_feature_map = torch.zeros( 120 | batch_size, self.n_agents, self.encoder_out_features 121 | ).to(device) 122 | for i in range(self.n_agents): 123 | agent_obs = input_dict["obs"]["agents"][i] 124 | action_feature_map[:, i] = self.action_encoder(agent_obs) 125 | value_feature_map[:, i] = self.value_encoder(agent_obs) 126 | value_state_features = self.value_encoder_state( 127 | input_dict["obs"]["state"].permute(0, 3, 1, 2) 128 | ) 129 | 130 | if self.share_observations: 131 | # We have a big common shared center NN so that all agents have access to the encoded observations of all agents 132 | action_shared_features = self.action_shared( 133 | action_feature_map.view( 134 | batch_size, self.n_agents * self.encoder_out_features 135 | ) 136 | ).view(batch_size, self.n_agents, self.shared_nn_out_features_per_agent) 137 | else: 138 | # Each agent only has access to its own local observation 139 | action_shared_features = torch.empty( 140 | batch_size, self.n_agents, self.shared_nn_out_features_per_agent 141 | ).to(device) 142 | for i in range(self.n_agents): 143 | action_shared_features[:, i] = self.action_shared( 144 | action_feature_map[:, i] 145 | ) 146 | 147 | value_shared_features = self.value_shared( 148 | torch.cat( 149 | [ 150 | value_feature_map.view( 151 | batch_size, self.n_agents * self.encoder_out_features 152 | ), 153 | value_state_features, 154 | ], 155 | dim=1, 156 | ) 157 | ).view(batch_size, self.n_agents, self.shared_nn_out_features_per_agent) 158 | 159 | outputs = torch.empty(batch_size, self.n_agents, self.outputs_per_agent).to( 160 | device 161 | ) 162 | values = torch.empty(batch_size, self.n_agents).to(device) 163 | 164 | for i in range(self.n_agents): 165 | outputs[:, i] = self.action_output(action_shared_features[:, i]) 166 | values[:, i] = self.value_output(value_shared_features[:, i]).squeeze(1) 167 | 168 | self._cur_value = values 169 | 170 | return outputs.view(batch_size, self.n_agents * self.outputs_per_agent), state 171 | 172 | @override(ModelV2) 173 | def value_function(self): 174 | assert self._cur_value is not None, "must call forward() first" 175 | return self._cur_value 176 | -------------------------------------------------------------------------------- /multi_action_dist.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import tree 4 | from ray.rllib.models.torch.torch_action_dist import ( 5 | TorchMultiActionDistribution, 6 | TorchCategorical, 7 | TorchBeta, 8 | TorchDiagGaussian, 9 | TorchMultiCategorical, 10 | ) 11 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 12 | from ray.rllib.utils.annotations import override 13 | from ray.rllib.utils.framework import try_import_torch 14 | from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space 15 | from ray.rllib.utils.typing import TensorType, List, Union 16 | 17 | torch, nn = try_import_torch() 18 | 19 | 20 | class InvalidActionSpace(Exception): 21 | """Raised when the action space is invalid""" 22 | 23 | pass 24 | 25 | 26 | # Override the TorchBeta class to allow for vectors on 27 | class TorchBetaMulti(TorchBeta): 28 | def __init__( 29 | self, 30 | inputs: List[TensorType], 31 | model: TorchModelV2, 32 | low: Union[float, TensorType] = 0.0, 33 | high: Union[float, TensorType] = 1.0, 34 | ): 35 | super().__init__(inputs, model) 36 | device = self.inputs.device 37 | self.low = torch.tensor(low).to(device) 38 | self.high = torch.tensor(high).to(device) 39 | 40 | assert len(self.low.shape) == 1, "Low vector of beta must have only 1 dimension" 41 | assert ( 42 | len(self.high.shape) == 1 43 | ), "High vector of beta must have only 1 dimension" 44 | assert ( 45 | self.low.shape[0] == 1 or self.low.shape[0] == self.inputs.shape[-1] // 2 46 | ), f"Size of low vector of beta must be either 1 ore match the size of the input, got {self.low.shape[0]} expected {self.inputs.shape[-1]}" 47 | assert ( 48 | self.high.shape[0] == 1 or self.high.shape[0] == self.inputs.shape[-1] // 2 49 | ), f"Size of high vector of beta must be either 1 ore match the size of the input, got {self.high.shape[0]} expected {self.inputs.shape[-1]}" 50 | 51 | 52 | class TorchHomogeneousMultiActionDistribution(TorchMultiActionDistribution): 53 | @override(TorchMultiActionDistribution) 54 | def __init__(self, inputs, model, *, child_distributions, input_lens, action_space): 55 | # Skip calling parent constructor, instead call grandparent constructor because 56 | # we do not want to compute the self.flat_child_distributions in the super constructor 57 | super(TorchMultiActionDistribution, self).__init__(inputs, model) 58 | 59 | if not isinstance(inputs, torch.Tensor): 60 | inputs = torch.from_numpy(inputs) 61 | if isinstance(model, TorchModelV2): 62 | inputs = inputs.to(next(model.parameters()).device) 63 | 64 | self.action_space_struct = get_base_struct_from_space(action_space) 65 | 66 | self.input_lens = tree.flatten(input_lens) 67 | split_inputs = torch.split(inputs, self.input_lens, dim=1) 68 | self.flat_child_distributions = [] 69 | for agent_action_space, agent_inputs in zip( 70 | self.action_space_struct, split_inputs 71 | ): 72 | if isinstance(agent_action_space, gym.spaces.Box): 73 | assert len(agent_action_space.shape) == 1 74 | if model.use_beta: 75 | self.flat_child_distributions.append( 76 | TorchBetaMulti( 77 | agent_inputs, 78 | model, 79 | low=agent_action_space.low, 80 | high=agent_action_space.high, 81 | ) 82 | ) 83 | else: 84 | self.flat_child_distributions.append( 85 | TorchDiagGaussian(agent_inputs, model) 86 | ) 87 | elif isinstance(agent_action_space, gym.spaces.Discrete): 88 | self.flat_child_distributions.append( 89 | TorchCategorical(agent_inputs, model) 90 | ) 91 | elif isinstance(agent_action_space, gym.spaces.MultiDiscrete): 92 | self.flat_child_distributions.append( 93 | TorchMultiCategorical( 94 | agent_inputs, model, action_space=agent_action_space 95 | ) 96 | ) 97 | else: 98 | raise InvalidActionSpace( 99 | "Expect gym.spaces.Box, gym.spaces.Discrete or gym.spaces.MultiDiscrete action space for each agent" 100 | ) 101 | 102 | @override(TorchMultiActionDistribution) 103 | def logp(self, x): 104 | if isinstance(x, np.ndarray): 105 | x = torch.Tensor(x) 106 | assert isinstance(x, torch.Tensor) 107 | # x.shape = (BATCH, num_agents * action_size) 108 | logps = [] 109 | assert len(self.flat_child_distributions) == len(self.action_space_struct) 110 | i = 0 111 | for agent_distribution in self.flat_child_distributions: 112 | if isinstance(agent_distribution, TorchCategorical): 113 | a_size = 1 114 | x_agent = x[:, i].int() 115 | elif isinstance(agent_distribution, TorchMultiCategorical): 116 | a_size = int(np.prod(agent_distribution.action_space.shape)) 117 | x_agent = x[:, i : (i + a_size)].int() 118 | else: 119 | sample = agent_distribution.sample() 120 | # Cover Box(shape=()) case. 121 | if len(sample.shape) == 1: 122 | a_size = 1 123 | else: 124 | a_size = sample.size()[1] 125 | x_agent = x[:, i : (i + a_size)] 126 | 127 | i += a_size 128 | agent_logps = agent_distribution.logp(x_agent) 129 | if len(agent_logps.shape) > 1: 130 | agent_logps = torch.sum(agent_logps, dim=1) 131 | 132 | # agent_logps shape (BATCH_SIZE, 1) 133 | logps.append(agent_logps) 134 | 135 | # logps shape (BATCH_SIZE, NUM_AGENTS) 136 | return torch.stack(logps, axis=-1) 137 | 138 | @override(TorchMultiActionDistribution) 139 | def entropy(self): 140 | entropies = [] 141 | for d in self.flat_child_distributions: 142 | agent_entropy = d.entropy() 143 | if len(agent_entropy.shape) > 1: 144 | agent_entropy = torch.sum(agent_entropy, dim=1) 145 | entropies.append(agent_entropy) 146 | return torch.stack(entropies, axis=-1) 147 | 148 | @override(TorchMultiActionDistribution) 149 | def sampled_action_logp(self): 150 | return torch.stack( 151 | [d.sampled_action_logp() for d in self.flat_child_distributions], axis=-1 152 | ) 153 | 154 | @override(TorchMultiActionDistribution) 155 | def kl(self, other): 156 | kls = [] 157 | for d, o in zip(self.flat_child_distributions, other.flat_child_distributions): 158 | agent_kl = d.kl(o) 159 | if len(agent_kl.shape) > 1: 160 | agent_kl = torch.sum(agent_kl, dim=1) 161 | kls.append(agent_kl) 162 | return torch.stack( 163 | kls, 164 | axis=-1, 165 | ) 166 | -------------------------------------------------------------------------------- /multi_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | PyTorch's policy class used for PPO. 3 | """ 4 | # Copyright (c) 2023. 5 | # ProrokLab (https://www.proroklab.org/) 6 | # All rights reserved. 7 | 8 | import logging 9 | from abc import ABC 10 | from typing import Dict 11 | from typing import List, Optional, Union 12 | from typing import Type 13 | 14 | import gym 15 | import numpy as np 16 | import ray 17 | from ray.rllib.agents.ppo import PPOTrainer 18 | from ray.rllib.algorithms.algorithm import Algorithm 19 | from ray.rllib.algorithms.ppo import PPOTorchPolicy 20 | from ray.rllib.algorithms.ppo.ppo_tf_policy import validate_config 21 | from ray.rllib.evaluation import Episode 22 | from ray.rllib.evaluation.postprocessing import Postprocessing, compute_advantages 23 | from ray.rllib.execution import synchronous_parallel_sample 24 | from ray.rllib.execution.common import ( 25 | _check_sample_batch_type, 26 | ) 27 | from ray.rllib.execution.train_ops import ( 28 | train_one_step, 29 | multi_gpu_train_one_step, 30 | ) 31 | from ray.rllib.models import ModelV2, ActionDistribution 32 | from ray.rllib.policy.policy import Policy 33 | from ray.rllib.policy.sample_batch import ( 34 | SampleBatch, 35 | DEFAULT_POLICY_ID, 36 | concat_samples, 37 | ) 38 | from ray.rllib.policy.torch_mixins import ( 39 | LearningRateSchedule, 40 | KLCoeffMixin, 41 | EntropyCoeffSchedule, 42 | ) 43 | from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 44 | from ray.rllib.utils.annotations import override 45 | from ray.rllib.utils.framework import try_import_torch 46 | from ray.rllib.utils.metrics import ( 47 | NUM_AGENT_STEPS_SAMPLED, 48 | NUM_ENV_STEPS_SAMPLED, 49 | SYNCH_WORKER_WEIGHTS_TIMER, 50 | ) 51 | from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY 52 | from ray.rllib.utils.numpy import convert_to_numpy 53 | from ray.rllib.utils.torch_utils import ( 54 | apply_grad_clipping, 55 | ) 56 | from ray.rllib.utils.torch_utils import ( 57 | warn_if_infinite_kl_divergence, 58 | explained_variance, 59 | sequence_mask, 60 | ) 61 | from ray.rllib.utils.typing import AgentID, TensorType, ResultDict 62 | from ray.rllib.utils.typing import PolicyID, SampleBatchType 63 | 64 | torch, nn = try_import_torch() 65 | 66 | logger = logging.getLogger(__name__) 67 | 68 | 69 | class InvalidActionSpace(Exception): 70 | """Raised when the action space is invalid""" 71 | 72 | pass 73 | 74 | 75 | def standardized(array: np.ndarray): 76 | """Normalize the values in an array. 77 | 78 | Args: 79 | array (np.ndarray): Array of values to normalize. 80 | 81 | Returns: 82 | array with zero mean and unit standard deviation. 83 | """ 84 | return (array - array.mean(axis=0, keepdims=True)) / array.std( 85 | axis=0, keepdims=True 86 | ).clip(min=1e-4) 87 | 88 | 89 | def standardize_fields(samples: SampleBatchType, fields: List[str]) -> SampleBatchType: 90 | """Standardize fields of the given SampleBatch""" 91 | _check_sample_batch_type(samples) 92 | wrapped = False 93 | 94 | if isinstance(samples, SampleBatch): 95 | samples = samples.as_multi_agent() 96 | wrapped = True 97 | 98 | for policy_id in samples.policy_batches: 99 | batch = samples.policy_batches[policy_id] 100 | for field in fields: 101 | if field in batch: 102 | batch[field] = standardized(batch[field]) 103 | 104 | if wrapped: 105 | samples = samples.policy_batches[DEFAULT_POLICY_ID] 106 | 107 | return samples 108 | 109 | 110 | def compute_gae_for_sample_batch( 111 | policy: Policy, 112 | sample_batch: SampleBatch, 113 | other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, 114 | episode: Optional[Episode] = None, 115 | ) -> SampleBatch: 116 | """Adds GAE (generalized advantage estimations) to a trajectory. 117 | The trajectory contains only data from one episode and from one agent. 118 | - If `config.batch_mode=truncate_episodes` (default), sample_batch may 119 | contain a truncated (at-the-end) episode, in case the 120 | `config.rollout_fragment_length` was reached by the sampler. 121 | - If `config.batch_mode=complete_episodes`, sample_batch will contain 122 | exactly one episode (no matter how long). 123 | New columns can be added to sample_batch and existing ones may be altered. 124 | Args: 125 | policy (Policy): The Policy used to generate the trajectory 126 | (`sample_batch`) 127 | sample_batch (SampleBatch): The SampleBatch to postprocess. 128 | other_agent_batches (Optional[Dict[PolicyID, SampleBatch]]): Optional 129 | dict of AgentIDs mapping to other agents' trajectory data (from the 130 | same episode). NOTE: The other agents use the same policy. 131 | episode (Optional[MultiAgentEpisode]): Optional multi-agent episode 132 | object in which the agents operated. 133 | Returns: 134 | SampleBatch: The postprocessed, modified SampleBatch (or a new one). 135 | """ 136 | n_agents = len(policy.action_space) 137 | 138 | if sample_batch[SampleBatch.INFOS].dtype == "float32": 139 | # The trajectory view API will pass populate the info dict with a np.zeros((ROLLOUT_SIZE,)) 140 | # array in the first call, in that case the dtype will be float32, and we 141 | # ignore it by assignining it to all agents 142 | samplebatch_infos_rewards = concat_samples( 143 | [ 144 | SampleBatch( 145 | { 146 | str(i): sample_batch[SampleBatch.REWARDS].copy() 147 | for i in range(n_agents) 148 | } 149 | ) 150 | ] 151 | ) 152 | 153 | else: 154 | # For regular calls, we extract the rewards from the info 155 | # dict into the samplebatch_infos_rewards dict, which now holds the rewards 156 | # for all agents as dict. 157 | 158 | # sample_batch[SampleBatch.INFOS] = list of len ROLLOUT_SIZE of which every element is 159 | # {'rewards': {0: -0.077463925, 1: -0.0029145998, 2: -0.08233316}} if there are 3 agents 160 | 161 | samplebatch_infos_rewards = concat_samples( 162 | [ 163 | SampleBatch({str(k): [np.float32(v)] for k, v in s["rewards"].items()}) 164 | for s in sample_batch[SampleBatch.INFOS] 165 | # s = {'rewards': {0: -0.077463925, 1: -0.0029145998, 2: -0.08233316}} if there are 3 agents 166 | ] 167 | ) 168 | 169 | # samplebatch_infos_rewards = SampleBatch(ROLLOUT_SIZE: ['0', '1', '2']) if there are 3 agents 170 | # (i.e. it has ROLLOUT_SIZE entries with keys '0','1','2') 171 | 172 | if not isinstance(policy.action_space, gym.spaces.tuple.Tuple): 173 | raise InvalidActionSpace("Expect tuple action space") 174 | 175 | keys_to_overwirte = [ 176 | SampleBatch.REWARDS, 177 | SampleBatch.VF_PREDS, 178 | Postprocessing.ADVANTAGES, 179 | Postprocessing.VALUE_TARGETS, 180 | ] 181 | 182 | original_batch = sample_batch.copy() 183 | 184 | # We prepare the sample batch to contain the agent batches 185 | for k in keys_to_overwirte: 186 | sample_batch[k] = np.zeros((len(original_batch), n_agents), dtype=np.float32) 187 | 188 | if original_batch[SampleBatch.DONES][-1]: 189 | all_values = None 190 | else: 191 | input_dict = original_batch.get_single_step_input_dict( 192 | policy.model.view_requirements, index="last" 193 | ) 194 | all_values = policy._value(**input_dict) 195 | 196 | # Create the sample_batch for each agent 197 | for key in samplebatch_infos_rewards.keys(): 198 | agent_index = int(key) 199 | sample_batch_agent = original_batch.copy() 200 | sample_batch_agent[SampleBatch.REWARDS] = samplebatch_infos_rewards[key] 201 | sample_batch_agent[SampleBatch.VF_PREDS] = original_batch[SampleBatch.VF_PREDS][ 202 | :, agent_index 203 | ] 204 | 205 | if all_values is None: 206 | last_r = 0.0 207 | # Trajectory has been truncated -> last r=VF estimate of last obs. 208 | else: 209 | last_r = ( 210 | all_values[agent_index].item() 211 | if policy.config["use_gae"] 212 | else all_values 213 | ) 214 | 215 | # Adds the policy logits, VF preds, and advantages to the batch, 216 | # using GAE ("generalized advantage estimation") or not. 217 | sample_batch_agent = compute_advantages( 218 | sample_batch_agent, 219 | last_r, 220 | policy.config["gamma"], 221 | policy.config["lambda"], 222 | use_gae=policy.config["use_gae"], 223 | use_critic=policy.config.get("use_critic", True), 224 | ) 225 | 226 | for k in keys_to_overwirte: 227 | sample_batch[k][:, agent_index] = sample_batch_agent[k] 228 | 229 | return sample_batch 230 | 231 | 232 | def ppo_surrogate_loss( 233 | policy: Policy, 234 | model: ModelV2, 235 | dist_class: Type[ActionDistribution], 236 | train_batch: SampleBatch, 237 | ) -> Union[TensorType, List[TensorType]]: 238 | """Constructs the loss for Proximal Policy Objective. 239 | Args: 240 | policy (Policy): The Policy to calculate the loss for. 241 | model (ModelV2): The Model to calculate the loss for. 242 | dist_class (Type[ActionDistribution]): The action distr. class. 243 | train_batch (SampleBatch): The training data. 244 | Returns: 245 | Union[TensorType, List[TensorType]]: A single loss tensor or a list 246 | of loss tensors. 247 | """ 248 | logits, state = model(train_batch) 249 | # logits has shape (BATCH, num_agents * num_outputs_per_agent) 250 | curr_action_dist = dist_class(logits, model) 251 | 252 | # RNN case: Mask away 0-padded chunks at end of time axis. 253 | if state: 254 | B = len(train_batch[SampleBatch.SEQ_LENS]) 255 | max_seq_len = logits.shape[0] // B 256 | mask = sequence_mask( 257 | train_batch[SampleBatch.SEQ_LENS], 258 | max_seq_len, 259 | time_major=model.is_time_major(), 260 | ) 261 | mask = torch.reshape(mask, [-1]) 262 | num_valid = torch.sum(mask) 263 | 264 | def reduce_mean_valid(t): 265 | return torch.sum(t[mask]) / num_valid 266 | 267 | # non-RNN case: No masking. 268 | else: 269 | mask = None 270 | reduce_mean_valid = torch.mean 271 | 272 | prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) 273 | # train_batch[SampleBatch.ACTIONS] has shape (BATCH, num_agents * action_size) 274 | logp_ratio = torch.exp( 275 | curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) 276 | - train_batch[SampleBatch.ACTION_LOGP] 277 | ) 278 | 279 | use_kl = policy.config["kl_coeff"] > 0.0 280 | if use_kl: 281 | action_kl = prev_action_dist.kl(curr_action_dist) 282 | else: 283 | action_kl = torch.tensor(0.0, device=logp_ratio.device) 284 | 285 | curr_entropies = curr_action_dist.entropy() 286 | 287 | # Compute a value function loss. 288 | if policy.config["use_critic"]: 289 | value_fn_out = model.value_function() 290 | else: 291 | value_fn_out = torch.tensor(0.0, device=logp_ratio.device) 292 | 293 | loss_data = [] 294 | n_agents = len(policy.action_space) 295 | for i in range(n_agents): 296 | 297 | surrogate_loss = torch.min( 298 | train_batch[Postprocessing.ADVANTAGES][..., i] * logp_ratio[..., i], 299 | train_batch[Postprocessing.ADVANTAGES][..., i] 300 | * torch.clamp( 301 | logp_ratio[..., i], 302 | 1 - policy.config["clip_param"], 303 | 1 + policy.config["clip_param"], 304 | ), 305 | ) 306 | 307 | # Compute a value function loss. 308 | if policy.config["use_critic"]: 309 | agent_value_fn_out = value_fn_out[..., i] 310 | vf_loss = torch.pow( 311 | agent_value_fn_out - train_batch[Postprocessing.VALUE_TARGETS][..., i], 312 | 2.0, 313 | ) 314 | vf_loss_clipped = torch.clamp(vf_loss, 0, policy.config["vf_clip_param"]) 315 | mean_vf_loss = reduce_mean_valid(vf_loss_clipped) 316 | # Ignore the value function. 317 | else: 318 | agent_value_fn_out = torch.tensor(0.0).to(surrogate_loss.device) 319 | vf_loss_clipped = mean_vf_loss = torch.tensor(0.0).to(surrogate_loss.device) 320 | 321 | total_loss = ( 322 | -surrogate_loss 323 | + policy.config["vf_loss_coeff"] * vf_loss_clipped 324 | - policy.entropy_coeff * curr_entropies[..., i] 325 | ) 326 | 327 | # Add mean_kl_loss if necessary. 328 | if use_kl: 329 | mean_kl_loss = reduce_mean_valid(action_kl[..., i]) 330 | total_loss += policy.kl_coeff * mean_kl_loss 331 | # TODO smorad: should we do anything besides warn? Could discard KL term 332 | # for this update 333 | warn_if_infinite_kl_divergence(policy, mean_kl_loss) 334 | else: 335 | mean_kl_loss = torch.tensor(0.0, device=logp_ratio.device) 336 | 337 | total_loss = reduce_mean_valid(total_loss) 338 | mean_policy_loss = reduce_mean_valid(-surrogate_loss) 339 | mean_entropy = reduce_mean_valid(curr_entropies[..., i]) 340 | vf_explained_var = explained_variance( 341 | train_batch[Postprocessing.VALUE_TARGETS][..., i], agent_value_fn_out 342 | ) 343 | 344 | # Store stats in policy for stats_fn. 345 | loss_data.append( 346 | { 347 | "total_loss": total_loss, 348 | "mean_policy_loss": mean_policy_loss, 349 | "mean_vf_loss": mean_vf_loss, 350 | "mean_entropy": mean_entropy, 351 | "mean_kl": mean_kl_loss, 352 | "vf_explained_var": vf_explained_var, 353 | } 354 | ) 355 | 356 | aggregation = torch.mean 357 | total_loss = aggregation(torch.stack([o["total_loss"] for o in loss_data])) 358 | 359 | model.tower_stats["total_loss"] = total_loss 360 | model.tower_stats["mean_policy_loss"] = aggregation( 361 | torch.stack([o["mean_policy_loss"] for o in loss_data]) 362 | ) 363 | model.tower_stats["mean_vf_loss"] = aggregation( 364 | torch.stack([o["mean_vf_loss"] for o in loss_data]) 365 | ) 366 | model.tower_stats["vf_explained_var"] = aggregation( 367 | torch.stack([o["vf_explained_var"] for o in loss_data]) 368 | ) 369 | model.tower_stats["mean_entropy"] = aggregation( 370 | torch.stack([o["mean_entropy"] for o in loss_data]) 371 | ) 372 | model.tower_stats["mean_kl_loss"] = aggregation( 373 | torch.stack([o["mean_kl"] for o in loss_data]) 374 | ) 375 | 376 | return total_loss 377 | 378 | 379 | class MultiAgentValueNetworkMixin: 380 | """Assigns the `_value()` method to a TorchPolicy. 381 | 382 | This way, Policy can call `_value()` to get the current VF estimate on a 383 | single(!) observation (as done in `postprocess_trajectory_fn`). 384 | Note: When doing this, an actual forward pass is being performed. 385 | This is different from only calling `model.value_function()`, where 386 | the result of the most recent forward pass is being used to return an 387 | already calculated tensor. 388 | """ 389 | 390 | def __init__(self, config): 391 | # When doing GAE, we need the value function estimate on the 392 | # observation. 393 | if config["use_gae"]: 394 | # Input dict is provided to us automatically via the Model's 395 | # requirements. It's a single-timestep (last one in trajectory) 396 | # input_dict. 397 | def value(**input_dict): 398 | """This is exactly the as in PPOTorchPolicy, 399 | but that one calls .item() on self.model.value_function()[0], 400 | which will not work for us since our value function returns 401 | multiple values. Instead, we call .item() in 402 | compute_gae_for_sample_batch above. 403 | """ 404 | input_dict = SampleBatch(input_dict) 405 | input_dict = self._lazy_tensor_dict(input_dict) 406 | model_out, _ = self.model(input_dict) 407 | # [0] = remove the batch dim. 408 | return self.model.value_function()[0] 409 | # When not doing GAE, we do not require the value function's output. 410 | 411 | # When not doing GAE, we do not require the value function's output. 412 | else: 413 | 414 | def value(*args, **kwargs): 415 | return 0.0 416 | 417 | self._value = value 418 | 419 | 420 | class MultiPPOTorchPolicy(PPOTorchPolicy, MultiAgentValueNetworkMixin): 421 | def __init__(self, observation_space, action_space, config): 422 | config = dict(ray.rllib.algorithms.ppo.ppo.PPOConfig().to_dict(), **config) 423 | # TODO: Move into Policy API, if needed at all here. Why not move this into 424 | # `PPOConfig`?. 425 | validate_config(config) 426 | 427 | TorchPolicyV2.__init__( 428 | self, 429 | observation_space, 430 | action_space, 431 | config, 432 | max_seq_len=config["model"]["max_seq_len"], 433 | ) 434 | 435 | # Only difference from ray code 436 | MultiAgentValueNetworkMixin.__init__(self, config) 437 | LearningRateSchedule.__init__(self, config["lr"], config["lr_schedule"]) 438 | EntropyCoeffSchedule.__init__( 439 | self, config["entropy_coeff"], config["entropy_coeff_schedule"] 440 | ) 441 | KLCoeffMixin.__init__(self, config) 442 | self.grad_gnorm = 0 443 | 444 | # TODO: Don't require users to call this manually. 445 | self._initialize_loss_from_dummy_batch() 446 | 447 | @override(PPOTorchPolicy) 448 | def loss(self, model, dist_class, train_batch): 449 | return ppo_surrogate_loss(self, model, dist_class, train_batch) 450 | 451 | @override(PPOTorchPolicy) 452 | def postprocess_trajectory( 453 | self, sample_batch, other_agent_batches=None, episode=None 454 | ): 455 | # Do all post-processing always with no_grad(). 456 | # Not using this here will introduce a memory leak 457 | # in torch (issue #6962). 458 | # TODO: no_grad still necessary? 459 | with torch.no_grad(): 460 | return compute_gae_for_sample_batch( 461 | self, sample_batch, other_agent_batches, episode 462 | ) 463 | 464 | @override(PPOTorchPolicy) 465 | def extra_grad_process(self, local_optimizer, loss): 466 | grad_gnorm = apply_grad_clipping(self, local_optimizer, loss) 467 | if "grad_gnorm" in grad_gnorm: 468 | self.grad_gnorm = grad_gnorm["grad_gnorm"] 469 | return grad_gnorm 470 | 471 | @override(TorchPolicyV2) 472 | def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: 473 | return convert_to_numpy( 474 | { 475 | "cur_kl_coeff": self.kl_coeff, 476 | "cur_lr": self.cur_lr, 477 | "total_loss": torch.mean( 478 | torch.stack(self.get_tower_stats("total_loss")) 479 | ), 480 | "policy_loss": torch.mean( 481 | torch.stack(self.get_tower_stats("mean_policy_loss")) 482 | ), 483 | "vf_loss": torch.mean( 484 | torch.stack(self.get_tower_stats("mean_vf_loss")) 485 | ), 486 | "vf_explained_var": torch.mean( 487 | torch.stack(self.get_tower_stats("vf_explained_var")) 488 | ), 489 | "kl": torch.mean(torch.stack(self.get_tower_stats("mean_kl_loss"))), 490 | "entropy": torch.mean( 491 | torch.stack(self.get_tower_stats("mean_entropy")) 492 | ), 493 | "entropy_coeff": self.entropy_coeff, 494 | "grad_gnorm": self.grad_gnorm, 495 | } 496 | ) 497 | 498 | 499 | class MultiPPOTrainer(PPOTrainer, ABC): 500 | @override(PPOTrainer) 501 | def get_default_policy_class(self, config): 502 | return MultiPPOTorchPolicy 503 | 504 | @override(PPOTrainer) 505 | def training_step(self) -> ResultDict: 506 | # Collect SampleBatches from sample workers until we have a full batch. 507 | if self._by_agent_steps: 508 | assert False 509 | train_batch = synchronous_parallel_sample( 510 | worker_set=self.workers, max_agent_steps=self.config["train_batch_size"] 511 | ) 512 | else: 513 | train_batch = synchronous_parallel_sample( 514 | worker_set=self.workers, max_env_steps=self.config["train_batch_size"] 515 | ) 516 | train_batch = train_batch.as_multi_agent() 517 | self._counters[NUM_AGENT_STEPS_SAMPLED] += train_batch.agent_steps() 518 | self._counters[NUM_ENV_STEPS_SAMPLED] += train_batch.env_steps() 519 | 520 | # Standardize advantage 521 | train_batch = standardize_fields(train_batch, ["advantages"]) 522 | # Train 523 | if self.config["simple_optimizer"]: 524 | assert False 525 | train_results = train_one_step(self, train_batch) 526 | else: 527 | train_results = multi_gpu_train_one_step(self, train_batch) 528 | 529 | global_vars = { 530 | "timestep": self._counters[NUM_AGENT_STEPS_SAMPLED], 531 | } 532 | 533 | # Update weights - after learning on the local worker - on all remote 534 | # workers. 535 | if self.workers.remote_workers(): 536 | with self._timers[SYNCH_WORKER_WEIGHTS_TIMER]: 537 | self.workers.sync_weights(global_vars=global_vars) 538 | 539 | # For each policy: update KL scale and warn about possible issues 540 | for policy_id, policy_info in train_results.items(): 541 | # Update KL loss with dynamic scaling 542 | # for each (possibly multiagent) policy we are training 543 | kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl") 544 | self.get_policy(policy_id).update_kl(kl_divergence) 545 | 546 | # Warn about excessively high value function loss 547 | scaled_vf_loss = ( 548 | self.config["vf_loss_coeff"] * policy_info[LEARNER_STATS_KEY]["vf_loss"] 549 | ) 550 | policy_loss = policy_info[LEARNER_STATS_KEY]["policy_loss"] 551 | if scaled_vf_loss > 100: 552 | logger.warning( 553 | "The magnitude of your value function loss for policy: {} is " 554 | "extremely large ({}) compared to the policy loss ({}). This " 555 | "can prevent the policy from learning. Consider scaling down " 556 | "the VF loss by reducing vf_loss_coeff, or disabling " 557 | "vf_share_layers.".format(policy_id, scaled_vf_loss, policy_loss) 558 | ) 559 | # Warn about bad clipping configs. 560 | train_batch.policy_batches[policy_id].set_get_interceptor(None) 561 | mean_reward = train_batch.policy_batches[policy_id]["rewards"].mean() 562 | if mean_reward > self.config["vf_clip_param"]: 563 | self.warned_vf_clip = True 564 | logger.warning( 565 | f"The mean reward returned from the environment is {mean_reward}" 566 | f" but the vf_clip_param is set to {self.config['vf_clip_param']}." 567 | f" Consider increasing it for policy: {policy_id} to improve" 568 | " value function convergence." 569 | ) 570 | 571 | # Update global vars on local worker as well. 572 | self.workers.local_worker().set_global_vars(global_vars) 573 | 574 | return train_results 575 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ray[rllib]==2.1.0 2 | torch==1.13.0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import ray 3 | from ray import tune 4 | 5 | # from ray.tune.integration.wandb import WandbLoggerCallback 6 | from ray.tune.registry import register_env 7 | 8 | from environment import DemoMultiAgentEnv 9 | from model import Model 10 | from ray.rllib.models import ModelCatalog 11 | from multi_trainer import MultiPPOTrainer 12 | from multi_action_dist import TorchHomogeneousMultiActionDistribution 13 | 14 | 15 | def train( 16 | share_observations=True, use_beta=True, action_space="discrete", goal_shift=1 17 | ): 18 | ray.init() 19 | 20 | register_env("demo_env", lambda config: DemoMultiAgentEnv(config)) 21 | ModelCatalog.register_custom_model("model", Model) 22 | ModelCatalog.register_custom_action_dist( 23 | "hom_multi_action", TorchHomogeneousMultiActionDistribution 24 | ) 25 | 26 | tune.run( 27 | MultiPPOTrainer, 28 | checkpoint_freq=1, 29 | keep_checkpoints_num=1, 30 | local_dir="/tmp", 31 | # callbacks=[WandbLoggerCallback( 32 | # project="", 33 | # api_key_file="", 34 | # log_config=True 35 | # )], 36 | stop={"training_iteration": 30}, 37 | config={ 38 | "framework": "torch", 39 | "env": "demo_env", 40 | "kl_coeff": 0.0, 41 | "lambda": 0.95, 42 | "clip_param": 0.2, 43 | "entropy_coeff": 0.01, 44 | "train_batch_size": 10000, 45 | "rollout_fragment_length": 1250, 46 | "sgd_minibatch_size": 2048, 47 | "num_sgd_iter": 16, 48 | "num_gpus": 1, 49 | "num_workers": 8, 50 | "num_envs_per_worker": 1, 51 | "lr": 5e-4, 52 | "gamma": 0.99, 53 | "batch_mode": "truncate_episodes", 54 | "observation_filter": "NoFilter", 55 | "model": { 56 | "custom_model": "model", 57 | "custom_action_dist": "hom_multi_action", 58 | "custom_model_config": { 59 | "encoder_out_features": 8, 60 | "shared_nn_out_features_per_agent": 8, 61 | "value_state_encoder_cnn_out_features": 16, 62 | "share_observations": share_observations, 63 | "use_beta": use_beta, 64 | }, 65 | }, 66 | "env_config": { 67 | "world_shape": [5, 5], 68 | "n_agents": 3, 69 | "max_episode_len": 10, 70 | "action_space": action_space, 71 | "goal_shift": goal_shift, 72 | }, 73 | }, 74 | ) 75 | 76 | 77 | if __name__ == "__main__": 78 | parser = argparse.ArgumentParser( 79 | description="RLLib multi-agent with shared NN demo." 80 | ) 81 | parser.add_argument( 82 | "--action_space", 83 | default="discrete", 84 | const="discrete", 85 | nargs="?", 86 | choices=["continuous", "discrete"], 87 | help="Train with continuous or discrete action space", 88 | ) 89 | parser.add_argument( 90 | "--disable_sharing", 91 | action="store_true", 92 | help="Do not instantiate shared central NN for sharing information", 93 | ) 94 | parser.add_argument( 95 | "--disable_beta", 96 | action="store_true", 97 | help="Use a gaussian distribution instead of the default beta distribution", 98 | ) 99 | parser.add_argument( 100 | "--goal_shift", 101 | type=int, 102 | default=1, 103 | choices=range(0, 2), 104 | help="Goal shift offset (0 means that each agent moves to its own goal, 1 to its neighbor, etc.)", 105 | ) 106 | 107 | args = parser.parse_args() 108 | train( 109 | share_observations=not args.disable_sharing, 110 | use_beta=not args.disable_beta, 111 | action_space=args.action_space, 112 | goal_shift=args.goal_shift, 113 | ) 114 | --------------------------------------------------------------------------------