├── README.md ├── components ├── action_selectors.py ├── buffers │ ├── __init__.py │ └── replay_buffer.py ├── make_env.py ├── misc.py ├── normalizers.py └── wrappers │ ├── episode_statistics.py │ ├── obervation_wrappers.py │ ├── reward_wrappers.py │ ├── shared_observation_wrappers.py │ ├── state_wrappers.py │ ├── time_limit.py │ └── wrappers.py ├── config ├── algos │ ├── q.yaml │ └── qmix.yaml ├── comm │ ├── disc.yaml │ └── tarmac.yaml └── default.yaml ├── envs ├── __init__.py ├── ad_hoc │ ├── ad_hoc.py │ ├── ad_hoc_entities.py │ └── ad_hoc_layouts.py ├── chan_models.py ├── common.py ├── heatmap.py └── multi_agent_env.py ├── exp_queue.py ├── learners ├── __init__.py ├── base_learner.py └── q_learner.py ├── modules ├── activations.py ├── agents │ ├── __init__.py │ ├── comm_agent.py │ ├── customized_agents.py │ └── recurrent_agent.py ├── basics.py ├── comm │ ├── __init__.py │ ├── disc.py │ └── tarmac.py ├── critics │ ├── __init__.py │ └── customized_q.py ├── dueling.py ├── encoders │ ├── __init__.py │ ├── flat_enc.py │ └── rel_enc.py ├── gn_blks.py ├── graph_nn.py ├── mixers │ ├── __init__.py │ └── qmix.py └── utils.py ├── policies ├── __init__.py └── shared_policy.py ├── process_csv.py ├── run.py └── runners ├── __init__.py ├── base_runner.py └── episode_runner.py /README.md: -------------------------------------------------------------------------------- 1 | # cross_layer_opt_with_grl 2 | Code implementation for our paper "Decentralized Routing and Radio Resource Allocation in Wireless Ad Hoc Networks via Graph Reinforcement Learning" 3 | 4 | While the main function of training loop is in `run.py`, you should take `exp_queue.py` as the main file to run. It calls `run.py` with various setups and collects results of multiple experiments (candidates). These results can be finally visualized to curves in the figures from the manuscript. 5 | 6 | The reference to our paper is 7 | 8 | X. Zhang et al., "Decentralized Routing and Radio Resource Allocation in Wireless Ad Hoc Networks via Graph Reinforcement Learning," in IEEE Transactions on Cognitive Communications and Networking, doi: 10.1109/TCCN.2024.3360517. 9 | 10 | If you have any questions, please contact me. I would try to offer help as long as I can. 11 | -------------------------------------------------------------------------------- /components/action_selectors.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from abc import abstractmethod 3 | 4 | import random 5 | import torch as th 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from torch.distributions import Categorical, OneHotCategorical 10 | import torch.nn.functional as F 11 | 12 | REGISTRY = {} 13 | 14 | 15 | class BaseActionSelector: 16 | """Base class of action selector""" 17 | 18 | def __init__(self, args): 19 | pass 20 | 21 | @ abstractmethod 22 | def select_actions(self, logits: Tensor, avail_logits: Optional[Tensor], t: Optional[int], mode: str): 23 | """Selects actions from logits.""" 24 | raise NotImplementedError 25 | 26 | 27 | class EpsilonGreedyActionSelector(BaseActionSelector): 28 | """Epsilon-greedy action selector to balance between exploration and exploitation 29 | 30 | Multi-discrete action space is accepted. 31 | """ 32 | 33 | def __init__(self, args): 34 | super(EpsilonGreedyActionSelector, self).__init__(args) 35 | self.eps_start = args.eps_start 36 | self.eps_end = args.eps_end 37 | self.eps_anneal_time = args.eps_anneal_time 38 | 39 | self.is_multi_discrete = args.is_multi_discrete 40 | if self.is_multi_discrete: 41 | self.nvec = args.nvec 42 | 43 | def schedule(self, t: int) -> float: 44 | return max(self.eps_end, -(self.eps_start - self.eps_end) / self.eps_anneal_time * t + self.eps_start) 45 | 46 | def select_actions(self, logits: Tensor, avail_logits: Optional[Tensor] = None, t: Optional[int] = None, 47 | mode: str = 'explore'): 48 | # Fill available logits if not provided. 49 | if avail_logits is None: 50 | avail_logits = th.ones_like(logits) 51 | # Mask unavailable actions 52 | masked_logits = logits.clone() 53 | masked_logits[avail_logits == 0] = -float("inf") 54 | # Choose actions following the epsilon-greedy strategy. 55 | 56 | if not self.is_multi_discrete: 57 | greedy_actions = th.argmax(masked_logits, 1) # Greedy actions 58 | rand_actions = Categorical(avail_logits).sample().long() # Random samples from available actions 59 | 60 | else: 61 | # Split inputs into spaces. 62 | masked_logits_per_space = th.split(masked_logits, split_size_or_sections=self.nvec, dim=-1) 63 | avail_logits_per_space = th.split(avail_logits, split_size_or_sections=self.nvec, dim=-1) 64 | # Greedily select action for each action space. 65 | greedy_actions = th.stack([th.argmax(chunk, 1) for chunk in masked_logits_per_space]).T 66 | # Randomly select from available actions for each action space. 67 | rand_actions = th.stack([Categorical(chunk).sample().long() for chunk in avail_logits_per_space]).T 68 | 69 | # Get action according to mode. 70 | if mode == 'rand': # Randomly actions from available logits 71 | return rand_actions 72 | elif mode == 'test': # Greedy actions 73 | return greedy_actions 74 | elif mode == 'explore': 75 | eps_thres = self.schedule(t) # epsilon threshold 76 | # Determine whether to pick randon actions or greedy ones. 77 | pick_rand = (th.rand_like(logits[..., 0]) < eps_thres).long() 78 | if self.is_multi_discrete: # Multiple spaces requires an additional dimension. 79 | pick_rand = th.unsqueeze(pick_rand, -1) 80 | # Pick either random or greedy actions. 81 | picked_actions = pick_rand * rand_actions + (1 - pick_rand) * greedy_actions 82 | return picked_actions 83 | else: 84 | raise KeyError("Invalid mode of action selection.") 85 | 86 | 87 | REGISTRY['epsilon_greedy'] = EpsilonGreedyActionSelector 88 | 89 | 90 | class CategoricalActionSelector(BaseActionSelector): 91 | """Categorical action selector for stochastic discrete policies""" 92 | 93 | def select_actions(self, logits: Tensor, avail_logits: Optional[Tensor] = None, t: Optional[int] = None, 94 | mode: str = 'explore'): 95 | # Fill available logits if not provided. 96 | if avail_logits is None: 97 | avail_logits = th.ones_like(logits) 98 | 99 | # Mask unavailable actions. 100 | masked_logits = logits.clone() 101 | masked_logits[avail_logits == 0] = -float("inf") 102 | 103 | # Get action according to mode. 104 | if mode == 'rand': 105 | rand_actions = Categorical(avail_logits).sample().long() # Random samples from available actions 106 | return rand_actions 107 | elif mode == 'explore': 108 | probs = F.softmax(masked_logits, dim=-1) # Convert logits to probs. 109 | sampled_actions = Categorical(probs=probs).sample() # Get action samples from categorical distributions. 110 | return sampled_actions.unsqueeze(-1) 111 | elif mode == 'test': 112 | return th.argmax(masked_logits, dim=-1, keepdim=True) # Greedy actions 113 | else: 114 | raise KeyError("Invalid mode of action selection.") 115 | 116 | 117 | REGISTRY['categorical'] = CategoricalActionSelector 118 | 119 | 120 | if __name__ == '__main__': 121 | from types import SimpleNamespace as SN 122 | import numpy as np 123 | import torch.nn.functional as F 124 | 125 | # args = SN(**dict(eps_start=1, eps_end=0.05, anneal_time=50000, temperature=0.1)) 126 | # logits = th.rand(3, 5) 127 | # avail_actions = th.randint(0, 2, logits.size()) 128 | # avail_actions[:, 3:] = 0 129 | # 130 | # action_selector = GumbelSoftmaxMultinomialActionSelector(args) 131 | # actions = action_selector.select_actions(logits, 10000, avail_actions, test_mode=True, explore=False) 132 | # print(actions) 133 | 134 | # a = 0.1 * th.tensor([0.5, 0.2, 0.3]) 135 | # a[0] = 0 136 | # a = th.rand(3, 4) 137 | # print(a) 138 | # print(th.argmax(a, dim=1)) 139 | # # 140 | # m = Categorical(a) 141 | # actions = m.sample() 142 | # print(actions) 143 | # print(actions.view(3, -1)) 144 | # onehot_acs = F.one_hot(actions, num_classes=a.shape[-1]) 145 | # print(onehot_acs) 146 | nvec = np.array([3, 5]) 147 | args = SN(**dict(eps_start=1, eps_end=0.05, eps_anneal_time=50000), nvec=nvec.tolist(), is_multi_discrete=False) 148 | action_selector = EpsilonGreedyActionSelector(args) 149 | 150 | batch_size = 10 151 | logits = th.rand(batch_size, nvec.sum()) 152 | acts = action_selector.select_actions(logits=logits, t=1000) 153 | 154 | 155 | 156 | -------------------------------------------------------------------------------- /components/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from components.buffers.replay_buffer import ReplayBuffer 4 | REGISTRY['replay_buffer'] = ReplayBuffer 5 | -------------------------------------------------------------------------------- /components/buffers/replay_buffer.py: -------------------------------------------------------------------------------- 1 | import random 2 | from collections import deque 3 | 4 | from components.misc import * 5 | 6 | 7 | class ReplayBuffer: 8 | """Replay buffer storing sequences of transitions.""" 9 | 10 | def __init__(self, args): 11 | 12 | self.pre_decision_fields = set(args.pre_decision_fields) # Field names before agent decision 13 | self.post_decision_fields = set(args.post_decision_fields) # Field names after agent decision 14 | self.fields = self.pre_decision_fields.union(self.post_decision_fields) # Overall fields 15 | self.fields.add('filled') 16 | 17 | self.capacity = args.buffer_size # Total number of data sequences that can be held by memory 18 | self.memory = deque(maxlen=self.capacity) # Memory holding samples 19 | self.data_chunk_len = args.data_chunk_len # Maximum length of data sequences 20 | 21 | self.sequence = None # Data sequence holding up-to-date transitions 22 | self.ptr = None # Recorder of data sequence length 23 | self._reset_sequence() 24 | 25 | def _reset_sequence(self) -> None: 26 | """cleans up the data sequence.""" 27 | self.sequence = {k: [] for k in self.fields} 28 | self.ptr = 0 29 | 30 | def insert(self, transition): 31 | """Stores a transition into memory. A transition is first held by data sequence. 32 | When maximum length is reached, contents of data sequence is stored to memory. 33 | """ 34 | 35 | # When maximum sequence length is reached, 36 | if self.ptr == self.data_chunk_len: 37 | # Append the pre-decision data beyond the last timestep to data sequence. 38 | for k in self.pre_decision_fields: 39 | self.sequence[k].append(transition.get(k, '')) 40 | # Move data sequence to memory. 41 | self.memory.append(self.sequence) 42 | # Clear the sequence and reset pointer. 43 | self._reset_sequence() 44 | # Pseudo transition is no longer added to the beginning of next sequence. 45 | if not transition.get('filled'): 46 | return 47 | 48 | # Store data specified by fields. 49 | # Note that pseudo transition is stored if not appended to the end of sequence. 50 | for k, v in transition.items(): 51 | if k in self.fields: 52 | self.sequence[k].append(v) 53 | self.ptr += 1 # A complete transition is stored. 54 | 55 | def recall(self, batch_size: int): 56 | """Selects a random batch of samples.""" 57 | assert len(self) >= batch_size, "Samples are insufficient." 58 | samples = random.sample(self.memory, batch_size) # List of samples 59 | batched_samples = {k: [] for k in self.fields} # Dict holding batch of samples. 60 | 61 | # Construct input sequences. 62 | for t in range(self.data_chunk_len): 63 | for k in self.fields: 64 | batched_samples[k].append(cat([samples[idx][k][t] for idx in range(batch_size)])) 65 | 66 | # Add pre-decision data beyond the last timestep for bootstrapping. 67 | for k in self.pre_decision_fields: 68 | batched_samples[k].append(cat([samples[idx][k][self.data_chunk_len] for idx in range(batch_size)])) 69 | 70 | return batched_samples 71 | 72 | def can_sample(self, batch_size: int) -> bool: 73 | """Whether sufficient samples are available.""" 74 | return batch_size <= len(self) 75 | 76 | def __len__(self): 77 | return len(self.memory) 78 | 79 | def __repr__(self): 80 | return f"ReplayBuffer, holding {len(self)}/{self.capacity} sequences." 81 | 82 | 83 | if __name__ == '__main__': 84 | a = {'apple'} 85 | b = {'pear', 'banana', 'apple'} 86 | buffer = ReplayBuffer(a, b, 10, 10) 87 | 88 | a = [1, 2] 89 | from typing import Iterable 90 | print(isinstance(a, Iterable)) -------------------------------------------------------------------------------- /components/make_env.py: -------------------------------------------------------------------------------- 1 | from components.wrappers.obervation_wrappers import REGISTRY as obs_REGISTRY 2 | from components.wrappers.shared_observation_wrappers import REGISTRY as sobs_REGISTRY 3 | from components.wrappers.state_wrappers import REGISTRY as stat_REGISTRY 4 | from components.wrappers.reward_wrappers import REGISTRY as rew_REGISTRY 5 | from components.wrappers.time_limit import TimeLimit 6 | from components.wrappers.episode_statistics import EpisodeStatistics 7 | 8 | 9 | def make_env(env_fn, args): 10 | """Instantiates a raw environment and wraps it.""" 11 | # Instantiate the env. 12 | env = env_fn() 13 | 14 | # Select format of local/shared observations and global states. 15 | env = obs_REGISTRY[args.obs](env) 16 | if args.shared_obs is not None: 17 | env = sobs_REGISTRY[args.shared_obs](env) 18 | if args.state is not None: 19 | env = stat_REGISTRY[args.state](env) 20 | 21 | # Enable multi-agent communication. 22 | if args.agent == 'comm': 23 | env = obs_REGISTRY['comm'](env) 24 | 25 | # Even if reward normalization is unused, reward filter is used to track episode returns. 26 | if args.normalize_reward: 27 | env = rew_REGISTRY['norm'](env, args) 28 | 29 | # Add proper time limit if requested. 30 | if args.use_time_limit: 31 | env = TimeLimit(env, args) 32 | 33 | # Record statistics of episodes. 34 | env = EpisodeStatistics(env) 35 | 36 | return env 37 | 38 | 39 | if __name__ == '__main__': 40 | from functools import partial 41 | from types import SimpleNamespace as SN 42 | args = SN(**dict(obs='flat', shared_obs=None, state=None, comm=None, normalize_reward=True, max_reward_limit=10, 43 | use_time_limit=True, alert_episode_limit=True)) 44 | from envs import REGISTRY as env_REGISTRY 45 | env_id = 'mpe' 46 | env_kwargs = {'scenario_name': 'simple'} 47 | env_fn = partial(env_REGISTRY[env_id], **env_kwargs) 48 | env = make_env(env_fn, args) 49 | env.reset() 50 | _, _, info = env.step([0]) 51 | print(env.get_obs()) 52 | print(env.get_shared_obs()) 53 | print(env.get_state()) 54 | print(info) -------------------------------------------------------------------------------- /components/misc.py: -------------------------------------------------------------------------------- 1 | r"""Miscellaneous stuff used in learning""" 2 | 3 | from typing import Union, Optional 4 | 5 | import numpy as np 6 | from numpy import ndarray 7 | import torch as th 8 | from torch import Tensor 9 | from torch.autograd import Variable 10 | from torch.distributions.categorical import Categorical 11 | import torch.nn.functional as F 12 | import dgl 13 | from dgl import DGLGraph 14 | 15 | from gym.spaces.discrete import Discrete 16 | 17 | 18 | # --- Env --- 19 | 20 | def get_random_actions(avail_actions): 21 | """Randomly draws discrete actions for agents in the env. 22 | 23 | Note that only discrete action space is supported. 24 | """ 25 | rand_act_list = [] 26 | for avail_agent_actions in avail_actions: 27 | rand_actions = Categorical(th.tensor(avail_agent_actions, dtype=th.int)).sample().long() 28 | rand_act_list.append(rand_actions.item()) 29 | return rand_act_list 30 | 31 | # --- Shape manipulation --- 32 | 33 | 34 | def cat(chunks: list[Union[Tensor, DGLGraph]]) -> Union[Tensor, DGLGraph]: 35 | """Concatenates data held by a list.""" 36 | 37 | if isinstance(chunks[0], Tensor): 38 | return th.cat(chunks) 39 | elif isinstance(chunks[0], DGLGraph): 40 | return dgl.batch(chunks) 41 | else: 42 | raise TypeError("Unrecognised data type.") 43 | 44 | 45 | def split(data: Tensor, n_agents: int) -> list[Tensor]: 46 | """Splits concatenated Tensor.""" 47 | 48 | assert data.size(0) % n_agents == 0, "Cannot split data due to inconsistent shape." 49 | chunks = th.split(data, n_agents) 50 | return list(chunks) 51 | 52 | 53 | # --- Discrete actions handling --- 54 | 55 | 56 | def get_masked_categorical(logits: Tensor, avail_actions: Optional[Tensor] = None) -> Categorical: 57 | # probs = F.softmax(logits, dim=-1) 58 | # if avail_actions is not None: 59 | # probs[avail_actions == 0] = 0 60 | # masked_c = Categorical(probs=probs) 61 | 62 | if avail_actions is not None: 63 | logits[avail_actions == 0] = -1e10 64 | masked_categorical = Categorical(logits=logits) 65 | return masked_categorical 66 | 67 | 68 | def onehot_from_logits(logits: Tensor, avail_logits: Optional[Tensor] = None, eps: float = 0.0): 69 | """Returns one-hot samples of actions from logits using epsilon-greedy strategy.""" 70 | # Mask unavailable actions. 71 | # TODO: This operation is in-place, which collapses back-propagation. 72 | if avail_logits is not None: 73 | logits[avail_logits == 0] = -1e10 74 | # Get best actions in one-hot form. 75 | argmax_acs = (logits == logits.max(-1, keepdim=True)[0]).float() 76 | if eps == 0.0: 77 | return argmax_acs 78 | else: 79 | # Get random actions in one-hot form. 80 | rand_acs = Variable(th.eye(logits.shape[1])[[np.random.choice(range(logits.shape[1]), size=logits.shape[0])]], requires_grad=False) 81 | # Chooses between best and random actions using epsilon greedy 82 | return th.stack([argmax_acs[i] if r > eps else rand_acs[i] for i, r in enumerate(th.rand(logits.shape[0]))]) 83 | 84 | 85 | def sample_gumbel(shape, eps=1e-20, tens_type=th.FloatTensor): 86 | """Samples from Gumbel(0, 1).""" 87 | U = Variable(tens_type(*shape).uniform_(), requires_grad=False) 88 | return -th.log(-th.log(U + eps) + eps) 89 | 90 | 91 | def gumbel_softmax_sample(logits, avail_logits, temperature): 92 | """Draws a sample from the Gumbel-Softmax distribution.""" 93 | y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)).to(logits.device) 94 | # Mask unavailable actions. 95 | if avail_logits is not None: 96 | y[avail_logits == 0] = -1e10 97 | dim = len(logits.shape) - 1 98 | return F.softmax(y / temperature, dim=dim) 99 | 100 | 101 | def gumbel_softmax(logits: Tensor, avail_logits: Optional[Tensor] = None, temperature: float = 1.0, hard: bool = False): 102 | """Sample from the Gumbel-Softmax distribution and optionally discretize. 103 | Args: 104 | logits: [batch_size, n_class] unnormalized log-probs 105 | avail_logits: Mask giving feasibility of actions in logits 106 | temperature: non-negative scalar 107 | hard: if True, take argmax, but differentiate w.r.t. soft sample y 108 | Returns: 109 | [batch_size, n_class] sample from the Gumbel-Softmax distribution. 110 | If hard=True, then the returned sample will be one-hot, 111 | otherwise it will be a probability distribution that sums to 1 across classes. 112 | """ 113 | y_soft = gumbel_softmax_sample(logits, avail_logits, temperature) 114 | if hard: 115 | y_hard = onehot_from_logits(y_soft) # Get one-hot sample. 116 | # Mask has been applied in gumbel-softmax sampling, re-adding mask in discretization causes an error. 117 | y_out = (y_hard - y_soft).detach() + y_soft 118 | return y_out 119 | else: 120 | return y_soft 121 | 122 | 123 | def onehot_from_actions(actions: Tensor, n_classes: Union[int, list]): 124 | # print(f"actions.size() = {actions.size()}") 125 | if isinstance(n_classes, int): 126 | onehot_acs = F.one_hot(actions, num_classes=n_classes) 127 | elif isinstance(n_classes, list): 128 | acs_per_space = th.split(actions, split_size_or_sections=1, dim=-1) 129 | onehot_acs_per_space = [F.one_hot(a.squeeze(-1), num_classes=n_classes[s]) for s, a in enumerate(acs_per_space)] 130 | onehot_acs = th.cat(onehot_acs_per_space, -1) 131 | # print(f"onehot_acs.size() = {onehot_acs.size()}") 132 | assert onehot_acs.size(-1) == sum(n_classes), "Incorrect dimension of one-hot actions for multi-discrete space." 133 | else: 134 | raise TypeError("Unsupported type of `n_classes` in function `onehot_from_actions`.") 135 | return onehot_acs 136 | 137 | 138 | # --- Functions shared by learners --- 139 | 140 | def get_clipped_linear_decay(total_steps, threshold): 141 | assert 1 > threshold >= 0, "Invalid threshold of linear decay." 142 | return lambda step: max(threshold, 1 - step / total_steps) # Linear decay and then flat 143 | 144 | 145 | def mse_loss(error: Tensor, mask: Optional[Tensor] = None): 146 | """Computes MSE loss from errors.""" 147 | if mask is not None: # Mask is available. 148 | # Mask invalid entries. 149 | mask = mask.expand_as(error) 150 | masked_error = error * mask 151 | # print(f"masked_error.squeeze() = \n{masked_error.squeeze()}") 152 | # Only average valid terms. 153 | return 0.5 * masked_error.pow(2).sum() / mask.sum() 154 | else: 155 | return 0.5 * error.pow(2).mean() 156 | 157 | 158 | def huber_loss(error: Tensor, mask: Optional[Tensor] = None, delta: float = 10.0): 159 | """Computes Huber loss from errors. 160 | See https://en.wikipedia.org/wiki/Huber_loss for more information. 161 | """ 162 | if mask is not None: 163 | # Mask invalid entries. 164 | mask = mask.expand_as(error) 165 | masked_error = error * mask 166 | # Determine loss type using delta as threshold. 167 | l2_rgn = (masked_error.abs() < delta).float() # Entries using MSE loss 168 | l1_rgn = 1 - l2_rgn # Entries using L1 loss 169 | # Compute loss for each entry. 170 | loss_per_element = l2_rgn * 0.5 * masked_error.pow(2) + l1_rgn * delta * (masked_error.abs() - 0.5 * delta) 171 | return loss_per_element.sum() / mask.sum() # Only average valid terms. 172 | 173 | else: 174 | l2_rgn = (error.abs() < delta).float() 175 | l1_rgn = 1 - l2_rgn 176 | loss_per_element = l2_rgn * 0.5 * error.pow(2) + l1_rgn * delta * (error.abs() - 0.5 * delta) 177 | return loss_per_element.mean() # Only average valid terms. 178 | 179 | 180 | def soft_target_update(policy, target, polyak: float): 181 | """Smoothly updates target network from learnt policy via polyak averaging.""" 182 | for p, p_targ in zip(policy.parameters(), target.parameters()): 183 | # In-place operations "mul_", "add_" are used to update target. 184 | p_targ.data.mul_(polyak) 185 | p_targ.data.add_((1 - polyak) * p.data) 186 | -------------------------------------------------------------------------------- /components/normalizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class RunningStat(object): 5 | """Keeps track of first and second moments (mean and variance) of a streaming time series. 6 | Taken from https://github.com/joschu/modular_rl 7 | Math in http://www.johndcook.com/blog/standard_deviation/ 8 | """ 9 | 10 | def __init__(self, shape): 11 | self._n = 0 12 | self._M = np.zeros(shape) 13 | self._S = np.zeros(shape) 14 | 15 | def push(self, x): 16 | x = np.asarray(x) 17 | assert x.shape == self._M.shape 18 | self._n += 1 19 | if self._n == 1: 20 | self._M[...] = x 21 | else: 22 | oldM = self._M.copy() 23 | self._M[...] = oldM + (x - oldM) / self._n 24 | self._S[...] = self._S + (x - oldM) * (x - self._M) 25 | 26 | @property 27 | def n(self): 28 | return self._n 29 | 30 | @property 31 | def mean(self): 32 | return self._M 33 | 34 | @property 35 | def var(self): 36 | return self._S / (self._n - 1) if self._n > 1 else np.square(self._M) 37 | 38 | @property 39 | def std(self): 40 | return np.sqrt(self.var) 41 | 42 | @property 43 | def shape(self): 44 | return self._M.shape 45 | 46 | 47 | class ZFilter: 48 | def __init__(self, shape, center=True, scale=True): 49 | assert shape is not None 50 | self.center = center 51 | self.scale = scale 52 | self.rs = RunningStat(shape) 53 | 54 | def __call__(self, x, **kwargs): 55 | self.rs.push(x) 56 | if self.center: 57 | x = x - self.rs.mean 58 | if self.scale: 59 | if self.center: 60 | x = x / (self.rs.std + 1e-8) 61 | else: 62 | diff = x - self.rs.mean 63 | diff = diff / (self.rs.std + 1e-8) 64 | x = diff + self.rs.mean 65 | return x 66 | -------------------------------------------------------------------------------- /components/wrappers/episode_statistics.py: -------------------------------------------------------------------------------- 1 | from components.wrappers.wrappers import Wrapper 2 | 3 | 4 | class EpisodeStatistics(Wrapper): 5 | """Tracks episode length and returns.""" 6 | 7 | def __init__(self, env): 8 | super(EpisodeStatistics, self).__init__(env) 9 | self.ep_len = None # Length of episode 10 | self.ep_ret = None # Return of episode 11 | 12 | def reset(self): 13 | super(EpisodeStatistics, self).reset() 14 | 15 | # Reset statistics. 16 | self.ep_len = 0 17 | self.ep_ret = 0.0 18 | 19 | def step(self, actions): 20 | rewards, terminated, info = super(EpisodeStatistics, self).step(actions) 21 | 22 | self.ep_len += 1 # One step elapse. 23 | # Since reward filter may be used, use actual rewards as long as they are available. 24 | self.ep_ret += info.get('actual_rewards', rewards).mean() # Average cross agents. 25 | 26 | # When episode terminates, add episode statistics to info. 27 | if terminated: 28 | episode_statistics = dict(EpLen=self.ep_len, EpRet=self.ep_ret) 29 | info.update(episode_statistics) 30 | 31 | return rewards, terminated, info 32 | -------------------------------------------------------------------------------- /components/wrappers/obervation_wrappers.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import numpy as np 3 | import torch as th 4 | from torch import Tensor 5 | import dgl 6 | from dgl import DGLGraph 7 | 8 | from gym.spaces.utils import flatten_space, flatten 9 | from components.wrappers.wrappers import ObservationWrapper 10 | 11 | 12 | REGISTRY = {} 13 | 14 | 15 | class FlatObservation(ObservationWrapper): 16 | def __init__(self, env): 17 | super(FlatObservation, self).__init__(env) 18 | self.observation_space = [flatten_space(self.env.observation_space[i]) for i in range(self.n_agents)] 19 | 20 | def get_obs_size(self): 21 | return [self.observation_space[i].shape[0] for i in range(self.n_agents)] 22 | 23 | def observation(self, obs): 24 | flat_obs = [flatten(self.env.observation_space[i], agent_obs) for i, agent_obs in enumerate(obs)] 25 | return th.as_tensor(np.stack(flat_obs), dtype=th.float32) 26 | 27 | 28 | REGISTRY['flat'] = FlatObservation 29 | 30 | 31 | class RelationalObservation(ObservationWrapper): 32 | def __init__(self, env): 33 | super(RelationalObservation, self).__init__(env) 34 | 35 | def get_obs_size(self): 36 | sizes = [] 37 | for obs_space_agent in self.observation_space: 38 | obs_size = {} 39 | for k in obs_space_agent: 40 | obs_size[k] = obs_space_agent[k].shape[-1] 41 | if k != 'agent': 42 | obs_size[k] -= 1 # Drop visibility. 43 | sizes.append(obs_size) 44 | return sizes 45 | 46 | def observation(self, obs): 47 | rel_obs = [] 48 | for agent_obs in obs: 49 | 50 | data_dict = {('agent', 'talks', 'agent'): ([], [])} 51 | num_nodes_dict = {'agent': 1} 52 | feat = {'agent': th.as_tensor(agent_obs['agent'], dtype=th.float).unsqueeze(0)} 53 | 54 | for k, v in agent_obs.items(): 55 | if k != 'agent': 56 | ent_ids = np.equal(v[:, 0], 1) 57 | n_ents = ent_ids.sum() 58 | data_dict[(k, 'nearby', 'agent')] = (th.arange(n_ents), th.zeros(n_ents, dtype=th.long)) 59 | num_nodes_dict[k] = n_ents 60 | feat[k] = th.as_tensor(agent_obs[k][ent_ids, 1:], dtype=th.float) 61 | 62 | graph = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict) 63 | graph.ndata['feat'] = feat 64 | rel_obs.append(graph) 65 | 66 | return dgl.batch(rel_obs) 67 | 68 | 69 | REGISTRY['rel'] = RelationalObservation 70 | 71 | 72 | class CommWrapper(ObservationWrapper): 73 | def __init__(self, env): 74 | super(CommWrapper, self).__init__(env) 75 | 76 | def get_obs_size(self): 77 | return self.env.get_obs_size() 78 | 79 | def observation(self, obs): 80 | u, v = [], [] 81 | if hasattr(self.env, "get_agent_visibility_matrix"): 82 | agent_vis_mat = self.env.get_agent_visibility_matrix() 83 | else: 84 | agent_vis_mat = np.ones((self.n_agents, self.n_agents), dtype=bool) 85 | for i in range(self.n_agents): 86 | for j in range(self.n_agents): 87 | if agent_vis_mat[i, j]: 88 | u.append(i) 89 | v.append(j) 90 | 91 | if isinstance(obs, Tensor): 92 | # comm_graph = dgl.graph((u, v), num_nodes=self.n_agents) 93 | comm_graph = dgl.heterograph({('agent', 'talks', 'agent'): (u, v)}, num_nodes_dict={'agent': self.n_agents}) 94 | comm_graph.ndata['feat'] = obs 95 | 96 | elif isinstance(obs, DGLGraph): 97 | data_dict = {c_etype: ([], []) for c_etype in obs.canonical_etypes} 98 | data_dict.update({('agent', 'talks', 'agent'): (u, v)}) 99 | num_nodes_dict = {ntype: 0 for ntype in obs.ntypes} 100 | num_nodes_dict.update({'agent': self.n_agents}) 101 | comm_graph = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict) 102 | comm_graph = dgl.merge([obs, comm_graph]) 103 | else: 104 | raise TypeError("Unrecognized type of local observations.") 105 | return comm_graph 106 | 107 | 108 | REGISTRY['comm'] = CommWrapper 109 | 110 | 111 | class GraphObservation(ObservationWrapper): 112 | """Builds a heterograph as observations of all agents in the env. 113 | 114 | Note that this wrapper does not call `.get_shared_obs()` of the env. Instead, it requires: 115 | - method `.get_graph_inputs()` 116 | - property `obs_graph_feats` 117 | Thus it is less inclusive than `RelationalSharedObservation` (which supports any env providing dict shared obs). 118 | """ 119 | 120 | def __init__(self, env): 121 | super(GraphObservation, self).__init__(env) 122 | assert hasattr(self.env, "get_graph_inputs"), "Absence of graph obs callback!" 123 | assert hasattr(self.env, "graph_feats"), "Absence of graph obs feats!" 124 | 125 | def get_obs(self): 126 | obs_relations = self.env.get_graph_inputs() 127 | graph_data = obs_relations['graph_data'] # Define edges 128 | num_nodes_dict = obs_relations['num_nodes_dict'] # Number of nodes 129 | node_feats = obs_relations['ndata'] # Node features 130 | edge_feats = obs_relations.get('edata') # Edge features 131 | 132 | # Create heterograph as observation. 133 | obs_graph = dgl.heterograph(graph_data, num_nodes_dict=num_nodes_dict) 134 | # Add node/edge features. 135 | for ntype in obs_graph.ntypes: 136 | obs_graph.nodes[ntype].data['feat'] = th.as_tensor(node_feats[ntype], dtype=th.float) 137 | if edge_feats is not None: 138 | for etype in obs_graph.etypes: 139 | obs_graph.edges[etype].data['feat'] = th.as_tensor(edge_feats[etype], dtype=th.float) 140 | 141 | # Remove redundant nodes. 142 | obs_graph = dgl.compact_graphs(obs_graph, always_preserve=dict(agent=th.arange(self.n_agents)), 143 | copy_ndata=True, copy_edata=True) 144 | 145 | return obs_graph 146 | 147 | def get_obs_size(self): 148 | return [getattr(self.env, "graph_feats")] * self.n_agents 149 | 150 | 151 | REGISTRY['graph'] = GraphObservation 152 | 153 | -------------------------------------------------------------------------------- /components/wrappers/reward_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from components.wrappers.wrappers import RewardWrapper 3 | from components.normalizers import ZFilter 4 | 5 | REGISTRY = {} 6 | 7 | 8 | class RewardNormalizer(RewardWrapper): 9 | def __init__(self, env, args): 10 | super(RewardNormalizer, self).__init__(env) 11 | reward_shape = self.n_agents 12 | self.reward_normalizer = ZFilter(reward_shape) 13 | 14 | self.clip_range = args.max_reward_limit 15 | 16 | def reward(self, rewards): 17 | rewards = self.reward_normalizer(rewards) 18 | if self.clip_range is not None: 19 | rewards = np.clip(rewards, -self.clip_range, self.clip_range) 20 | return rewards 21 | 22 | 23 | REGISTRY['norm'] = RewardNormalizer 24 | -------------------------------------------------------------------------------- /components/wrappers/shared_observation_wrappers.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import numpy as np 3 | import torch as th 4 | from torch import Tensor 5 | import dgl 6 | from dgl import DGLGraph 7 | 8 | from gym.spaces.utils import flatten_space, flatten 9 | from components.wrappers.wrappers import SharedObservationWrapper 10 | 11 | 12 | REGISTRY = {} 13 | 14 | 15 | class FlatSharedObservation(SharedObservationWrapper): 16 | def __init__(self, env): 17 | super(FlatSharedObservation, self).__init__(env) 18 | self.shared_observation_space = [flatten_space(self.env.shared_observation_space[i]) for i in range(self.n_agents)] 19 | 20 | def get_shared_obs_size(self): 21 | return [self.shared_observation_space[i].shape[0] for i in range(self.n_agents)] 22 | 23 | def shared_observation(self, shared_obs): 24 | flat_shared_obs = [flatten(self.env.shared_observation_space[i], agent_shared_obs) for i, agent_shared_obs in enumerate(shared_obs)] 25 | return th.as_tensor(np.stack(flat_shared_obs), dtype=th.float32) 26 | 27 | 28 | REGISTRY['flat'] = FlatSharedObservation 29 | 30 | 31 | class RelationalSharedObservation(SharedObservationWrapper): 32 | def __init__(self, env): 33 | super(RelationalSharedObservation, self).__init__(env) 34 | 35 | def get_shared_obs_size(self): 36 | sizes = [] 37 | for shared_obs_space_agent in self.shared_observation_space: 38 | shared_obs_size = {} 39 | for k in shared_obs_space_agent: 40 | shared_obs_size[k] = shared_obs_space_agent[k].shape[-1] 41 | sizes.append(shared_obs_size) 42 | return sizes 43 | 44 | def shared_observation(self, shared_obs): 45 | graph_shared_obs = [] 46 | for agent_shared_obs in shared_obs: 47 | 48 | data_dict = {('agent', 'talks', 'agent'): ([], [])} 49 | num_nodes_dict = {'agent': 1} 50 | feat = {'agent': th.as_tensor(agent_shared_obs['agent'], dtype=th.float).unsqueeze(0)} 51 | 52 | for k, v in agent_shared_obs.items(): 53 | if k != 'agent': 54 | n_ents = agent_shared_obs[k].shape[0] 55 | data_dict[(k, 'nearby', 'agent')] = (th.arange(n_ents), th.zeros(n_ents, dtype=th.long)) 56 | num_nodes_dict[k] = n_ents 57 | feat[k] = th.as_tensor(agent_shared_obs[k], dtype=th.float) 58 | 59 | graph = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict) 60 | graph.ndata['feat'] = feat 61 | graph_shared_obs.append(graph) 62 | 63 | return dgl.batch(graph_shared_obs) 64 | 65 | 66 | REGISTRY['rel'] = RelationalSharedObservation 67 | 68 | 69 | class GraphSharedObservation(SharedObservationWrapper): 70 | """Builds a heterograph as shared observations of all agents in the env. 71 | 72 | Note that this wrapper does not call `.get_shared_obs()` of the env. Instead, it requires: 73 | - method `.get_graph_inputs()` 74 | - property `graph_feats` 75 | Thus it is less inclusive than `RelationalSharedObservation` (which supports any env providing dict shared obs). 76 | """ 77 | 78 | def __init__(self, env): 79 | super(GraphSharedObservation, self).__init__(env) 80 | assert hasattr(self.env, "get_graph_inputs"), "Absence of graph shared obs callback!" 81 | assert hasattr(self.env, "graph_feats"), "Absence of graph shared obs feats!" 82 | 83 | def get_shared_obs(self): 84 | shared_obs_relations = self.env.get_graph_inputs() 85 | graph_data = shared_obs_relations['graph_data'] # Define edges 86 | num_nodes_dict = shared_obs_relations['num_nodes_dict'] # Number of nodes 87 | node_feats = shared_obs_relations['ndata'] # Node features 88 | edge_feats = shared_obs_relations.get('edata') # Edge features 89 | 90 | # Create heterograph as shared observation. 91 | shared_obs_graph = dgl.heterograph(graph_data, num_nodes_dict=num_nodes_dict) 92 | # Add node/edge features. 93 | for ntype in shared_obs_graph.ntypes: 94 | shared_obs_graph.nodes[ntype].data['feat'] = th.as_tensor(node_feats[ntype], dtype=th.float) 95 | if edge_feats is not None: 96 | for etype in shared_obs_graph.etypes: 97 | shared_obs_graph.edges[etype].data['feat'] = th.as_tensor(edge_feats[etype], dtype=th.float) 98 | 99 | # Remove redundant nodes. 100 | shared_obs_graph = dgl.compact_graphs(shared_obs_graph, always_preserve=dict(agent=th.arange(self.n_agents)), 101 | copy_ndata=True, copy_edata=True) 102 | # print(shared_obs_graph) 103 | # print(shared_obs_graph.ndata['feat']) 104 | # print(shared_obs_graph.edata['feat']) 105 | return shared_obs_graph 106 | 107 | def get_shared_obs_size(self): 108 | return [getattr(self.env, "graph_feats")] * self.n_agents 109 | 110 | 111 | REGISTRY['graph'] = GraphSharedObservation 112 | 113 | 114 | if __name__ == '__main__': 115 | from envs.ad_hoc.ad_hoc import AdHocEnv 116 | env = AdHocEnv('2flows') 117 | env.reset() 118 | shared_obs = env.get_shared_obs() 119 | env = GraphSharedObservation(env) 120 | shared_obs = env.get_shared_obs() -------------------------------------------------------------------------------- /components/wrappers/state_wrappers.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from gym.spaces.utils import flatten_space, flatten 3 | from envs.multi_agent_env import MultiAgentEnv 4 | from components.wrappers.wrappers import StateWrapper 5 | 6 | 7 | REGISTRY = {} 8 | 9 | 10 | class FlatStateWrapper(StateWrapper): 11 | def __init__(self, env: MultiAgentEnv): 12 | super(FlatStateWrapper, self).__init__(env) 13 | self.state_space = flatten_space(self.env.state_space) 14 | 15 | def state(self, stat): 16 | stat = flatten(self.env.state_space, stat) 17 | stat = th.as_tensor(stat, dtype=th.float) # Covert to Tensor 18 | return th.atleast_2d(stat) # Check dimension 19 | 20 | def get_state_size(self): 21 | return self.state_space.shape[0] 22 | 23 | 24 | REGISTRY['flat'] = FlatStateWrapper 25 | -------------------------------------------------------------------------------- /components/wrappers/time_limit.py: -------------------------------------------------------------------------------- 1 | from components.wrappers.wrappers import Wrapper 2 | 3 | 4 | class TimeLimit(Wrapper): 5 | """Truncates episode with time limit.""" 6 | 7 | def __init__(self, env, args): 8 | super(TimeLimit, self).__init__(env) 9 | 10 | self._max_episode_steps = getattr(args, 'max_episode_steps', None) 11 | if self._max_episode_steps is None: 12 | self._max_episode_steps = self.env.get_env_info()['episode_limit'] 13 | assert self._max_episode_steps is not None, "Maximum step number is not set for wrapper `TimeLimit`." 14 | 15 | self._elapsed_steps = None # Number of elapsed steps in current episode 16 | self._alert_ep_lim = args.alert_episode_limit # Whether arrival of episode limit is alerted in info 17 | 18 | def reset(self): 19 | self._elapsed_steps = 0 20 | self.env.reset() 21 | 22 | def step(self, actions): 23 | self._elapsed_steps += 1 24 | rewards, terminated, info = self.env.step(actions) 25 | info['truncated'] = False 26 | if (self._elapsed_steps == self._max_episode_steps) and not terminated: 27 | terminated = True # Truncate episode. 28 | if self._alert_ep_lim: 29 | info['truncated'] = True # Alert arrival of episode limit into info. 30 | return rewards, terminated, info 31 | 32 | def get_env_info(self): 33 | env_info = self.env.get_env_info() 34 | env_info["episode_limit"] = self._max_episode_steps 35 | return env_info 36 | -------------------------------------------------------------------------------- /components/wrappers/wrappers.py: -------------------------------------------------------------------------------- 1 | from envs.multi_agent_env import MultiAgentEnv 2 | from abc import abstractmethod 3 | 4 | 5 | class Wrapper(object): 6 | def __init__(self, env: MultiAgentEnv): 7 | self.env = env 8 | 9 | def __getattr__(self, name): 10 | if name.startswith("_"): 11 | raise AttributeError(f"attempted to get missing private attribute '{name}'") 12 | return getattr(self.env, name) 13 | 14 | def reset(self): 15 | self.env.reset() 16 | 17 | def step(self, actions): 18 | return self.env.step(actions) 19 | 20 | def close(self): 21 | self.env.close() 22 | 23 | 24 | class ObservationWrapper(Wrapper): 25 | def __init__(self, env): 26 | super(ObservationWrapper, self).__init__(env) 27 | 28 | def get_obs(self): 29 | obs = self.env.get_obs() 30 | obs = self.observation(obs) 31 | return obs 32 | 33 | @ abstractmethod 34 | def get_obs_size(self): 35 | raise NotImplementedError 36 | 37 | @abstractmethod 38 | def observation(self, obs): 39 | raise NotImplementedError 40 | 41 | def get_env_info(self): 42 | env_info = self.env.get_env_info() 43 | env_info.update(dict(obs_shape=self.get_obs_size())) 44 | return env_info 45 | 46 | 47 | class SharedObservationWrapper(Wrapper): 48 | def __init__(self, env): 49 | super(SharedObservationWrapper, self).__init__(env) 50 | 51 | def get_shared_obs(self): 52 | obs = self.env.get_shared_obs() 53 | obs = self.shared_observation(obs) 54 | return obs 55 | 56 | @ abstractmethod 57 | def get_shared_obs_size(self): 58 | raise NotImplementedError 59 | 60 | @abstractmethod 61 | def shared_observation(self, shared_obs): 62 | raise NotImplementedError 63 | 64 | def get_env_info(self): 65 | env_info = self.env.get_env_info() 66 | env_info.update(dict(shared_obs_shape=self.get_shared_obs_size())) 67 | return env_info 68 | 69 | 70 | class StateWrapper(Wrapper): 71 | def __init__(self, env): 72 | super(StateWrapper, self).__init__(env) 73 | 74 | def get_state(self): 75 | stat = self.env.get_state() 76 | stat = self.state(stat) 77 | return stat 78 | 79 | @abstractmethod 80 | def get_state_size(self): 81 | raise NotImplementedError 82 | 83 | @abstractmethod 84 | def state(self, stat): 85 | raise NotImplementedError 86 | 87 | def get_env_info(self): 88 | env_info = self.env.get_env_info() 89 | env_info.update(dict(state_shape=self.get_state_size())) 90 | return env_info 91 | 92 | 93 | class RewardWrapper(Wrapper): 94 | def __init__(self, env): 95 | super(RewardWrapper, self).__init__(env) 96 | 97 | def step(self, actions): 98 | rewards, terminated, info = self.env.step(actions) 99 | 100 | # Record actual rewards. 101 | info.update(dict(actual_rewards=rewards)) 102 | # Filter rewards. 103 | rewards = self.reward(rewards) 104 | 105 | return rewards, terminated, info 106 | 107 | @abstractmethod 108 | def reward(self, rewards): 109 | raise NotImplementedError 110 | -------------------------------------------------------------------------------- /config/algos/q.yaml: -------------------------------------------------------------------------------- 1 | learner: "q" 2 | 3 | # --- Buffer --- 4 | buffer: "replay_buffer" 5 | pre_decision_fields: 6 | - "obs" 7 | - "h" 8 | post_decision_fields: 9 | - "actions" 10 | - "rewards" 11 | - "terminated" 12 | 13 | # --- Action selection --- 14 | action_selector: "epsilon_greedy" # Action selector 15 | eps_start: 1.0 # Initial exploration rate 16 | eps_end: 0.05 # Final exploration rate 17 | eps_anneal_time: 50000 # Number of env steps to anneal exploration rate 18 | 19 | # --- Q-learning extensions --- 20 | use_double_q: True # Whether double Q-learning is used 21 | use_dueling_ar: False # Whether dueling architecture is used 22 | -------------------------------------------------------------------------------- /config/algos/qmix.yaml: -------------------------------------------------------------------------------- 1 | learner: "q" 2 | state: "flat" 3 | 4 | # --- Buffer --- 5 | buffer: "replay_buffer" 6 | pre_decision_fields: 7 | - "obs" 8 | - "h" 9 | - "state" 10 | post_decision_fields: 11 | - "actions" 12 | - "rewards" 13 | - "terminated" 14 | 15 | # --- Action selection --- 16 | action_selector: "epsilon_greedy" # Action selector 17 | eps_start: 1.0 # Initial exploration rate 18 | eps_end: 0.05 # Final exploration rate 19 | eps_anneal_time: 50000 # Number of env steps to anneal exploration rate 20 | 21 | # --- Q-learning extensions --- 22 | use_double_q: True # Whether double Q-learning is used 23 | use_dueling_ar: False # Whether dueling architecture is used 24 | 25 | # --- Mixer parameters --- 26 | mixer: 'qmix' # Mixer to combine individual state-action values 27 | mixing_embed_dim: 32 # Size of mixing network 28 | hypernet_layers: 2 # Number of layers in hyper network (only support 1 and 2) 29 | hypernet_embed: 64 # Hidden size of hyper network (when hypernet_layers > 1) 30 | -------------------------------------------------------------------------------- /config/comm/disc.yaml: -------------------------------------------------------------------------------- 1 | msg_size: 64 # Size of messages in multi-agent communication 2 | -------------------------------------------------------------------------------- /config/comm/tarmac.yaml: -------------------------------------------------------------------------------- 1 | key_size: 16 # Size of signatures and queries in multi-agent communication 2 | msg_size: 64 # Size of messages in multi-agent communication 3 | n_rounds: 1 # Rounds of multi-agent communication -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | # --- Common parameters --- 2 | use_cuda: True # Whether CUDA is used 3 | cuda_idx: 0 # Index of CUDA 4 | record_tests: False # Whether to save test results 5 | use_wandb: False # Whether to use wandb to record history 6 | 7 | # --- Components --- 8 | runner: "base" # Runner 9 | policy: 'shared' # Policy to make decisions 10 | 11 | agent: "rnn" # Agent (actor) type 12 | critic: # Critic type 13 | obs: 'flat' # Local observation format 14 | shared_obs: # Shared observation format 15 | state: # State format 16 | comm: # Communication protocol 17 | 18 | # --- Running parameters --- 19 | total_env_steps: 200000 # Total number of env steps to run 20 | rollout_len: # Number of transitions to collect in each time 21 | warmup_steps: 0 # Number of env steps for random exploration 22 | update_after: 0 # Number of env steps before starting update 23 | steps_per_session: 20000 # Number of env steps per session 24 | test_interval: 20000 # Number of env steps between tests 25 | save_interval: 100000 # Interval to save checkpoints 26 | n_test_episodes: 10 # Number of episodes to run in each test 27 | 28 | # --- RL hyperparameters --- 29 | lr: 0.0005 # Learning rate 30 | optim_eps: 0.00001 # eps used by Adam optimizer 31 | gamma: 0.99 # Discount factor 32 | buffer_size: 5000 # Maximum number of samples (sequences) held by replay buffer 33 | data_chunk_len: # Length of sequences 34 | batch_size: 32 # Minibatch size 35 | target_update_mode: "soft" # Whether soft/hard target update is used 36 | target_update_interval: 200 # Number of env steps between hard target updates 37 | polyak: 0.995 # Polyak factor used in soft target update 38 | 39 | use_huber_loss: False # Whether huber loss is used for Q function 40 | max_grad_norm: 10.0 # Maximum norm of gradients for clipping 41 | anneal_lr: False # Whether learning rate annealing is used 42 | normalize_reward: False # Whether reward normalization is used 43 | max_reward_limit: 10.0 # Reward range for clipping 44 | use_time_limit: False # Whether to truncate episode when maximum step number is reached 45 | alert_episode_limit: False # Whether arrival of episode limit is reported in info 46 | 47 | hidden_size: 64 # Hidden of policy/actor network 48 | activation: "relu" # Activation function between hidden layers 49 | n_layers: 2 # Number of fully-connected layers in flat observation encoder 50 | n_heads: 4 # Number of attention heads in relational observation encoder 51 | use_feat_norm: False # Apply layer normalization to observations 52 | use_layer_norm: False # Use layer normalization between hidden layers 53 | use_dueling: False # Use dueling layer at the end of policy/actor network 54 | -------------------------------------------------------------------------------- /envs/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from envs.ad_hoc.ad_hoc import AdHocEnv 4 | REGISTRY['ad-hoc'] = AdHocEnv 5 | -------------------------------------------------------------------------------- /envs/ad_hoc/ad_hoc.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import itertools 3 | import os 4 | import os.path as osp 5 | from copy import deepcopy 6 | 7 | import random 8 | import numpy as np 9 | 10 | import matplotlib.pyplot as plt 11 | from matplotlib.colors import Normalize 12 | 13 | from gym.spaces.discrete import Discrete 14 | from gym.spaces.dict import Dict 15 | from gym.spaces.box import Box 16 | 17 | from envs.multi_agent_env import MultiAgentEnv 18 | from envs.common import * 19 | from envs.heatmap import heatmap, annotate_heatmap 20 | from envs.chan_models import * 21 | from envs.ad_hoc.ad_hoc_entities import * 22 | from envs.ad_hoc.ad_hoc_layouts import SCENARIOS 23 | 24 | 25 | class AdHocEnv(MultiAgentEnv): 26 | """Cross-Layer Optimization in Wireless Ad Hoc Networks""" 27 | 28 | LEGAL_ROUTERS = {'c2Dst', 'mSINR'} 29 | LEGAL_POWER_CONTROLLERS = {'Full', 'Rand'} 30 | 31 | def __init__(self, 32 | scenario_name: str = '1flow', 33 | mute_interference: bool = False, # Whether to overlook interference. 34 | learn_power_control: bool = True, # Whether to learn power control. 35 | enforce_connection: bool = True, # Whether to enforce connection of path 36 | benchmark_routers: tuple[str, ...] = ('c2Dst', 'mSINR'), # Benchmark routers 37 | benchmark_power_controllers: tuple[str, ...] = ('Full',), # Benchmark power controllers 38 | graph_khops: int = 1, 39 | ): 40 | 41 | self._mute_inf = mute_interference 42 | self._learn_pc = learn_power_control 43 | self._force_cnct = enforce_connection 44 | 45 | self.bm_rts = benchmark_routers 46 | self.bm_pow_ctrls = benchmark_power_controllers 47 | 48 | self.khops = graph_khops 49 | 50 | # Get attributes from scenario. 51 | scenario = SCENARIOS[scenario_name] 52 | for k, v in scenario.__dict__.items(): 53 | setattr(self, k, v) 54 | self.set_entities = scenario.set_entities # Callback to set initial attributes of entities 55 | 56 | # In single-channel case, each relay has to transmit/receive on the same spectrum. 57 | self._allow_full_duplex = True if self.n_chans == 1 else False 58 | 59 | # Get benchmarks. 60 | if (self.n_pow_lvs > 1) and ('Rand' not in self.bm_pow_ctrls): # Compare learnt policy and rand choice. 61 | self.bm_pow_ctrls += ('Rand',) 62 | benchmark_schemes = list(itertools.product(self.bm_rts, self.bm_pow_ctrls)) 63 | self.bm_perf = dict.fromkeys(benchmark_schemes) # Performance of benchmark schemes 64 | self.bm_paths = dict() # Copy of benchmark routes 65 | 66 | # Create entities in the env. 67 | self.nodes = [Node(m, self.n_chans) for m in range(self.n_nodes)] 68 | 69 | def builds_power_levels(n_levels, p_max): 70 | """Builds discrete power levels.""" 71 | return (np.arange(n_levels, dtype=np.float32) + 1) / n_levels * p_max 72 | 73 | # Ambient flows take fixed Tx power with unlimited budget, 74 | # while agent flows may select power levels based on their (constrained) budget. 75 | amb_flow_kwargs = dict(p_levels=builds_power_levels(1, self.p_amb), p_budget=float('inf'), 76 | allow_full_duplex=self._allow_full_duplex) 77 | agt_flow_kwargs = dict(p_levels=builds_power_levels(self.n_pow_lvs, self.p_max), p_budget=self.p_bdg, 78 | allow_full_duplex=self._allow_full_duplex) 79 | self.amb_flows = [Flow(i, self.n_chans, self.n_nodes, **amb_flow_kwargs) for i in range(self.n_amb_flows)] 80 | agt_flow_slice = range(self.n_amb_flows, self.n_amb_flows + self.n_agt_flows) 81 | self.agt_flows = [Flow(i, self.n_chans, self.n_nodes, **agt_flow_kwargs) for i in agt_flow_slice] 82 | 83 | self.chan = CHANNEL_MODELS[self.chan_name]() # Wireless channel model 84 | self.bw = self.tot_bw / self.n_chans # Bandwidth of each sub-channel (Hz) 85 | 86 | # Define time-varying attributes characterizing relations between entities. 87 | self.d_n2n = np.empty((self.n_nodes, self.n_nodes), dtype=np.float32) # Distance between nodes 88 | self.d_n2dst = np.empty((self.n_nodes, self.n_flows), 89 | dtype=np.float32) # Distance between nodes and destinations 90 | self.chan_coef = np.empty((self.n_nodes, self.n_nodes, self.n_chans), dtype=np.float32) # Channel coefficients 91 | 92 | self.p_rx = np.empty_like(self.chan_coef) 93 | self.p_inf = np.empty_like(self.chan_coef) 94 | self.link_sinr = np.empty_like(self.chan_coef) 95 | self.link_rates = np.empty_like(self.chan_coef) 96 | 97 | # Define MDP components. 98 | self.n_agents = 1 # Single-agent env 99 | self.episode_limit = self.max_hops # Maximum number of timesteps 100 | self.agent: Flow = self.flows[-1] 101 | self.nbrs: list[Node] = [] # Neighbors around front node of agent 102 | 103 | self.all_actions = [] 104 | for flow in self.flows: 105 | action_tuples = [range(self.max_nbrs)] 106 | if self._learn_pc: 107 | action_tuples.append(range(flow.n_pow_lvs)) 108 | all_agent_actions = list(itertools.product(*action_tuples)) 109 | all_agent_actions.append('no-op') 110 | self.all_actions.append(all_agent_actions) 111 | 112 | self.observation_space = [] 113 | self.shared_observation_space = [] 114 | self.action_space = [] 115 | self.observation_space.append( 116 | Dict(spaces={ 117 | 'agent': Box(-np.inf, np.inf, shape=np.array([self.obs_own_feats_size])), 118 | 'nbr': Box(-np.inf, np.inf, shape=np.array(self.obs_nbr_feats_size)), 119 | }) 120 | ) 121 | self.shared_observation_space.append( 122 | Dict(spaces={ 123 | 'agent': Box(-np.inf, np.inf, shape=np.array([self.shared_obs_own_feats_size])), 124 | 'nbr': Box(-np.inf, np.inf, shape=np.array(self.shared_obs_nbr_feats_size)), 125 | 'nbr2': Box(-np.inf, np.inf, shape=np.array(self.shared_obs_nbr2_feats_size)), 126 | }) 127 | ) 128 | self.action_space.append(Discrete(len(self.all_actions[-1]))) 129 | 130 | @property 131 | def flows(self): 132 | """All data flows in a list""" 133 | return self.amb_flows + self.agt_flows 134 | 135 | @property 136 | def power_options(self): 137 | """Number of power options for RL algos""" 138 | return self.agent.n_pow_lvs if self._learn_pc else 1 139 | 140 | def reset(self): 141 | if self.agent != self.flows[-1]: # Finished agent-flow is not the last one. 142 | self.handover_agent(self.flows[self.flows.index(self.agent) + 1]) # Sequentially move to next flow. 143 | else: 144 | # Reset nodes and flows. 145 | pos_nodes, src_nids, dst_nids = self.set_entities() 146 | # print(f"src_nids = {src_nids}, dst_nids = {dst_nids}") 147 | for nid, node in enumerate(self.nodes): 148 | node.reset(pos_nodes[nid]) 149 | for fid, flow in enumerate(self.flows): 150 | flow.reset(self.nodes[src_nids[fid]], self.nodes[dst_nids[fid]]) 151 | 152 | # Compute the distance and channel coefficients between nodes. 153 | for m in range(self.n_nodes): 154 | for n in range(self.n_nodes): 155 | # Distance between node-m and node-n. 156 | self.d_n2n[m, n] = self.distance(self.nodes[m], self.nodes[n]) 157 | for i in range(self.n_flows): 158 | self.d_n2dst[m, i] = self.distance(self.nodes[m], self.flows[i].dst) 159 | d_n2n_copy = self.d_n2n.copy() 160 | d_n2n_copy[np.eye(self.n_nodes, dtype=bool)] = float('inf') 161 | self.chan_coef = self.chan.estimate_chan_gain(d_n2n_copy) 162 | self.chan_coef = np.tile(np.expand_dims(self.chan_coef, axis=-1), (1, 1, self.n_chans)) 163 | self._update_per_link_rate() # Compute achievable rate of all links. 164 | 165 | # Set route for ambient flows as warm-up. 166 | for flow in self.amb_flows: 167 | self.handover_agent(flow) 168 | terminated = False 169 | while not terminated: 170 | action = self.get_heuristic_action(('c2Dst', 'Full')) 171 | _, terminated, _ = self.step(action) 172 | # print(f"Ambient {flow} is set.") 173 | 174 | # Set heuristic routes for agent flows as benchmark. 175 | for rt in self.bm_rts: 176 | self.bm_paths[rt] = [] 177 | for pc in self.bm_pow_ctrls: 178 | for flow in self.agt_flows: 179 | self.handover_agent(flow) 180 | terminated = False 181 | while not terminated: 182 | action = self.get_heuristic_action((rt, pc)) 183 | _, terminated, _ = self.step(action) 184 | 185 | # print(f"Agent {flow} is set by heuristic scheme ({rt}, {pc}).") 186 | if len(self.bm_paths[rt]) < len(self.agt_flows): 187 | self.bm_paths[rt].append(deepcopy(flow)) 188 | 189 | self.bm_perf[(rt, pc)] = self.evaluate_performance() 190 | 191 | # Reset agent flows after evaluating performance. 192 | for flow in self.agt_flows: 193 | flow.clear_route() 194 | self._update_per_link_rate() 195 | 196 | # Assign agent flow. 197 | self.handover_agent(self.agt_flows[0]) 198 | 199 | def step(self, action) -> tuple[ndarray, bool, dict[str, Any]]: 200 | if isinstance(action, list): 201 | assert len(action) == self.n_agents, "Inconsistent numbers between actions and agents." 202 | action = action[0] 203 | 204 | # Get selected nbr, power level. 205 | flow_action = self.all_actions[self.agent.fid][action] 206 | # print(f"self.all_actions[self.agent.fid] = {self.all_actions[self.agent.fid]}") 207 | # print(f"action = {action}, flow_action = {flow_action}") 208 | if self._learn_pc: 209 | sel_nbr_idx, sel_pow_idx = flow_action 210 | else: 211 | sel_nbr_idx = flow_action[0] 212 | sel_pow_idx = self.allocate_power(self.nbrs[sel_nbr_idx], 'Rand') 213 | 214 | # Select the sub-channel with the least interference. 215 | next_node = self.nbrs[sel_nbr_idx] 216 | sel_chan_idx = self.allocate_channel(self.agent.front, next_node) 217 | 218 | # Enhance performance by eliminating unselected neighbors closer than selected one. 219 | front_nid = self.agent.front.nid 220 | for nbr in self.nbrs: 221 | is_closer_than_selected_nbr = self.d_n2n[nbr.nid, front_nid] < self.d_n2n[next_node.nid, front_nid] 222 | if is_closer_than_selected_nbr and (nbr is not self.agent.dst): 223 | self.agent.ban(nbr) 224 | 225 | # Apply action to agent. 226 | self.agent.add_hop(next_node, sel_chan_idx, sel_pow_idx) 227 | self._update_per_link_rate() 228 | self._find_neighbors() 229 | 230 | # Computes output signal from env. 231 | reward = self._get_reward() 232 | terminated = self._get_terminate() 233 | info = dict() 234 | 235 | # Evaluate performance of agent flows. 236 | agent_perf = self.evaluate_performance() 237 | for k, v in agent_perf.items(): 238 | info['Agent' + k] = v 239 | # Provide performance of benchmark schemes on above flows . 240 | for sch, perf in self.bm_perf.items(): 241 | rt, pc = sch 242 | if perf is not None: 243 | for k, v in perf.items(): 244 | info[rt + pc + k] = v 245 | 246 | return reward, terminated, info 247 | 248 | def handover_agent(self, flow: Flow): 249 | """Hands-over agent to data flow.""" 250 | self.agent = flow # Assign flow to be agent. 251 | self._find_neighbors() # Since front node is altered, call update neighbors. 252 | 253 | @staticmethod 254 | def distance(node1: Node, node2: Node): 255 | """Computes distance between two nodes.""" 256 | return np.linalg.norm(node1.pos - node2.pos) 257 | 258 | def _find_neighbors(self): 259 | """Finds qualified neighbors for front node of current agent flow.""" 260 | # Check whether maximum hop is reached or battery is depleted. 261 | front_nid = self.agent.front.nid 262 | is_overtime = (self.agent.n_hops == self.max_hops - 1) and not self.agent.is_connected 263 | is_low_battery = (self.agent.p_rem < 2 * self.agent.p_lvs[0]) and not self.agent.is_connected 264 | 265 | if self.agent.is_connected: # When agent is connected: 266 | # `terminated` would be then activated to reset the env. 267 | nbrs = [] # No neighbor is available after termination of episode. 268 | elif is_overtime or is_low_battery: # When any failure occurs: 269 | if self._force_cnct: 270 | nbrs = [self.agent.dst] # Destination is the only neighbor. 271 | else: 272 | nbrs = [] 273 | else: 274 | # Sort all nodes in ascending order of distance to current front node. 275 | sorted_nids = np.argsort(self.d_n2n[self.agent.front.nid]) 276 | nbrs = [] 277 | for nid in sorted_nids: 278 | # Add node to neighbors when all of following conditions are met: The neighbor 279 | # 1) lies within the sensing range, 2) Meets the qualification of current agent flow. 280 | if (self.d_n2n[nid, front_nid] <= self.r_sns) and self.agent.check(self.nodes[nid]): 281 | nbrs.append(self.nodes[nid]) 282 | # End when maximum number of neighbors are collected. 283 | if len(nbrs) >= self.max_nbrs: 284 | break 285 | # Shuffle the order of neighbors. 286 | random.shuffle(nbrs) 287 | is_isolated = (len(nbrs) == 0) and not self.agent.is_connected 288 | if is_isolated and self._force_cnct: 289 | nbrs = [self.agent.dst] 290 | # Assign neighbors. 291 | if (not self._force_cnct) and (not self.agent.is_connected): 292 | assert len(nbrs) > 0, "Empty neighbor set is found." 293 | self.nbrs = nbrs 294 | 295 | def _find_2hop_neighbors(self): 296 | """Finds neighbors of current neighbors (2nd-hop). 297 | 298 | Note that 2nd-hop neighbors are not directly added to route and thus qualification is not mandatory. 299 | """ 300 | nbrs2_per_nbr = [] 301 | for nbr in self.nbrs: 302 | sorted_nids = np.argsort(self.d_n2n[nbr.nid]) 303 | nbrs2 = [] 304 | for nid in sorted_nids: 305 | if self.nodes[nid] is not nbr: 306 | nbrs2.append(self.nodes[nid]) 307 | # End when maximum number of neighbors are collected. 308 | if len(nbrs2) >= self.max_nbrs: 309 | break 310 | nbrs2_per_nbr.append(nbrs2) 311 | return nbrs2_per_nbr 312 | 313 | def _update_per_link_rate(self): 314 | """Computes achievable rate of all links.""" 315 | p_tx = np.stack([node.p_tx for node in self.nodes]) # Tx power (Watt) 316 | self.p_rx = self.chan_coef * (1 - np.expand_dims(np.eye(self.n_nodes), axis=-1)) * p_tx # Rx power (Watt) 317 | self.p_inf = np.zeros_like(self.p_rx) if self._mute_inf else self.p_rx.sum(1, keepdims=True) - self.p_rx # Interference (Watt) 318 | self.link_sinr = self.p_rx / (self.p_inf + self.bw * self.n0) # Signal-to-interference-plus-noise ratio 319 | self.link_rates = self.bw * np.log2(1 + self.link_sinr) * 1e-6 # Achievable rates (Mbps) 320 | 321 | def get_per_hop_rate(self, flow: Flow): 322 | """Returns the rate of each hop along a data flow.""" 323 | rate_per_hop = [] 324 | for link in flow.route: 325 | rate_per_hop.append(self.link_rates[link.rx.nid, link.tx.nid, link.chan_idx]) 326 | return rate_per_hop 327 | 328 | def get_bottleneck_rate(self, flow: Flow): 329 | """Gets the bottleneck rate of a data flow.""" 330 | rate_per_hop = self.get_per_hop_rate(flow) 331 | if len(rate_per_hop) > 0: 332 | bottleneck_rate = np.amin(rate_per_hop) 333 | bottleneck_idx = np.argmin(rate_per_hop) 334 | return bottleneck_rate.item(), bottleneck_idx.item() 335 | else: 336 | return 0.0, 0 337 | 338 | def get_bottleneck_sinr(self, flow: Flow): 339 | """Gets the bottleneck SINR of a data flow.""" 340 | sinr_per_hop = [] 341 | for link in flow.route: 342 | sinr_per_hop.append(self.link_sinr[link.rx.nid, link.tx.nid, link.chan_idx]) 343 | return min(sinr_per_hop) 344 | 345 | def allocate_channel(self, tx_node: Node, rx_node: Node): 346 | """Allocates sub-channel for a transceiver pair.""" 347 | p_inf_per_chan = self.p_inf[rx_node.nid, tx_node.nid] 348 | for chan_idx in np.argsort(p_inf_per_chan): # Starting from channel with the lowest inf level: 349 | if (tx_node.idle[chan_idx] or self._allow_full_duplex) and rx_node.idle[chan_idx]: 350 | return chan_idx 351 | # This should not happen since it is prevented by `check` mechanism. 352 | raise Warning("No channel is available!") 353 | 354 | def allocate_power(self, next_node: Node, power_controller: str = 'Full'): 355 | """Allocates Tx power from current front node to next node""" 356 | can_afford_pow_lvs = self.get_affordable_power_levels(next_node) 357 | afford_pow_lv_idxes = np.flatnonzero(can_afford_pow_lvs) 358 | if power_controller == 'Full': 359 | sel_pow_idx = afford_pow_lv_idxes[-1] # Maximum affordable power 360 | elif power_controller == 'Rand': 361 | rand_from_afford_pow_lv_idxes = random.sample(range(afford_pow_lv_idxes.size), 1)[0] 362 | sel_pow_idx = afford_pow_lv_idxes[rand_from_afford_pow_lv_idxes] # Rand affordable power 363 | else: 364 | raise KeyError("Unrecognized power controller!") 365 | # print(f"Select p_idx = {sel_pow_idx} with p_rem = {self.agent.p_rem}.") 366 | return sel_pow_idx 367 | 368 | def get_affordable_power_levels(self, next_node: Node): 369 | """Returns affordable power levels from current agent front to next node""" 370 | p_rem = self.agent.p_rem # Remaining power 371 | p_min = self.agent.p_lvs[0] # Minute power 372 | if next_node is self.agent.dst: 373 | can_afford_pow_lvs = [(p < p_rem) for p in self.agent.p_lvs] 374 | else: 375 | can_afford_pow_lvs = [(p < p_rem - p_min) for p in self.agent.p_lvs] 376 | return can_afford_pow_lvs 377 | 378 | def get_heuristic_action(self, scheme: tuple[str, str]): 379 | """Returns agent action by heuristic AI.""" 380 | # Interpret heuristic scheme. 381 | rt, pc = scheme 382 | assert (rt in self.LEGAL_ROUTERS) and (pc in self.LEGAL_POWER_CONTROLLERS), \ 383 | f"Unrecognized heuristic router/power controller ({scheme}) is received." 384 | 385 | # Select next hop from neighbors. 386 | if rt == 'c2Dst': 387 | d_nbr2dst = [self.d_n2dst[nbr.nid, self.agent.fid] for nbr in self.nbrs] 388 | sel_nbr_idx = np.argmin(d_nbr2dst) 389 | # print(f"d_nbr2dst = {d_nbr2dst}, sel_nbr_idx = {sel_nbr_idx}") 390 | elif rt == 'mSINR': 391 | front_nid = self.agent.front.nid 392 | sinr_per_nbr = [] 393 | for nbr in self.nbrs: 394 | sinr = self.chan_coef[nbr.nid, front_nid] * self.agent.p_lvs[-1] / (self.n0 * self.bw + self.p_inf[nbr.nid, front_nid]) 395 | sinr_per_nbr.append(sinr) 396 | sel_nbr_idx = np.argmax(sinr_per_nbr) 397 | else: 398 | raise KeyError(f"Unrecognized heuristic scheme `{scheme}` to select next node.") 399 | 400 | # Allocate Tx power and get index of discrete action. 401 | if self._learn_pc: 402 | sel_pow_idx = self.allocate_power(self.nbrs[sel_nbr_idx], pc) 403 | act_idx = sel_nbr_idx * self.agent.n_pow_lvs + sel_pow_idx 404 | else: 405 | act_idx = sel_nbr_idx 406 | return act_idx 407 | 408 | def get_total_actions(self): 409 | return [self.action_space[i].n for i in range(self.n_agents)] 410 | 411 | def get_avail_actions(self): 412 | avail_actions = [] 413 | 414 | no_op = True if self.agent.is_connected else False 415 | if self._learn_pc: 416 | avail_agent_actions = np.zeros((self.max_nbrs, self.agent.n_pow_lvs), dtype=bool) 417 | if not self.agent.is_connected: 418 | for nbr_idx, nbr in enumerate(self.nbrs): 419 | can_afford_pow_lvs = self.get_affordable_power_levels(nbr) 420 | avail_agent_actions[nbr_idx, can_afford_pow_lvs] = True 421 | else: 422 | avail_agent_actions = np.zeros(self.max_nbrs, dtype=bool) 423 | if not self.agent.is_connected: 424 | for nbr_idx in range(len(self.nbrs)): 425 | avail_agent_actions[nbr_idx] = True 426 | avail_actions.append(avail_agent_actions.flatten().tolist() + [no_op]) 427 | 428 | return avail_actions 429 | 430 | def get_obs(self): 431 | obs = [] 432 | 433 | own_feats = np.zeros(self.obs_own_feats_size, dtype=np.float32) 434 | nbr_feats = np.zeros(self.obs_nbr_feats_size, dtype=np.float32) 435 | 436 | # Get features of agent flow 437 | ind = 0 438 | own_feats[ind:ind + self.dim_pos] = self.agent.front.pos / self.range_pos # Position of front node 439 | ind += self.dim_pos 440 | own_feats[ind:ind + self.dim_pos] = (self.agent.dst.pos - self.agent.front.pos) / self.range_pos # Distance to destination 441 | ind += self.dim_pos 442 | if self.agent.p_bdg < float('inf'): # If power budget is limited: 443 | own_feats[ind] = self.agent.p_rem / self.agent.p_bdg # Remaining power 444 | ind += 1 445 | 446 | # Get features of neighbor nodes. 447 | front_nid = self.agent.front.nid 448 | p_max = self.agent.p_lvs[-1] 449 | for m, nbr in enumerate(self.nbrs): 450 | ind = 0 451 | # Availability 452 | nbr_feats[m, ind] = 1 453 | ind += 1 454 | # Relative distance to front node of agent 455 | nbr_feats[m, ind:ind + self.dim_pos] = (nbr.pos - self.agent.front.pos) / self.range_pos 456 | ind += self.dim_pos 457 | # Relative distance to destination of agent 458 | nbr_feats[m, ind:ind + self.dim_pos] = (nbr.pos - self.agent.dst.pos) / self.range_pos 459 | ind += self.dim_pos 460 | # SINR in dB 461 | sinr_per_chan = self.chan_coef[nbr.nid, front_nid] * p_max / (self.n0 * self.bw + self.p_inf[nbr.nid, front_nid]) 462 | nbr_feats[m, ind] = np.log10(max(np.max(sinr_per_chan), 1e-10)) # Avoid extreme value of SINR in dB. 463 | 464 | obs.append(dict(agent=own_feats, nbr=nbr_feats)) 465 | return obs 466 | 467 | @property 468 | def obs_own_feats_size(self): 469 | nf_own = self.dim_pos + self.dim_pos 470 | if self.agent.p_bdg < float('inf'): 471 | nf_own += 1 # Scalar to show remaining energy 472 | return nf_own 473 | 474 | @property 475 | def obs_nbr_feats_size(self): 476 | nf_nbr = 1 + self.dim_pos + self.dim_pos # Availability, distance to agent, distance to destination 477 | nf_nbr += 1 # Cumulative interference 478 | return self.max_nbrs, nf_nbr 479 | 480 | def get_obs_size(self): 481 | return [self.obs_own_feats_size + np.prod(self.obs_nbr_feats_size)] * self.n_agents 482 | 483 | def get_graph_inputs(self): 484 | k = self.khops 485 | 486 | # ============ Count number of nodes of each type. ============ 487 | num_nodes_dict = dict(agent=1, nbr=self.n_nodes) 488 | 489 | # ============ Get node features. ============ 490 | own_feats = np.zeros(self.obs_own_feats_size, dtype=np.float32) 491 | # Get features of agent flow 492 | ind = 0 493 | own_feats[ind:ind + self.dim_pos] = self.agent.front.pos / self.range_pos # Position of front node 494 | ind += self.dim_pos 495 | own_feats[ind:ind + self.dim_pos] = (self.agent.dst.pos - self.agent.front.pos) / self.range_pos # Distance to destination 496 | ind += self.dim_pos 497 | if self.agent.p_bdg < float('inf'): # If power budget is limited: 498 | own_feats[ind] = self.agent.p_rem / self.agent.p_bdg # Remaining power 499 | ind += 1 500 | 501 | nbr_feats = np.zeros((self.n_nodes, self.graph_nbr_feats)) 502 | for nid, node in enumerate(self.nodes): 503 | ind = 0 504 | nbr_feats[nid, ind:ind + self.dim_pos] = (self.agent.dst.pos - node.pos) / self.range_pos 505 | ind += self.dim_pos 506 | 507 | node_feats = {'agent': np.expand_dims(own_feats, 0), 'nbr': nbr_feats} 508 | 509 | # ============ Define edges and their features. ============ 510 | graph_data = {('nbr', '1hop', 'agent'): ([], [])} 511 | edge_feats = {'1hop': []} 512 | if k == 2: 513 | graph_data.update({('nbr', '2hop', 'nbr'): ([], [])}) 514 | edge_feats.update({'2hop': []}) 515 | 516 | nbrs2_per_nbr = self._find_2hop_neighbors() 517 | front_nid = self.agent.front.nid 518 | p_max = self.agent.p_lvs[-1] 519 | 520 | for nbr_idx, nbr in enumerate(self.nbrs): 521 | # Agent and its neighbors defines 1-hop relations. 522 | graph_data[('nbr', '1hop', 'agent')][0].append(nbr.nid) 523 | graph_data[('nbr', '1hop', 'agent')][1].append(0) 524 | 525 | h1_feats = np.zeros(self.graph_hop_feats, dtype=np.float32) 526 | ind = 0 527 | h1_feats[ind:ind + self.dim_pos] = (nbr.pos - self.agent.front.pos) / self.range_pos 528 | ind += self.dim_pos 529 | # SINR in dB 530 | sinr_per_chan = self.chan_coef[nbr.nid, front_nid] * p_max / (self.n0 * self.bw + self.p_inf[nbr.nid, front_nid]) 531 | h1_feats[ind] = np.log10(max(np.max(sinr_per_chan), 1e-10)) # Avoid extreme value of SINR in dB. 532 | 533 | edge_feats['1hop'].append(h1_feats) 534 | 535 | if k == 2: 536 | for nbr2_idx, nbr2 in enumerate(nbrs2_per_nbr[nbr_idx]): 537 | # Neighbors and their neighbors defines 2-hop relations. 538 | graph_data[('nbr', '2hop', 'nbr')][0].append(nbr2.nid) 539 | graph_data[('nbr', '2hop', 'nbr')][1].append(nbr.nid) 540 | 541 | h2_feats = np.zeros(self.graph_hop_feats, dtype=np.float32) 542 | ind = 0 543 | h2_feats[ind:ind + self.dim_pos] = (nbr2.pos - nbr.pos) / self.range_pos 544 | ind += self.dim_pos 545 | # SINR in dB 546 | sinr_per_chan = self.chan_coef[nbr2.nid, nbr.nid] * p_max / (self.n0 * self.bw + self.p_inf[nbr2.nid, nbr.nid]) 547 | h2_feats[ind] = np.log10(max(np.max(sinr_per_chan), 1e-10)) # Avoid extreme value of SINR in dB. 548 | 549 | edge_feats['2hop'].append(h2_feats) 550 | 551 | for etype, edata in edge_feats.items(): 552 | if len(edata) == 0: 553 | edge_feats[etype] = np.zeros((0, self.graph_feats['hop']), dtype=np.float32) 554 | else: 555 | edge_feats[etype] = np.stack(edata) 556 | 557 | graph_inputs = { 558 | 'graph_data': graph_data, # Define edges 559 | 'num_nodes_dict': num_nodes_dict, # Number of nodes 560 | 'ndata': node_feats, # Node features 561 | 'edata': edge_feats, # Edge features 562 | } 563 | return graph_inputs 564 | 565 | @property 566 | def graph_feats(self, ): 567 | return { 568 | 'agent': self.graph_own_feats, 569 | 'nbr': self.graph_nbr_feats, 570 | 'hop': self.graph_hop_feats, 571 | } 572 | 573 | @property 574 | def graph_own_feats(self): 575 | return self.obs_own_feats_size 576 | 577 | @property 578 | def graph_nbr_feats(self): 579 | return self.dim_pos 580 | 581 | @property 582 | def graph_hop_feats(self): 583 | return self.dim_pos + 1 584 | 585 | def get_shared_obs(self): 586 | 587 | # Get local observations. 588 | local_obs = self.get_obs() 589 | 590 | nbrs2_per_nbr = self._find_2hop_neighbors() 591 | nbr2_feats = np.zeros(self.shared_obs_nbr2_feats_size, dtype=np.float32) 592 | p_max = self.agent.p_lvs[-1] 593 | for nbr_idx, nbr in enumerate(self.nbrs): 594 | nbrs2 = nbrs2_per_nbr[nbr_idx] 595 | for nbr2_idx, nbr2 in enumerate(nbrs2): 596 | ind = 0 597 | nbr2_feats[nbr_idx, nbr2_idx, ind] = 1 598 | ind += 1 599 | nbr2_feats[nbr_idx, nbr2_idx, ind:ind + self.dim_pos] = (nbr2.pos - nbr.pos) / self.range_pos 600 | ind += self.dim_pos 601 | nbr2_feats[nbr_idx, nbr2_idx, ind:ind + self.dim_pos] = (nbr2.pos - self.agent.dst.pos) / self.range_pos 602 | ind += self.dim_pos 603 | sinr_per_chan = self.chan_coef[nbr2.nid, nbr.nid] * p_max / (self.n0 * self.bw + self.p_inf[nbr2.nid, nbr.nid]) 604 | nbr2_feats[nbr_idx, nbr2_idx, ind] = np.log10(max(np.max(sinr_per_chan), 1e-10)) # Avoid extreme value of SINR in dB. 605 | 606 | shared_obs = [dict(nbr2=nbr2_feats, **local_obs_agent) for local_obs_agent in local_obs] 607 | return shared_obs 608 | 609 | def get_shared_obs_size(self): 610 | return [self.shared_obs_own_feats_size + np.prod(self.shared_obs_nbr_feats_size) + np.prod(self.shared_obs_nbr2_feats_size)] * self.n_agents 611 | 612 | @property 613 | def shared_obs_own_feats_size(self): 614 | return self.obs_own_feats_size 615 | 616 | @property 617 | def shared_obs_nbr_feats_size(self): 618 | return self.max_nbrs, self.obs_nbr_feats_size[-1] 619 | 620 | @property 621 | def shared_obs_nbr2_feats_size(self): 622 | return self.max_nbrs, self.max_nbrs, self.obs_nbr_feats_size[-1] 623 | 624 | def _get_reward(self): 625 | # Reward SINR of bottleneck link. 626 | bias = 5 627 | bottleneck_sinr = np.log10(self.get_bottleneck_sinr(self.agent)) + bias 628 | reward = self.agent.is_connected * max(bottleneck_sinr, 0) 629 | return np.array([reward], dtype=np.float32) 630 | 631 | def _get_terminate(self): 632 | """Returns termination of episode.""" 633 | if self.agent.is_connected or (len(self.nbrs) == 0): 634 | return True 635 | else: 636 | return False 637 | 638 | def measure_link_distance(self, link: Link): 639 | tx_nid, rx_nid = link.tx.nid, link.rx.nid 640 | return self.d_n2n[rx_nid, tx_nid] 641 | 642 | def evaluate_performance(self): 643 | """Evaluates the overall performance of agent flows.""" 644 | perf_ind_dict = { 645 | 'BottleneckRate': [self.get_bottleneck_rate(flow)[0] for flow in self.agt_flows], 646 | 'Hops': [flow.n_hops for flow in self.agt_flows], 647 | } 648 | 649 | # When power budget is finite, record the total amount of power consumption. 650 | if self.p_bdg < float('inf'): 651 | perf_ind_dict['TotalPowCost'] = [flow.p_tot for flow in self.agt_flows] 652 | 653 | # Check whether the data is connected autonomously. 654 | def check_autonomous_connection(flow: Flow): 655 | if not flow.is_connected: 656 | return 0 657 | else: 658 | for link in flow.route: 659 | if self.measure_link_distance(link) > self.r_sns: 660 | return 0 661 | return 1 662 | 663 | perf_ind_dict['ConnectionProb'] = [check_autonomous_connection(flow) for flow in self.agt_flows] 664 | 665 | return perf_ind_dict 666 | 667 | def render(self): 668 | pass 669 | 670 | def save_replay(self, show_img: bool = False, save_dir: str = None, tag: str = None): 671 | self.visualize_policy(show_img, save_dir, tag) 672 | self.visualize_route(show_img, save_dir, tag) 673 | self.visualize_power(show_img, save_dir, tag) 674 | 675 | def add_boundary(self, ax): 676 | """Adds boundary to ax.""" 677 | boundary = plot_boundary(self.range_pos) 678 | ax.plot(*boundary, color='black') 679 | 680 | ax.axis([-0.1 * self.range_pos, 1.1 * self.range_pos, -0.1 * self.range_pos, 1.1 * self.range_pos]) 681 | ax.set_xlabel('x (m)') 682 | ax.set_ylabel('y (m)') 683 | 684 | def plot_nodes(self, ax, ignore_idle_nodes: bool = True): 685 | """Plots nodes to ax.""" 686 | # Styles of nodes according to their role 687 | node_styles = { 688 | 'idle': {'marker': 'o', 'alpha': 0.5, 's': 70, 'color': 'grey'}, 689 | 'src': {'marker': 'v', 'alpha': 0.75, 's': 100, 'color': 'tomato'}, 690 | 'dst': {'marker': 's', 'alpha': 0.75, 's': 100, 'color': 'tomato'}, 691 | 'rly': {'marker': 'o', 'alpha': 0.75, 's': 85, 'color': 'lightskyblue'}, 692 | } 693 | txt_offset = 7.5 # Offset of text 694 | font_size = 6 # Font size 695 | 696 | for node in self.nodes: 697 | # Determine the role played by node. 698 | role = 'idle' if node.idle.all() else node.role 699 | txt_offset_bias = 0 if node.nid < 10 else txt_offset 700 | if role == 'idle': 701 | if not ignore_idle_nodes: # Idle nodes are plotted only if specially requested. 702 | ax.scatter(node.pos[0], node.pos[1], **node_styles['idle']) 703 | ax.text(node.pos[0] - txt_offset - txt_offset_bias, node.pos[1] - txt_offset, f"{node.nid}", 704 | fontsize=font_size, alpha=0.5, weight='light') 705 | else: 706 | ax.scatter(node.pos[0], node.pos[1], **node_styles[role]) 707 | ax.text(node.pos[0] - txt_offset - txt_offset_bias, node.pos[1] - txt_offset, f"{node.nid}", 708 | fontsize=font_size) 709 | 710 | def visualize_policy(self, show_img: bool = False, save_dir: str = None, tag: str = None, **kwargs): 711 | """Shows the routes/resource allocation decisions of policy.""" 712 | dpi = 200 # Resolution of figures 713 | fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(7, 3.25), layout="constrained", dpi=dpi) 714 | fig.suptitle('Route and Resource Allocation in Wireless Ad Hoc Network', fontsize=12) 715 | 716 | # ============ Sub-figure 1: Plot route. ============ 717 | ax1 = axs[0] 718 | ax1.set_aspect('equal') 719 | 720 | # Style of different heuristic routers 721 | rt_styles = { 722 | 'c2Dst': {'color': 'pink', 'linewidth': 2.5, 'alpha': 0.5}, 723 | 'mSINR': {'color': 'yellowgreen', 'linewidth': 2.5, 'alpha': 0.5}, 724 | } 725 | 726 | from matplotlib.collections import LineCollection 727 | import matplotlib.lines as mlines 728 | 729 | # Plot routes set by benchmark routers. 730 | rt_labels = [] 731 | for rt, flows in self.bm_paths.items(): 732 | xs, ys = [], [] 733 | for flow in flows: 734 | for link in flow.route: 735 | xs.append(np.linspace(link.tx.pos[0], link.rx.pos[0], 25)) 736 | ys.append(np.linspace(link.tx.pos[1], link.rx.pos[1], 25)) 737 | segs = [np.stack((x, y)).T for x, y in zip(xs, ys)] 738 | ax1.add_collection(LineCollection(segs, zorder=-1, **rt_styles[rt])) 739 | rt_labels.append(mlines.Line2D([], [], label=rt, **rt_styles[rt])) 740 | ax1.legend(handles=rt_labels, loc='lower right', prop={'size': 6}) 741 | 742 | # Get the range of rates. 743 | min_rate, max_rate = float('inf'), 0 744 | for flow in self.flows: 745 | if flow.n_hops > 0: 746 | rate_per_hop = self.get_per_hop_rate(flow) 747 | min_rate = min(min_rate, min(rate_per_hop)) 748 | max_rate = max(max_rate, max(rate_per_hop)) 749 | max_rate += 1e-2 750 | min_rate -= 1e-2 751 | # Plot routes of all flows with color specifying link rates. 752 | cmap = plt.colormaps["magma"] 753 | for flow in self.flows: 754 | for link in flow.route: 755 | link_rate = self.link_rates[link.rx.nid, link.tx.nid, link.chan_idx] 756 | norm_r = (link_rate - min_rate) / (max_rate - min_rate) 757 | ax1.arrow(link.tx.pos[0], link.tx.pos[1], 758 | link.rx.pos[0] - link.tx.pos[0], link.rx.pos[1] - link.tx.pos[1], 759 | shape='full', length_includes_head=True, width=5, color=cmap(norm_r), 760 | alpha=norm_r * 0.8 + 0.1) 761 | 762 | # Plot all nodes. 763 | self.plot_nodes(ax1, ignore_idle_nodes=False) 764 | fig.colorbar(plt.cm.ScalarMappable(norm=Normalize(min_rate, max_rate), cmap=cmap), ax=ax1, 765 | fraction=0.05, pad=0.05, label="Rate (Mbps)") 766 | 767 | # Plot the boundary of legal region. 768 | self.add_boundary(ax1) 769 | ax1.set_title("Route") 770 | 771 | # ============ Sub-figure 2: Plot resource allocation. ============ 772 | ax2 = axs[1] 773 | 774 | # Extract busy nodes 775 | busy_nids, p_tx = [], [] 776 | for node in self.nodes: 777 | if not node.idle.all(): 778 | p_tx.append(node.p_tx) 779 | busy_nids.append(node.nid) 780 | 781 | if len(busy_nids) > 0: 782 | # Stack Tx power of nodes. 783 | p_tx = np.stack(p_tx) 784 | # Create heatmap. 785 | im, cbar = heatmap(p_tx, busy_nids, range(self.n_chans), ax=ax2, 786 | cmap="magma", cbarlabel="Tx power (Watt)") 787 | # Annotate heatmap. 788 | texts = annotate_heatmap(im, valfmt="{x:.2f}", size=7, threshold=0.04, 789 | textcolors=("white", "black")) 790 | 791 | ax2.set_title("Resource Allocation") 792 | ax2.set_xlabel('Chan Idx') 793 | ax2.set_ylabel('NID') 794 | 795 | # ============ Show/save figure. ============ 796 | 797 | # Display the image. 798 | if show_img: 799 | plt.show() 800 | 801 | # Write results to disk. 802 | if save_dir is not None: 803 | os.makedirs(save_dir, exist_ok=True) 804 | fig_name = 'rt_ra.pdf' 805 | if tag is not None: 806 | fig_name = tag + '_' + fig_name 807 | fig_path = osp.join(save_dir, fig_name) 808 | plt.savefig(fig_path) 809 | plt.close() 810 | 811 | def visualize_route(self, show_img: bool = False, save_dir: str = None, tag: str = None, **kwargs): 812 | """Shows the routes/resource allocation decisions of policy.""" 813 | dpi = 200 # Resolution of figures 814 | fig, ax1 = plt.subplots(ncols=1, nrows=1, figsize=(3.5, 3.25), layout="constrained", dpi=dpi) 815 | 816 | # ============ Sub-figure 1: Plot route. ============ 817 | ax1.set_aspect('equal') 818 | 819 | # Style of different heuristic routers 820 | rt_styles = { 821 | 'c2Dst': {'color': 'pink', 'linewidth': 2.5, 'alpha': 0.5}, 822 | 'mSINR': {'color': 'yellowgreen', 'linewidth': 2.5, 'alpha': 0.5}, 823 | } 824 | 825 | from matplotlib.collections import LineCollection 826 | import matplotlib.lines as mlines 827 | 828 | # Plot routes set by benchmark routers. 829 | rt_labels = [] 830 | for rt, flows in self.bm_paths.items(): 831 | xs, ys = [], [] 832 | for flow in flows: 833 | for link in flow.route: 834 | if self.measure_link_distance(link) <= self.r_sns: 835 | xs.append(np.linspace(link.tx.pos[0], link.rx.pos[0], 25)) 836 | ys.append(np.linspace(link.tx.pos[1], link.rx.pos[1], 25)) 837 | segs = [np.stack((x, y)).T for x, y in zip(xs, ys)] 838 | ax1.add_collection(LineCollection(segs, zorder=-1, **rt_styles[rt])) 839 | rt_labels.append(mlines.Line2D([], [], label=rt, **rt_styles[rt])) 840 | ax1.legend(handles=rt_labels, loc='lower right', prop={'size': 6}) 841 | 842 | # Get the range of rates. 843 | min_rate, max_rate = float('inf'), 0 844 | for flow in self.flows: 845 | if flow.n_hops > 0: 846 | rate_per_hop = self.get_per_hop_rate(flow) 847 | min_rate = min(min_rate, min(rate_per_hop)) 848 | max_rate = max(max_rate, max(rate_per_hop)) 849 | max_rate += 1e-2 850 | min_rate -= 1e-2 851 | # Plot routes of all flows with color specifying link rates. 852 | cmap = plt.colormaps["magma"] 853 | for flow in self.flows: 854 | for link in flow.route: 855 | if self.measure_link_distance(link) <= self.r_sns: 856 | link_rate = self.link_rates[link.rx.nid, link.tx.nid, link.chan_idx] 857 | norm_r = (link_rate - min_rate) / (max_rate - min_rate) 858 | ax1.arrow(link.tx.pos[0], link.tx.pos[1], 859 | link.rx.pos[0] - link.tx.pos[0], link.rx.pos[1] - link.tx.pos[1], 860 | shape='full', length_includes_head=True, width=5, color=cmap(norm_r), 861 | alpha=norm_r * 0.75 + 0.25) 862 | 863 | # Plot all nodes. 864 | self.plot_nodes(ax1, ignore_idle_nodes=False) 865 | fig.colorbar(plt.cm.ScalarMappable(norm=Normalize(min_rate, max_rate), cmap=cmap), ax=ax1, 866 | fraction=0.05, pad=0.05, label="Rate (Mbit/s)") 867 | 868 | # Plot the boundary of legal region. 869 | self.add_boundary(ax1) 870 | 871 | # ============ Show/save figure. ============ 872 | 873 | # Display the image. 874 | if show_img: 875 | plt.show() 876 | 877 | # Write results to disk. 878 | if save_dir is not None: 879 | os.makedirs(save_dir, exist_ok=True) 880 | fig_name = 'rt.pdf' 881 | if tag is not None: 882 | fig_name = tag + '_' + fig_name 883 | fig_path = osp.join(save_dir, fig_name) 884 | plt.savefig(fig_path) 885 | plt.close() 886 | 887 | def visualize_power(self, show_img: bool = False, save_dir: str = None, tag: str = None): 888 | """Shows power relation between nodes.""" 889 | dpi = 200 # Resolution of figures 890 | fig, axs = plt.subplots(ncols=2, nrows=1, figsize=(7, 3.25), layout="constrained", dpi=dpi) 891 | fig.suptitle('Received Power', fontsize=12) 892 | 893 | snr_min, snr_max = -10, 40 # Min/Max value of power/N (dB) 894 | cmap = plt.colormaps["viridis"] # Colormap 895 | 896 | def normalize_snr(p): 897 | snr_db = 10 * np.log10(max(p / (self.n0 * self.bw), 1e-8)) 898 | rescaled_snr_db = (np.clip(snr_db, snr_min, snr_max) - snr_min) / (snr_max - snr_min) 899 | return rescaled_snr_db 900 | 901 | # ============ Sub-figure 1: Plot direct signals. ============ 902 | ax1 = axs[0] 903 | ax1.set_aspect('equal') 904 | 905 | for flow in self.flows: 906 | _, bottleneck_idx = self.get_bottleneck_rate(flow) 907 | for l_idx, link in enumerate(flow.route): 908 | tx, rx = link.tx, link.rx 909 | # Accentuate bottleneck link. 910 | if l_idx == bottleneck_idx: 911 | ax1.plot(np.linspace(tx.pos[0], rx.pos[0], 25), np.linspace(tx.pos[1], rx.pos[1], 25), 912 | alpha=0.25, color='grey', linewidth=7.5) 913 | snr_val = normalize_snr(self.p_rx[rx.nid, tx.nid, link.chan_idx]) 914 | ax1.arrow(tx.pos[0], tx.pos[1], rx.pos[0] - tx.pos[0], rx.pos[1] - link.tx.pos[1], 915 | shape='full', length_includes_head=True, color=cmap(snr_val), alpha=snr_val, width=5) 916 | 917 | # Plot nodes that are not idle. 918 | self.plot_nodes(ax1) 919 | # Plot the boundary of legal region. 920 | self.add_boundary(ax1) 921 | ax1.set_title("SNR") 922 | 923 | # ============ Sub-figure 2: Plot interference signals. ============ 924 | ax2 = axs[1] 925 | ax2.set_aspect('equal') 926 | 927 | for flow in self.flows: 928 | for link in flow.route: 929 | tx, rx = link.tx, link.rx 930 | for inf_node in self.nodes: 931 | if (inf_node not in {tx, rx}) and inf_node.p_tx[link.chan_idx] > 0: 932 | if self.p_rx[rx.nid, inf_node.nid, link.chan_idx] > self.n0 * self.bw: 933 | inr_val = normalize_snr(self.p_rx[rx.nid, inf_node.nid, link.chan_idx]) 934 | ax2.arrow(inf_node.pos[0], inf_node.pos[1], 935 | rx.pos[0] - inf_node.pos[0], rx.pos[1] - inf_node.pos[1], 936 | shape='left', length_includes_head=True, width=5, 937 | alpha=inr_val, color=cmap(inr_val)) 938 | 939 | # Plot nodes that are not idle. 940 | self.plot_nodes(ax2) 941 | # Plot the boundary of legal region. 942 | self.add_boundary(ax2) 943 | ax2.set_title("INR") 944 | 945 | fig.colorbar(plt.cm.ScalarMappable(norm=Normalize(snr_min, snr_max), cmap=cmap), ax=ax2, 946 | fraction=0.05, pad=0.05, label="SNR/INR in dB") 947 | 948 | # ============ Show/save figure. ============ 949 | # Display the image. 950 | if show_img: 951 | plt.show() 952 | 953 | # Write results to disk. 954 | if save_dir is not None: 955 | os.makedirs(save_dir, exist_ok=True) 956 | fig_name = 'snr_inr.pdf' 957 | if tag is not None: 958 | fig_name = tag + '_' + fig_name 959 | fig_path = osp.join(save_dir, fig_name) 960 | plt.savefig(fig_path) 961 | plt.close() 962 | 963 | 964 | if __name__ == '__main__': 965 | from components.misc import get_random_actions 966 | env = AdHocEnv('cls-500') 967 | env.reset() 968 | terminated = False 969 | while not terminated: 970 | avail_actions = env.get_avail_actions() 971 | print(f"At n{env.agent.front.nid}, nbrs: {[nbr.nid for nbr in env.nbrs]}") 972 | rand_action = get_random_actions(avail_actions) 973 | _, terminated, _ = env.step(rand_action) 974 | print(env.agent) 975 | env.save_replay(save_dir='./') 976 | 977 | # shared_obs = env.get_shared_obs_relations() 978 | # for k, v in shared_obs.items(): 979 | # print(f"shared_obs[{k}] = \n{v}.") 980 | -------------------------------------------------------------------------------- /envs/ad_hoc/ad_hoc_entities.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | from numpy import ndarray 4 | 5 | 6 | class Node: 7 | """Wireless node in Ad Hoc networks""" 8 | 9 | legal_roles = {'src', 'rly', 'dst'} # Legal roles of nodes are source/relay/destination 10 | 11 | def __init__(self, nid: int, n_chans: int): 12 | 13 | self.nid = nid # Identifier of node 14 | self.n_chans = n_chans # Number of sub-channels 15 | self.pos: Optional[ndarray] = None 16 | self.role: str = 'rly' 17 | 18 | self.p_tx = np.zeros(self.n_chans, dtype=np.float32) # Tx power on each sub-channel (Watt) 19 | self.idle = np.zeros(self.n_chans, dtype=bool) 20 | 21 | def reset(self, pos_node: ndarray): 22 | self.pos = pos_node 23 | self.role = 'rly' 24 | self.p_tx = np.zeros(self.n_chans, dtype=np.float32) 25 | self.idle = np.ones(self.n_chans, dtype=bool) 26 | 27 | 28 | class Link: 29 | """Wireless link between a transceiver pair""" 30 | 31 | def __init__(self, tx_node: Node, rx_node: Node, chan_idx: int, p_tx: float): 32 | self.tx: Node = tx_node 33 | self.rx: Node = rx_node 34 | self.chan_idx: int = chan_idx 35 | 36 | self.tx.p_tx[self.chan_idx] = p_tx 37 | self.tx.idle[self.chan_idx] = False 38 | self.rx.idle[self.chan_idx] = False 39 | 40 | @property 41 | def p_tx(self) -> float: 42 | return self.tx.p_tx[self.chan_idx] 43 | 44 | def remove(self): 45 | self.tx.p_tx[self.chan_idx] = 0 46 | self.tx.idle[self.chan_idx] = True 47 | self.rx.idle[self.chan_idx] = True 48 | 49 | 50 | class Flow: 51 | """Data flow with a multi-hop route""" 52 | 53 | def __init__(self, 54 | fid: int, # Identifier of node 55 | n_chans: int, # Number of sub-channels 56 | n_nodes: int, # Number of nodes 57 | p_levels: list[float], # Discrete levels of transmit power (Watt) 58 | p_budget: float, # Total power budget of the flow (Watt) 59 | allow_full_duplex: bool = False, # Whether to allow simultaneous Tx/Rx on one channel 60 | ): 61 | 62 | self.fid = fid 63 | self.n_chans = n_chans 64 | self.n_nodes = n_nodes 65 | 66 | self.route: list[Link] = [] # Route list 67 | self.src: Optional[Node] = None # Source node 68 | self.dst: Optional[Node] = None # Destination node 69 | 70 | self.qual_nodes: ndarray = np.ones(self.n_nodes, dtype=bool) # Qualification of candidate nodes 71 | self.p_bdg = p_budget 72 | self.p_lvs = p_levels 73 | 74 | self._allow_full_duplex = allow_full_duplex 75 | 76 | def reset(self, src_node: Node, dst_node: Node): 77 | # Assign source/destination. 78 | src_node.role = 'src' 79 | self.src = src_node 80 | dst_node.role = 'dst' 81 | self.dst = dst_node 82 | 83 | # Clear the route. 84 | self.clear_route() 85 | 86 | def clear_route(self): 87 | """Clears route.""" 88 | # Remove all links in the route. 89 | for link in self.route: 90 | link.remove() 91 | self.route = [] # Empty route. 92 | 93 | # All nodes but source can be added to route in the future. 94 | self.qual_nodes = np.ones(self.n_nodes, dtype=bool) 95 | self.qual_nodes[self.src.nid] = 0 96 | 97 | def check(self, node: Node) -> bool: 98 | """Checks whether a node is qualified to be added.""" 99 | if node.nid in self.nids_in_route: 100 | return False 101 | if (node.role == 'rly') and self.qual_nodes[node.nid]: 102 | if self._allow_full_duplex: # Simultaneous transmitting/receiving on the same channel is allowed. 103 | if node.idle.any(): 104 | return True 105 | else: # Transmitting/receiving must take different sub-channels. 106 | # Then, a qualified next relay must meet two criteria: 107 | # 1) An idle sub-channel is shared by next/front nodes; 108 | # 2) Another sub-channel is available for later hop. 109 | if (node.idle.sum() >= 2) and (node.idle * self.front.idle).any(): 110 | return True 111 | elif node is self.dst: 112 | if self._allow_full_duplex: 113 | if node.idle.any(): 114 | return True 115 | else: 116 | if (node.idle * self.front.idle).any(): 117 | return True 118 | else: 119 | return False 120 | 121 | def ban(self, node: Node) -> None: 122 | """Bans a node from the route.""" 123 | self.qual_nodes[node.nid] = False 124 | 125 | def add_hop(self, next_node: Node, chan_idx: int, p_idx: int) -> None: 126 | """Adds a hop to the route.""" 127 | # Create a link between current front node and next node. 128 | link = Link(self.front, next_node, chan_idx, self.p_lvs[p_idx]) 129 | # Add link to route. 130 | self.route.append(link) 131 | # Disqualify added node to prevent loop in the route. 132 | self.qual_nodes[next_node.nid] = 0 133 | 134 | @property 135 | def n_hops(self) -> int: 136 | """Number of hops""" 137 | return len(self.route) 138 | 139 | @property 140 | def is_connected(self) -> bool: 141 | """Returns whether source node is connected to destination with current route""" 142 | return self.front == self.dst 143 | 144 | @property 145 | def front(self) -> Node: 146 | """Current frontier node""" 147 | return self.src if len(self.route) == 0 else self.route[-1].rx 148 | 149 | @property 150 | def n_pow_lvs(self) -> int: 151 | """Number of Tx power levels""" 152 | return len(self.p_lvs) 153 | 154 | @property 155 | def p_tot(self) -> float: 156 | """Total power cost""" 157 | return sum([link.p_tx for link in self.route]) 158 | 159 | @property 160 | def p_rem(self) -> float: 161 | """Remaining power budget""" 162 | return self.p_bdg - self.p_tot 163 | 164 | @property 165 | def nids_in_route(self): 166 | return [self.src.nid] + [link.rx.nid for link in self.route] 167 | 168 | def __repr__(self): 169 | return f"Flow-{self.fid} from n{self.src.nid} to n{self.dst.nid} with route {self.nids_in_route}" 170 | -------------------------------------------------------------------------------- /envs/ad_hoc/ad_hoc_layouts.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from abc import abstractmethod 3 | import random 4 | import numpy as np 5 | 6 | from envs.common import * 7 | 8 | 9 | class AdHocLayout: 10 | """Base class for all layouts of Ad Hoc Networks""" 11 | 12 | @abstractmethod 13 | def set_entities(self) -> tuple[ndarray, ndarray, ndarray]: 14 | """Sets positions of nodes and source/destination of data flows.""" 15 | raise NotImplementedError 16 | 17 | 18 | @dataclass 19 | class Stationary(AdHocLayout): 20 | """Stationary layout""" 21 | 22 | range_pos: float = 500 # Range of position (m) 23 | dim_pos: int = 2 # Dimension of position 24 | min_link_dist: float = 1 # Minimum distance between nodes (m) 25 | r_sns: float = 250 # Range of neighbor detection 26 | 27 | n_agt_flows: int = 1 # Number of data flows controlled by agents 28 | n_amb_flows: int = 1 # Number of ambient data flows 29 | max_nbrs: int = 10 # Maximum number of neighbors to consider 30 | n_nodes_per_rgn: tuple = (6, 5, 6, 5, 6, 5, 6, 5, 6) # Nodes density profiled over subregions 31 | max_hops: int = 20 # Maximum number of hops 32 | 33 | chan_name: str = 'itu1411' # Channel model 34 | p_max_dbm: float = 20 # Maximum Tx power of data flows (dBm) 35 | n_pow_lvs: int = 4 # Number of discrete power levels for agent flows 36 | p_bdg: float = float('inf') # Power budget (Watt) 37 | n0_dbm: float = -150 # PSD of Rx noise (dBm/Hz) 38 | p_amb_dbm: float = 10 # Tx power of ambient flows (dBm) 39 | 40 | tot_bw: float = 5e6 # Total available bandwidth (Hz) 41 | n_chans: int = 1 # Number of sub-channels 42 | 43 | def __post_init__(self): 44 | self.n_nodes = sum(self.n_nodes_per_rgn) # Number of nodes 45 | self.n_flows = self.n_agt_flows + self.n_amb_flows 46 | 47 | # Divide area into subregions. 48 | self.n_rgns = len(self.n_nodes_per_rgn) # Number of subregions 49 | n_divs = int(np.sqrt(len(self.n_nodes_per_rgn))) # Division in each dimension 50 | self.range_rgn = self.range_pos / n_divs # Range of each subregion (m) 51 | 52 | self.p_max = 1e-3 * np.power(10, self.p_max_dbm / 10) # Tx power of data flows (Watt) 53 | self.n0 = 1e-3 * np.power(10, self.n0_dbm / 10) # PSD of Rx noise (Watt/Hz) 54 | self.p_amb = 1e-3 * np.power(10, self.p_amb_dbm / 10) # Tx power of ambient flows (Watt) 55 | 56 | def set_stationary_node_positions(self): 57 | """Evenly sets positions of nodes in each subregion.""" 58 | pos_nodes = [] 59 | # Randomly draw positions of nodes in each subregion. 60 | for rgn_idx in range(self.n_rgns): 61 | o_rgn = self.range_rgn * np.array([rgn_idx // 3, rgn_idx % 3]) # Origin of subrigion 62 | coor_nodes_in_rgn = select_from_box(self.n_nodes_per_rgn[rgn_idx], 0, int(self.range_rgn // self.min_link_dist), self.dim_pos) 63 | pos_nodes_in_rgn = o_rgn + self.min_link_dist * coor_nodes_in_rgn 64 | pos_nodes.append(pos_nodes_in_rgn) 65 | pos_nodes = np.concatenate(pos_nodes) 66 | return pos_nodes 67 | 68 | def sample_nodes_from_region(self, rgn_idx, n_nodes): 69 | """Samples distinct nodes in a subregion.""" 70 | nids = random.sample(range(self.n_nodes_per_rgn[rgn_idx]), n_nodes) 71 | nids = np.array(nids, dtype=int) + sum(self.n_nodes_per_rgn[:rgn_idx]) 72 | return nids 73 | 74 | def set_entities(self): 75 | pos_nodes = self.set_stationary_node_positions() 76 | 77 | # Ambient flows 78 | amb_rgn_pairs = random.sample([(2, 4), (6, 4)], 1) 79 | src_rgn_idx, dst_rgn_idx = amb_rgn_pairs[0] 80 | amb_src_nids = self.sample_nodes_from_region(src_rgn_idx, self.n_amb_flows) 81 | amb_dst_nids = self.sample_nodes_from_region(dst_rgn_idx, self.n_amb_flows) 82 | 83 | # Agent flows 84 | src_rgn_idx, dst_rgn_idx = (0, 8) 85 | agt_src_nids = self.sample_nodes_from_region(src_rgn_idx, self.n_agt_flows) 86 | agt_dst_nids = self.sample_nodes_from_region(dst_rgn_idx, self.n_agt_flows) 87 | 88 | src_nids = np.append(amb_src_nids, agt_src_nids) 89 | dst_nids = np.append(amb_dst_nids, agt_dst_nids) 90 | 91 | return pos_nodes, src_nids, dst_nids 92 | 93 | 94 | class QuasiStationary(Stationary): 95 | def set_entities(self): 96 | pos_nodes = self.set_stationary_node_positions() 97 | 98 | agt_rgn_pairs = random.sample([(0, 8), (8, 0), (2, 6), (6, 2)], 1)[0] 99 | if agt_rgn_pairs in {(0, 8), (8, 0)}: 100 | amb_rgn_pairs = random.sample([(2, 4), (6, 4)], 1)[0] 101 | else: 102 | amb_rgn_pairs = random.sample([(0, 4), (8, 4)], 1)[0] 103 | 104 | # Ambient flows 105 | src_rgn_idx, dst_rgn_idx = amb_rgn_pairs 106 | amb_src_nids = self.sample_nodes_from_region(src_rgn_idx, self.n_amb_flows) 107 | amb_dst_nids = self.sample_nodes_from_region(dst_rgn_idx, self.n_amb_flows) 108 | 109 | # Agent flows 110 | src_rgn_idx, dst_rgn_idx = agt_rgn_pairs 111 | agt_src_nids = self.sample_nodes_from_region(src_rgn_idx, self.n_agt_flows) 112 | agt_dst_nids = self.sample_nodes_from_region(dst_rgn_idx, self.n_agt_flows) 113 | 114 | src_nids = np.append(amb_src_nids, agt_src_nids) 115 | dst_nids = np.append(amb_dst_nids, agt_dst_nids) 116 | 117 | return pos_nodes, src_nids, dst_nids 118 | 119 | 120 | @dataclass 121 | class Clusters(AdHocLayout): 122 | range_pos: float = 500 # Range of position (m) 123 | dim_pos: int = 2 # Dimension of position 124 | min_link_dist: float = 1 # Minimum distance between nodes (m) 125 | r_sns: float = 250 # Range of neighbor detection 126 | 127 | n_agt_flows: int = 1 # Number of data flows controlled by agents 128 | n_amb_flows: int = 0 # Number of ambient data flows 129 | max_nbrs: int = 10 # Maximum number of neighbors to consider 130 | n_nodes_per_cl: tuple = (5, 3, 3, 4, 4, 5) 131 | max_hops: int = 20 # Maximum number of hops 132 | 133 | chan_name: str = 'itu1411' # Channel model 134 | p_max_dbm: float = 20 # Maximum Tx power of data flows (dBm) 135 | n_pow_lvs: int = 4 # Number of discrete power levels for agent flows 136 | p_bdg: float = float('inf') # Power budget (Watt) 137 | n0_dbm: float = -150 # PSD of Rx noise (dBm/Hz) 138 | p_amb_dbm: float = 10 # Tx power of ambient flows (dBm) 139 | 140 | tot_bw: float = 5e6 # Total available bandwidth (Hz) 141 | n_chans: int = 1 # Number of sub-channels 142 | 143 | def __post_init__(self): 144 | 145 | self.n_cls = len(self.n_nodes_per_cl) 146 | self.n_nodes = sum(self.n_nodes_per_cl) # Number of nodes 147 | self.n_flows = self.n_agt_flows + self.n_amb_flows 148 | 149 | self.p_max = 1e-3 * np.power(10, self.p_max_dbm / 10) # Tx power of data flows (Watt) 150 | self.n0 = 1e-3 * np.power(10, self.n0_dbm / 10) # PSD of Rx noise (Watt/Hz) 151 | self.p_amb = 1e-3 * np.power(10, self.p_amb_dbm / 10) # Tx power of ambient flows (Watt) 152 | 153 | def set_entities(self) -> tuple[ndarray, ndarray, ndarray]: 154 | pos_cls = [] 155 | # # range_pos = 600m 156 | # pos_cls.append(np.array([[300, 50], [100, 100], [500, 100]])) 157 | # sign = 2 * (np.random.randint(0, 2) - 0.5) 158 | # pos_cls.append(np.array([[300 + sign * 200, 300], [300 + sign * 200, 500]])) 159 | # pos_cls.append(np.array([[300, 550]])) 160 | 161 | # range_pos = 500m 162 | pos_cls.append(np.array([[250, 50], [50, 100], [450, 100]])) 163 | sign = 2 * (np.random.randint(0, 2) - 0.5) 164 | pos_cls.append(np.array([[250 + sign * 200, 250], [250 + sign * 200, 450]])) 165 | pos_cls.append(np.array([[250, 450]])) 166 | 167 | pos_cls = np.concatenate(pos_cls, axis=0) 168 | 169 | pos_nodes = [] 170 | for cl_idx, pos_cl in enumerate(pos_cls): 171 | r_cls = 100 if cl_idx == 0 else 75 172 | pos_nodes.append(pos_cl + select_from_box(self.n_nodes_per_cl[cl_idx], 0, r_cls, 2) - r_cls / 2) 173 | pos_nodes = np.concatenate(pos_nodes, axis=0) 174 | 175 | x, y = pos_nodes[:, 0], pos_nodes[:, 1] 176 | ang = np.random.randint(0, 4) * np.pi / 2 177 | x_o, y_o = self.range_pos / 2, self.range_pos / 2 178 | x_rot = (x - x_o) * np.cos(ang) + (y - y_o) * np.sin(ang) + x_o 179 | y_rot = - (x - x_o) * np.sin(ang) + (y - y_o) * np.cos(ang) + y_o 180 | pos_nodes = np.vstack((x_rot, y_rot)).T 181 | pos_nodes = np.clip(pos_nodes, 0, self.range_pos) 182 | 183 | src_nids = np.array([0]) 184 | dst_nids = np.array([self.n_nodes - 1]) 185 | return pos_nodes, src_nids, dst_nids 186 | 187 | 188 | kwargs_1flow = dict(n_amb_flows=0, n_nodes_per_rgn=tuple([3] * 9), max_nbrs=6) 189 | SCENARIOS = { 190 | 'debug': Stationary(max_nbrs=4, n_chans=3, n_pow_lvs=2, n_nodes_per_rgn=tuple([3] * 9)), 191 | 'debug-full-pow': Stationary(max_nbrs=4, n_chans=3, n_pow_lvs=1, n_nodes_per_rgn=tuple([3] * 9)), 192 | 193 | '1flow': QuasiStationary(n_amb_flows=0, n_nodes_per_rgn=tuple([3] * 9), max_nbrs=6,), 194 | '1flow-full-pow': QuasiStationary(n_amb_flows=0, n_nodes_per_rgn=tuple([3] * 9), max_nbrs=6, n_pow_lvs=1), 195 | 196 | 'cls': Clusters(), 197 | 'cls-full-pow': Clusters(n_pow_lvs=1), 198 | 199 | '2flows': QuasiStationary(), 200 | '2flows-full-pow': QuasiStationary(n_pow_lvs=1), 201 | 202 | '1f': QuasiStationary(n_amb_flows=0, n_nodes_per_rgn=tuple([3] * 9), max_nbrs=6,), 203 | '1f-full-pow': QuasiStationary(n_amb_flows=0, n_nodes_per_rgn=tuple([3] * 9), max_nbrs=6, n_pow_lvs=1), 204 | 205 | '1fb': QuasiStationary(n_amb_flows=0), 206 | '1fb-full-pow': QuasiStationary(n_amb_flows=0, n_pow_lvs=1), 207 | } 208 | -------------------------------------------------------------------------------- /envs/chan_models.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy import ndarray 3 | 4 | 5 | CHANNEL_MODELS = {} 6 | 7 | 8 | class FSPLChannel(object): 9 | """Free space path loss (FSPL) channel model""" 10 | def __init__(self, fc: float = 2.4e9) -> None: 11 | self.fc = fc # Central carrier frequency (Hz) 12 | 13 | def estimate_chan_gain(self, d: ndarray) -> ndarray: 14 | fspl = (4 * np.pi * self.fc * d / 3e8) ** 2 # The Friis equation 15 | return 1 / fspl 16 | 17 | 18 | CHANNEL_MODELS['fspl'] = FSPLChannel 19 | 20 | 21 | class ITU1411Channel(object): 22 | """Short-range outdoor model ITU-1411""" 23 | 24 | def __init__(self, fc: float = 2.4e9, antenna_gain_db: float = 2.5, antenna_height: float = 1.5): 25 | self.fc = fc # Carrier frequency (Hz) 26 | self.ant_gain_db = antenna_gain_db # Antenna gain in dBi 27 | self.h_ant = antenna_height # Height of antenna (m) 28 | 29 | def estimate_chan_gain(self, d: ndarray) -> ndarray: 30 | h1, h2 = self.h_ant, self.h_ant 31 | signal_lambda = 2.998e8 / self.fc 32 | # compute relevant quantity. 33 | Rbp = 4 * h1 * h2 / signal_lambda 34 | Lbp = abs(20 * np.log10(np.power(signal_lambda, 2) / (8 * np.pi * h1 * h2))) 35 | sum_term = 20 * np.log10(d / Rbp) 36 | Tx_over_Rx = Lbp + 6 + sum_term + ((d > Rbp).astype(int)) * sum_term # Adjust for longer path loss 37 | pl = -Tx_over_Rx + self.ant_gain_db # Add antenna gain 38 | return np.power(10, (pl / 10)) # convert from decibel to absolut 39 | 40 | 41 | CHANNEL_MODELS['itu1411'] = ITU1411Channel 42 | 43 | 44 | class AirToGroundChannel(object): 45 | """Air-to-ground (ATG) channel model proposed in "Optimal LAP Altitude for Maximum Coverage". 46 | Parameter a and b are provided in paper 47 | "Efficient 3-D Placement of an Aerial Base Station in Next Generation Cellular Networks". 48 | """ 49 | 50 | ATG_CHAN_PARAMS = { 51 | 'suburban': {'a': 4.88, 'b': 0.43, 'eta_los': 0.1, 'eta_nlos': 21}, 52 | 'urban': {'a': 9.61, 'b': 0.16, 'eta_los': 1, 'eta_nlos': 20}, 53 | 'dense-urban': {'a': 12.08, 'b': 0.11, 'eta_los': 1.6, 'eta_nlos': 23}, 54 | 'high-rise-urban': {'a': 27.23, 'b': 0.08, 'eta_los': 2.3, 'eta_nlos': 34} 55 | } 56 | 57 | def __init__(self, scene: str = 'urban', fc: float = 2.4e9) -> None: 58 | # Set scene-specific parameters. 59 | for k, v in self.ATG_CHAN_PARAMS[scene].items(): 60 | self.__setattr__(k, v) 61 | self.fc = fc # Central carrier frequency (Hz) 62 | 63 | def estimate_chan_gain(self, d_ground: ndarray, h_ubs: float) -> ndarray: 64 | """Estimates the channel gain from horizontal distance.""" 65 | # Get direct link distance. 66 | d_link = np.sqrt(np.square(d_ground) + np.square(h_ubs)) 67 | # Estimate probability of LoS link emergence. 68 | p_los = 1 / (1 + self.a * np.exp(-self.b * (180 / np.pi * np.arcsin(h_ubs / d_link) - self.a))) 69 | # Compute free space path loss (FSPL). 70 | fspl = (4 * np.pi * self.fc * d_link / 3e8) ** 2 71 | # Path loss is the weighted average of LoS and NLoS cases. 72 | pl = p_los * fspl * 10 ** (self.eta_los / 10) + (1 - p_los) * fspl * 10 ** (self.eta_nlos / 10) 73 | return 1 / pl 74 | 75 | 76 | CHANNEL_MODELS['a2g'] = AirToGroundChannel 77 | 78 | 79 | def compute_antenna_gain(theta: ndarray, psi: float) -> ndarray: 80 | """Computes antenna gain using simple two-lobe approximation. 81 | Direction of UBS antenna is perpendicular to ground. 82 | See "Joint Altitude and Beamwidth Optimization for UAV-Enabled Multiuser Communications" for more explanation. 83 | Args: 84 | theta (float): elevation angle from GTs to UBSs between (0, pi/2]. 85 | psi (float): Half-beamwidth of directional antennas (rad) 86 | """ 87 | g_main = 2.285 / np.power(psi, 2) # Constant gain of main lobe 88 | g_side = 0 # Ignored gain of side lobe 89 | return ((np.pi / 2 - theta) <= psi) * g_main + ((np.pi / 2 - theta) > psi) * g_side 90 | 91 | 92 | if __name__ == '__main__': 93 | d = np.arange(0, 1000, 100) 94 | print(d) 95 | chan_model = AirToGroundChannel() 96 | g = chan_model.estimate_chan_gain(d, 100) 97 | print(10 * np.log10(g)) -------------------------------------------------------------------------------- /envs/common.py: -------------------------------------------------------------------------------- 1 | r"""Code blocks shared across envs""" 2 | 3 | from abc import abstractmethod 4 | 5 | from random import sample 6 | from itertools import product, repeat 7 | 8 | import numpy as np 9 | from numpy import ndarray 10 | 11 | from envs.multi_agent_env import MultiAgentEnv 12 | 13 | # --- Math --- 14 | 15 | 16 | def get_one_hot(x: int, n: int) -> ndarray: 17 | """Gets one-hot encoding of integer.""" 18 | v = np.zeros(n, dtype=np.float32) 19 | v[x] = 1 20 | return v 21 | 22 | 23 | def select_from_box(n_els: int, min_val: int, max_val: int, n_dims: int) -> ndarray: 24 | """Selects non-repetitive elements from a box.""" 25 | legal_points = list(product(*list(repeat(np.arange(min_val, max_val), n_dims)))) 26 | return np.array(sample(legal_points, n_els), dtype=np.float32) 27 | 28 | 29 | def compute_jain_fairness_index(x: ndarray) -> float: 30 | """Computes the Jain's fairness index of entries in given ndarray.""" 31 | if x.size > 0: 32 | x = np.clip(x, 1e-6, np.inf) 33 | return np.square(x.sum()) / (x.size * np.square(x).sum()) 34 | else: 35 | return 1 36 | 37 | 38 | def build_discrete_moves(n_dirs: int, v_max: float, v_levels: int) -> ndarray: 39 | """Builds discrete moves from kronecker product of directions and velocities.""" 40 | move_amounts = [] 41 | v = v_max 42 | for i in range(v_levels): 43 | move_amounts.append(v) 44 | v /= 2 45 | move_amounts = np.array(move_amounts).reshape(-1, 1) # Amounts of movement in each timestep (m) 46 | ang = 2 * np.pi * np.arange(n_dirs) / n_dirs # Possible flying angles 47 | move_dirs = np.stack([np.cos(ang), np.sin(ang)]).T # Moving directions of UBSs 48 | avail_moves = np.concatenate((np.zeros((1, 2)), np.kron(move_amounts, move_dirs))) # Available moves 49 | return avail_moves 50 | 51 | 52 | # --- Recorder --- 53 | 54 | class Recorder(object): 55 | """Tool to record condition of env at each timestep.""" 56 | 57 | def __init__(self, env: MultiAgentEnv, variables) -> None: 58 | self.env = env 59 | self.variables = variables 60 | self.film = {k: [] for k in self.variables} # A dict to hold records of variables 61 | 62 | def __getattr__(self, item): 63 | if item == '__setstate__': 64 | raise AttributeError(item) 65 | else: 66 | return getattr(self.env, item) 67 | 68 | def reload(self) -> None: 69 | """Clears the film to prepare for new episode.""" 70 | self.film = {k: [] for k in self.variables} 71 | 72 | def click(self) -> None: 73 | """Takes the snapshot at each timestep.""" 74 | for k in self.variables: 75 | v = self.__getattr__(k) 76 | if isinstance(v, ndarray): 77 | self.film[k].append(v.copy()) 78 | else: 79 | self.film[k].append(v) 80 | 81 | @ abstractmethod 82 | def replay(self, **kwargs): 83 | """Replays the entire episode from recording.""" 84 | raise NotImplementedError 85 | 86 | 87 | # --- Functions to plot simple objects --- 88 | 89 | 90 | def plot_line(a: ndarray, b: ndarray) -> list[ndarray]: 91 | """Plots a line from point a to point b.""" 92 | assert a.shape == b.shape, "Inconsistent dimension between a and b." 93 | return [np.linspace(a[d], b[d], 50) for d in range(a.size)] 94 | 95 | 96 | def plot_circ(x_o: ndarray, y_o: ndarray, r: float) -> tuple[ndarray, ndarray]: 97 | """Plots a circle centered at given origin (x_0, y_o) with radius r.""" 98 | assert x_o.shape == x_o.shape, "Inconsistent dimension between x_o and y_o." 99 | t = np.linspace(0, 2 * np.pi, 100) 100 | x_data, y_data = r * np.cos(t), r * np.sin(t) 101 | return x_o + x_data, y_o + y_data 102 | 103 | 104 | def plot_segments(points: list[ndarray]): 105 | """Plots segments between points.""" 106 | for p in points: 107 | assert p.ndim == 1 and p.size == 2, "Invalid shape of points." 108 | 109 | x, y = [], [] 110 | for s in range(len(points) - 1): 111 | seg = plot_line(points[s], points[s + 1]) 112 | x.append(seg[0]) 113 | y.append(seg[1]) 114 | return np.concatenate(x), np.concatenate(y) 115 | 116 | 117 | def plot_boundary(range_pos: float, symmetric: bool = False): 118 | if not symmetric: 119 | b = plot_segments([np.array([0, 0]), np.array([range_pos, 0]), np.array([range_pos, range_pos]), 120 | np.array([0, range_pos]), np.array([0, 0])]) 121 | else: 122 | b = plot_segments([np.array([-range_pos, -range_pos]), np.array([-range_pos, range_pos]), 123 | np.array([range_pos, range_pos]), np.array([range_pos, -range_pos]), 124 | np.array([-range_pos, -range_pos])]) 125 | return b 126 | -------------------------------------------------------------------------------- /envs/heatmap.py: -------------------------------------------------------------------------------- 1 | r"""Plot and annotate heatmaps with matplotlib. 2 | 3 | See https://matplotlib.org/stable/gallery/images_contours_and_fields/image_annotated_heatmap.html#. 4 | """ 5 | 6 | import numpy as np 7 | import matplotlib 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def heatmap(data, row_labels, col_labels, ax=None, 12 | cbar_kw=None, cbarlabel="", **kwargs): 13 | """ 14 | Create a heatmap from a numpy array and two lists of labels. 15 | 16 | Parameters 17 | ---------- 18 | data 19 | A 2D numpy array of shape (M, N). 20 | row_labels 21 | A list or array of length M with the labels for the rows. 22 | col_labels 23 | A list or array of length N with the labels for the columns. 24 | ax 25 | A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If 26 | not provided, use current axes or create a new one. Optional. 27 | cbar_kw 28 | A dictionary with arguments to `matplotlib.Figure.colorbar`. Optional. 29 | cbarlabel 30 | The label for the colorbar. Optional. 31 | **kwargs 32 | All other arguments are forwarded to `imshow`. 33 | """ 34 | 35 | if ax is None: 36 | ax = plt.gca() 37 | 38 | if cbar_kw is None: 39 | cbar_kw = {} 40 | 41 | # Plot the heatmap 42 | im = ax.imshow(data, **kwargs) 43 | 44 | # Create colorbar 45 | cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw, label=cbarlabel) 46 | # cbar.ax.set_ylabel(cbarlabel, rotation=-90, va="bottom") 47 | 48 | # Show all ticks and label them with the respective list entries. 49 | ax.set_xticks(np.arange(data.shape[1]), labels=col_labels) 50 | ax.set_yticks(np.arange(data.shape[0]), labels=row_labels) 51 | 52 | # Let the horizontal axes labeling appear on top. 53 | ax.tick_params(top=True, bottom=False, 54 | labeltop=True, labelbottom=False) 55 | 56 | # Rotate the tick labels and set their alignment. 57 | # plt.setp(ax.get_xticklabels(), rotation=30, ha="right", 58 | # rotation_mode="anchor") 59 | 60 | # Turn spines off and create white grid. 61 | ax.spines[:].set_visible(False) 62 | 63 | ax.set_xticks(np.arange(data.shape[1]+1)-.5, minor=True) 64 | ax.set_yticks(np.arange(data.shape[0]+1)-.5, minor=True) 65 | ax.grid(which="minor", color="w", linestyle='-', linewidth=3) 66 | ax.tick_params(which="minor", bottom=False, left=False) 67 | 68 | return im, cbar 69 | 70 | 71 | def annotate_heatmap(im, data=None, valfmt="{x:.2f}", 72 | textcolors=("black", "white"), 73 | threshold=None, **textkw): 74 | """ 75 | A function to annotate a heatmap. 76 | 77 | Parameters 78 | ---------- 79 | im 80 | The AxesImage to be labeled. 81 | data 82 | Data used to annotate. If None, the image's data is used. Optional. 83 | valfmt 84 | The format of the annotations inside the heatmap. This should either 85 | use the string format method, e.g. "$ {x:.2f}", or be a 86 | `matplotlib.ticker.Formatter`. Optional. 87 | textcolors 88 | A pair of colors. The first is used for values below a threshold, 89 | the second for those above. Optional. 90 | threshold 91 | Value in data units according to which the colors from textcolors are 92 | applied. If None (the default) uses the middle of the colormap as 93 | separation. Optional. 94 | **kwargs 95 | All other arguments are forwarded to each call to `text` used to create 96 | the text labels. 97 | """ 98 | 99 | if not isinstance(data, (list, np.ndarray)): 100 | data = im.get_array() 101 | 102 | # Normalize the threshold to the images color range. 103 | if threshold is not None: 104 | threshold = im.norm(threshold) 105 | else: 106 | threshold = im.norm(data.max())/2. 107 | 108 | # Set default alignment to center, but allow it to be 109 | # overwritten by textkw. 110 | kw = dict(horizontalalignment="center", 111 | verticalalignment="center") 112 | kw.update(textkw) 113 | 114 | # Get the formatter in case a string is supplied 115 | if isinstance(valfmt, str): 116 | valfmt = matplotlib.ticker.StrMethodFormatter(valfmt) 117 | 118 | # Loop over the data and create a `Text` for each "pixel". 119 | # Change the text's color depending on the data. 120 | texts = [] 121 | for i in range(data.shape[0]): 122 | for j in range(data.shape[1]): 123 | kw.update(color=textcolors[int(im.norm(data[i, j]) > threshold)]) 124 | text = im.axes.text(j, i, valfmt(data[i, j], None), **kw) 125 | texts.append(text) 126 | 127 | return texts -------------------------------------------------------------------------------- /envs/multi_agent_env.py: -------------------------------------------------------------------------------- 1 | class MultiAgentEnv(object): 2 | def step(self, actions): 3 | """Returns reward, terminated, info.""" 4 | raise NotImplementedError 5 | 6 | def get_obs(self): 7 | """Returns all agent observations in a list.""" 8 | raise NotImplementedError 9 | 10 | def get_obs_agent(self, agent_id): 11 | """Returns observation for agent_id.""" 12 | raise NotImplementedError 13 | 14 | def get_obs_size(self): 15 | """Returns the size of the observation.""" 16 | raise NotImplementedError 17 | 18 | def get_shared_obs(self): 19 | """Returns all agent shared observations in a list.""" 20 | raise NotImplementedError 21 | 22 | def get_shared_obs_size(self): 23 | """Returns the size of the shared observation.""" 24 | return 25 | 26 | def get_state(self): 27 | """Returns the global state.""" 28 | raise NotImplementedError 29 | 30 | def get_state_size(self): 31 | """Returns the size of the global state.""" 32 | return 33 | 34 | def get_avail_actions(self): 35 | """Returns the available actions of all models in a list.""" 36 | raise NotImplementedError 37 | 38 | def get_avail_agent_actions(self, agent_id): 39 | """Returns the available actions for agent_id.""" 40 | raise NotImplementedError 41 | 42 | def get_total_actions(self): 43 | """Returns the total number of actions an agent could ever take.""" 44 | raise NotImplementedError 45 | 46 | def reset(self): 47 | """Returns initial observations and states.""" 48 | raise NotImplementedError 49 | 50 | def render(self): 51 | raise NotImplementedError 52 | 53 | def close(self): 54 | pass 55 | 56 | def seed(self): 57 | raise NotImplementedError 58 | 59 | def save_replay(self): 60 | """Save a replay.""" 61 | raise NotImplementedError 62 | 63 | def get_env_info(self): 64 | env_info = { 65 | "obs_shape": self.get_obs_size(), 66 | "shared_obs_shape": self.get_shared_obs_size(), 67 | "state_shape": self.get_state_size(), 68 | "n_actions": self.get_total_actions(), 69 | "n_agents": self.n_agents, 70 | "episode_limit": getattr(self, 'episode_limit', None), 71 | } 72 | return env_info 73 | -------------------------------------------------------------------------------- /exp_queue.py: -------------------------------------------------------------------------------- 1 | from run import * 2 | 3 | if __name__ == '__main__': 4 | 5 | # Setup common parameters shared across runs. 6 | base_scene = '1fb' 7 | exp_tag = base_scene 8 | common_kwargs = { 9 | 'cuda_idx': 0, 10 | 'env_id': 'ad-hoc', 11 | 'env_kwargs': {'scenario_name': base_scene}, 12 | 13 | 'record_tests': True, 14 | 'use_wandb': True, 15 | 16 | 'total_env_steps': 1000000, 17 | 'steps_per_session': 50000, 18 | 'test_interval': 10000, 19 | 'n_test_episodes': 25, 20 | 'save_interval': 2000000, 21 | 22 | 'rollout_len': 20, 23 | 'data_chunk_len': 10, 24 | 'eps_anneal_time': 100000, 25 | 'gamma': 0.98, 26 | 'batch_size': 32, 27 | } 28 | 29 | mlp_agent = {'hidden_size': 256} 30 | graph_agent = {'obs': 'graph', 'agent': 'g-adhoc', 'hidden_size': 128} 31 | 32 | # Form a queue of runs. 33 | queues = { 34 | # Full power transmission 35 | 'full_rnn': {'algo_name': 'q', **mlp_agent, 36 | 'env_kwargs': {'scenario_name': f'{base_scene}-full-pow'}}, 37 | 'full_1hop-gnn': {'algo_name': 'q', **graph_agent, 38 | 'env_kwargs': {'scenario_name': f'{base_scene}-full-pow'}}, 39 | 'full_2hop-gnn': {'algo_name': 'q', **graph_agent, 40 | 'env_kwargs': {'scenario_name': f'{base_scene}-full-pow', 'graph_khops': 2}}, 41 | # Random power selection 42 | 'rand_rnn': {'algo_name': 'q', **mlp_agent, 43 | 'env_kwargs': {'learn_power_control': False}}, 44 | 'rand_1hop-gnn': {'algo_name': 'q', **graph_agent, 45 | 'env_kwargs': {'learn_power_control': False}}, 46 | 'rand_2hop-gnn': {'algo_name': 'q', **graph_agent, 47 | 'env_kwargs': {'learn_power_control': False, 'graph_khops': 2}}, 48 | # Learn cross-layer optimization 49 | 'pc_rnn': {'algo_name': 'q', **mlp_agent}, 50 | 'pc_1hop-gnn': {'algo_name': 'q', **graph_agent}, 51 | 'pc_2hop-gnn': {'algo_name': 'q', **graph_agent, 'env_kwargs': {'graph_khops': 2}}, 52 | } 53 | 54 | # Assign random seeds. 55 | seeds = [0, 1, 2] 56 | # Display run names. 57 | print(f"Following {len(seeds) * len(queues)} runs are launched in total:") 58 | for run_name in queues: 59 | print(run_name) 60 | # Sequentially start each run. 61 | for seed in seeds: 62 | for tag, param_dict in queues.items(): 63 | # Build tag of the run. 64 | run_tag = exp_tag + '_' + tag if exp_tag is not None else tag 65 | 66 | # Get all kwargs of the run. 67 | run_kwargs = config_copy(common_kwargs) 68 | run_kwargs = recursive_dict_update(run_kwargs, param_dict) 69 | # Extract `env_id`, `env_kwargs` and `algo_name`. 70 | env_id = run_kwargs['env_id'] 71 | env_kwargs = run_kwargs['env_kwargs'] 72 | algo_name = run_kwargs['algo_name'] 73 | # Drop irrelevant items and `get train_kwargs`. 74 | train_kwargs = config_copy(run_kwargs) 75 | del train_kwargs['env_id'], train_kwargs['env_kwargs'], train_kwargs['algo_name'] 76 | 77 | # Call run. 78 | run(env_id, env_kwargs, seed, algo_name, train_kwargs, run_tag, add_suffix=True, suffix=base_scene) 79 | -------------------------------------------------------------------------------- /learners/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from learners.q_learner import QLearner 4 | REGISTRY['q'] = QLearner 5 | 6 | DETERMINISTIC_POLICY_GRADIENT_ALGOS = {'ddpg'} # Algorithms using deterministic policy gradients 7 | -------------------------------------------------------------------------------- /learners/base_learner.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | 4 | class BaseLearner: 5 | """Base class of learners to train agents as well as auxiliary modules""" 6 | 7 | def eval(self): 8 | """Sets auxiliary modules to eval mode.""" 9 | return 10 | 11 | def init_hidden(self): 12 | """Initializes RNN states for auxiliary modules""" 13 | return dict() 14 | 15 | def step(self, *args): 16 | """Gets output of auxiliary modules for env step.""" 17 | return dict() 18 | 19 | def schedule_lr(self): 20 | """Calls step of learning rate scheduler(s).""" 21 | raise NotImplementedError 22 | 23 | @abstractmethod 24 | def reset(self): 25 | """Resets learner before training loop.""" 26 | raise NotImplementedError 27 | 28 | @ abstractmethod 29 | def update(self, buffer, batch_size: int): 30 | """Updates parameters of modules from collected data.""" 31 | raise NotImplementedError 32 | 33 | def soft_target_sync(self): 34 | raise NotImplementedError 35 | 36 | def hard_target_sync(self): 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | def save_checkpoint(self, path): 41 | raise NotImplementedError 42 | 43 | @abstractmethod 44 | def load_checkpoint(self, path): 45 | raise NotImplementedError 46 | -------------------------------------------------------------------------------- /learners/q_learner.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | from components.misc import * 6 | from learners.base_learner import BaseLearner 7 | from modules.mixers import REGISTRY as mix_REGISTRY 8 | 9 | 10 | class QLearner(BaseLearner): 11 | """Multi-agent Q learning with recurrent models""" 12 | 13 | def __init__(self, env_info, policy, args) -> None: 14 | 15 | self.device = args.device 16 | self.online_policy = policy 17 | print(self.online_policy) 18 | self.params = list(self.online_policy.parameters()) 19 | self.target_policy = deepcopy(self.online_policy) 20 | 21 | self.args = args # Arguments 22 | self.n_agents = policy.n_agents # Number of agents 23 | self.n_updates = None # Number of completed updates 24 | 25 | # Set mixer to combine individual state-action values to global ones. 26 | self.mixer = None 27 | if hasattr(args, 'mixer'): # Mixer is specified. 28 | self.mixer = mix_REGISTRY[args.mixer](env_info['state_shape'], self.n_agents, args).to(self.device) 29 | print(f"Mixer: \n{self.mixer}") 30 | self.params += list(self.mixer.parameters()) 31 | self.target_mixer = deepcopy(self.mixer) 32 | 33 | # Define optimizer. 34 | self._use_huber_loss = args.use_huber_loss # Whether Huber loss is used. 35 | self.optimizer = th.optim.Adam(self.params, lr=args.lr, eps=args.optim_eps) 36 | 37 | # Set learning rate scheduler. 38 | if self.args.anneal_lr: 39 | lr_lam = get_clipped_linear_decay(total_steps=10, threshold=0.4) 40 | self.lr_scheduler = th.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lr_lam, verbose=True) 41 | 42 | self._use_double_q = args.use_double_q # Whether double Q-learning is used 43 | 44 | def reset(self): 45 | """Resets the learner.""" 46 | self.n_updates = 0 # Reset number of updates 47 | 48 | def schedule_lr(self): 49 | if self.args.anneal_lr: 50 | self.lr_scheduler.step() 51 | 52 | def update(self, buffer, batch_size: int): 53 | """Updates parameters of recurrent models via BPTT.""" 54 | self.online_policy.train() # Set policy to train mode. 55 | self.target_policy.train() 56 | 57 | batch = buffer.recall(batch_size) # Batched samples from reply buffer 58 | 59 | obs = [batch['obs'][t].to(self.device) for t in range(len(batch['obs']))] 60 | avail_actions = th.stack(batch['avail_actions']).to(self.device) if 'avail_actions' in batch else None 61 | actions = th.stack(batch['actions']).to(self.device) # Shape (data_chunk_len, batch_size * n_agents, 1) 62 | rewards = th.stack(batch['rewards']).to(self.device) # Shape (data_chunk_len, batch_size, n_agents) 63 | terminated = th.stack(batch['terminated']).to(self.device) # Shape (data_chunk_len, batch_size, 1) 64 | mask = th.stack(batch['filled']).to(self.device) # Shape (data_chunk_len, batch_size, 1) 65 | h, h_targ = batch['h'][0].to(self.device), batch['h'][1].to(self.device) # Get initial hidden states. 66 | 67 | # print(f"rewards.size() = {rewards.size()}, mask.size() = {mask.size()}") 68 | # for t in range(mask.size(0)): 69 | # print(f"t = {t}, rewards = {rewards[t].squeeze()}, terminated = {terminated[t].squeeze()}, mask = {mask[t].squeeze()}") 70 | 71 | assert self.args.data_chunk_len == len(obs) - 1, "Improper length of sequences found in mini-batch." 72 | # Get maximum filled steps. 73 | horizon = 0 # Time horizon of forward computation 74 | for t in range(self.args.data_chunk_len): 75 | if mask[t].any(): 76 | horizon = t + 1 77 | # Truncate batch sequences to horizon. 78 | obs, avail_actions = obs[:horizon + 1], avail_actions[:horizon + 1] 79 | actions, rewards, terminated, mask = actions[:horizon], rewards[:horizon], terminated[:horizon], mask[:horizon] 80 | 81 | # print(f"len(obs) = {len(obs)}") 82 | # print(f"actions.size() = {actions.size()}") 83 | # print(f"After truncation,") 84 | # print(f"rewards = \n{rewards.squeeze(-1)}") 85 | # print(f"terminated = \n{terminated.squeeze(-1)}") 86 | # print(f"mask = \n{mask.squeeze(-1)}") 87 | 88 | online_agent_out, target_agent_out = [], [] 89 | for t in range(horizon): 90 | # Policy network predicts the Q(o_{t},a). 91 | logits, h = self.online_policy.forward(obs[t], h) 92 | online_agent_out.append(logits) 93 | # Reset RNN states of policy network when termination of episode is encountered. 94 | h = h * (1 - terminated[t]).expand(batch_size, self.n_agents).reshape(-1, 1) 95 | 96 | # Target network predicts Q(o_{t+1}, a). 97 | with th.no_grad(): 98 | next_logits, h_targ = self.target_policy.forward(obs[t + 1], h_targ) 99 | target_agent_out.append(next_logits) 100 | if t + 1 < horizon: 101 | # Reset RNN states of target network when termination of episode is encountered. 102 | h_targ = h_targ * (1 - terminated[t + 1]).expand(batch_size, self.n_agents).reshape(-1, 1) 103 | 104 | # Let policy network make predictions for next obs of the last timestep. 105 | logits, h = self.online_policy.forward(obs[horizon], h) 106 | online_agent_out.append(logits) 107 | 108 | # Stack outputs of policy/target networks over time. 109 | online_logits, target_logits = th.stack(online_agent_out), th.stack(target_agent_out) 110 | 111 | # Get Q(o_{t}, a_{t}) from online logits. 112 | q_vals = online_logits[:-1].gather(2, actions).view(horizon, batch_size, self.n_agents) 113 | 114 | # Get V(o_{t+1}) from target logits. 115 | if not self._use_double_q: # Greedy selection of a_{t+1} 116 | if avail_actions is not None: 117 | target_logits[avail_actions[1:] == 0] = -1e10 # Mask unavailable actions. 118 | # Pick the largest Q values from target network. 119 | next_v = target_logits.max(2, keepdim=True)[0] 120 | else: # Double Q-learning is used. 121 | online_logits_detach = online_logits.clone().detach() # Duplicate output from policy network. 122 | if avail_actions is not None: 123 | online_logits_detach[avail_actions == 0] = -1e10 # Mask unavailable actions. 124 | next_actions = th.argmax(online_logits_detach[1:], 2, keepdim=True) # Next actions selected by policy network 125 | next_v = target_logits.gather(2, next_actions) # Pick Q values of specified by next actions. 126 | 127 | next_v = next_v.view(horizon, batch_size, self.n_agents) # Reshaped V(o_{t+1}). 128 | 129 | # When mixer is used, combine individual Q values from agents. 130 | if self.mixer is not None: 131 | # Aggregate individual rewards to global ones. 132 | rewards = rewards.mean(axis=-1, keepdims=True) 133 | # Get env states. 134 | states = th.stack(batch['state']).to(self.device) 135 | # Mix individual values v(o_{t}^{1}),...,v(o_{t}^{n}) and states s_{t} to global ones v(s_{t}). 136 | q_vals = self.mixer(q_vals, states[:-1]) 137 | with th.no_grad(): 138 | # Likewise, get v(s_{t+1}). 139 | next_v = self.target_mixer(next_v, states[1:]) 140 | 141 | # Obtain one-step target of Q-learning as r_{t} + gamma * (1 - d) * v(s_{t+1}). 142 | q_targets = rewards + self.args.gamma * (1 - terminated) * next_v 143 | 144 | # Compute masked MSE loss. 145 | td_error = q_vals - q_targets # TD error 146 | loss = huber_loss(td_error, mask) if self._use_huber_loss else mse_loss(td_error, mask) # Q loss 147 | 148 | # Call one step of gradient descent. 149 | self.optimizer.zero_grad() 150 | loss.backward() # Back propagation 151 | grad_norm = nn.utils.clip_grad_norm_(self.params, max_norm=self.args.max_grad_norm) # Gradient-clipping 152 | self.optimizer.step() # Call update. 153 | 154 | self.n_updates += 1 # Finish one step of policy update. 155 | # Sync parameters of target networks. 156 | if self.args.target_update_mode == "soft": 157 | self.soft_target_sync() 158 | elif self.args.target_update_mode == "hard": 159 | if self.n_updates % self.args.target_update_interval == 0: 160 | self.hard_target_sync() 161 | 162 | update_info = dict(LossQ=loss.item(), QVals=q_vals.detach().cpu().numpy(), GradNorm=grad_norm) 163 | return update_info 164 | 165 | @th.no_grad() 166 | def soft_target_sync(self) -> None: 167 | """Applies soft update of target networks via polyak averaging.""" 168 | soft_target_update(self.online_policy.model, self.target_policy.model, self.args.polyak) 169 | if self.mixer is not None: 170 | soft_target_update(self.mixer, self.target_mixer, self.args.polyak) 171 | 172 | @th.no_grad() 173 | def hard_target_sync(self) -> None: 174 | """Copies parameters of policy networks to target networks.""" 175 | self.target_policy.load_state(self.online_policy) 176 | if self.mixer is not None: 177 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 178 | 179 | def save_checkpoint(self, path) -> None: 180 | """Saves states of components.""" 181 | checkpoint = dict() # Current timestep 182 | checkpoint['model'] = self.online_policy.model.state_dict() # Parameters of policy network 183 | if self.mixer is not None: 184 | checkpoint['mixer'] = self.mixer.state_dict() # Parameters of mixer 185 | checkpoint['opt'] = self.optimizer.state_dict() # State of optimizer 186 | if self.args.anneal_lr: 187 | checkpoint['scheduler'] = self.lr_scheduler.state_dict() # State of learning rate scheduler 188 | # Save to path. 189 | th.save(checkpoint, path) 190 | 191 | def load_checkpoint(self, path): 192 | """Loads states of components to resume training or test.""" 193 | # Load checkpoint. 194 | checkpoint = th.load(path, map_location=self.device) 195 | # Load parameters of policy and target networks. 196 | self.online_policy.model.load_state(checkpoint['model']) 197 | self.target_policy.model.load_state(checkpoint['model']) 198 | # Load parameters of mixers. 199 | if self.mixer is not None: 200 | self.mixer.load_state(checkpoint['mixer']) 201 | self.target_mixer.load_state_dict(self.mixer.state_dict()) 202 | # Load state of optimizer. 203 | self.optimizer.load_state_dict(checkpoint['opt']) 204 | # Load state of learning rate scheduler. 205 | if self.args.anneal_lr: 206 | self.lr_scheduler.load_state_dict(checkpoint['scheduler']) 207 | -------------------------------------------------------------------------------- /modules/activations.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | REGISTRY = { 4 | 'relu': nn.ReLU, 5 | 'tanh': nn.Tanh, 6 | } 7 | -------------------------------------------------------------------------------- /modules/agents/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .recurrent_agent import RecurrentAgent 4 | REGISTRY['rnn'] = RecurrentAgent 5 | 6 | from .comm_agent import CommunicativeAgent 7 | REGISTRY['comm'] = CommunicativeAgent 8 | 9 | from .customized_agents import AdHocRelationalController 10 | REGISTRY['r-adhoc'] = AdHocRelationalController 11 | 12 | from .customized_agents import AdHocGraphController 13 | REGISTRY['g-adhoc'] = AdHocGraphController 14 | -------------------------------------------------------------------------------- /modules/agents/comm_agent.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | from dgl import DGLGraph 6 | from modules.basics import * 7 | from modules.dueling import DuelingLayer 8 | from modules.encoders import REGISTRY as enc_REGISTRY 9 | from modules.comm import REGISTRY as comm_REGISTRY 10 | 11 | 12 | class CommunicativeAgent(nn.Module): 13 | """Communicative agent""" 14 | 15 | def __init__(self, obs_shape, act_size, args) -> None: 16 | super(CommunicativeAgent, self).__init__() 17 | self._obs_shape = obs_shape # Shape of observations 18 | self._act_size = act_size # Number of discrete actions 19 | self._obs = args.obs # Observation format 20 | self._hidden_size = args.hidden_size # Hidden size 21 | 22 | self._comm = args.comm # Communication protocol 23 | assert self._comm in comm_REGISTRY, "Unrecognised communication protocol." 24 | 25 | self.f_enc = enc_REGISTRY[self._obs](self._obs_shape, self._hidden_size, args) # Observation encoder 26 | self.f_comm = comm_REGISTRY[self._comm](self._hidden_size, args) # Modules performing multi-agent communication 27 | if getattr(args, "use_dueling", False): 28 | self.f_out = DuelingLayer(self._hidden_size, self._act_size, args) 29 | else: 30 | self.f_out = nn.Linear(self._hidden_size, self._act_size) # Output layer 31 | 32 | def init_hidden(self) -> Tensor: 33 | """Initializes RNN hidden states.""" 34 | return th.zeros(1, self._hidden_size) 35 | 36 | def forward(self, obs, h): 37 | # Encode observations. 38 | x = self.f_enc(obs) 39 | # Apply multi-agent communication and update RNN hidden states after communicating among models. 40 | x, h = self.f_comm(obs['talks'], x, h) # Note that call subgraph by edge name may cause ambiguity. 41 | # Compute logits for action selection. 42 | logits = self.f_out(x) 43 | return logits, h 44 | 45 | 46 | if __name__ == '__main__': 47 | from types import SimpleNamespace as SN 48 | args = SN(**dict(hidden_size=128)) 49 | print(type(args)) 50 | -------------------------------------------------------------------------------- /modules/agents/customized_agents.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | import dgl 6 | import dgl.nn.pytorch as dglnn 7 | from dgl import DGLGraph 8 | from modules.basics import * 9 | from modules.dueling import DuelingLayer 10 | from modules.encoders import REGISTRY as enc_REGISTRY 11 | from modules.graph_nn import NeighborSelector 12 | from modules.gn_blks import * 13 | from modules.utils import pad_edge_output 14 | 15 | 16 | class AdHocRelationalController(nn.Module): 17 | """Graph agent""" 18 | 19 | def __init__(self, obs_shape, act_size, args) -> None: 20 | super(AdHocRelationalController, self).__init__() 21 | self._obs_shape = obs_shape # Shape of observations 22 | self._act_size = act_size # Number of discrete actions 23 | self._obs = args.obs # Observation format 24 | self._hidden_size = args.hidden_size # Hidden size 25 | 26 | self.f_enc = enc_REGISTRY[self._obs](self._obs_shape, self._hidden_size, args) # Observation encoder 27 | self.rnn = RnnLayer(self._hidden_size, args) 28 | self.f_out = NeighborSelector(self._obs_shape['nbr'], self._hidden_size, args.n_pow_opts, 1, self._hidden_size, args) 29 | 30 | def init_hidden(self) -> Tensor: 31 | """Initializes RNN hidden states.""" 32 | return th.zeros(1, self._hidden_size) 33 | 34 | def forward(self, obs, h): 35 | # Encode observations. 36 | x = self.f_enc(obs) 37 | # Update RNN hidden states independently when communication is disabled. 38 | x, h = self.rnn(x, h) 39 | # Compute logits for action selection. 40 | logits = self.f_out(obs[('nbr', 'nearby', 'agent')], {'agent': x, 'nbr': obs.nodes['nbr'].data['feat']}) 41 | # print(f"logits.size() = {logits.size()}") 42 | return logits, h 43 | 44 | 45 | class AdHocGraphController(nn.Module): 46 | def __init__(self, obs_shape, act_size, args): 47 | super(AdHocGraphController, self).__init__() 48 | 49 | self._obs_shape = obs_shape # Shape of observations 50 | self._act_size = act_size # Number of discrete actions 51 | self.args = args 52 | 53 | self._hidden_size = args.hidden_size # Hidden size 54 | self._activation = act_REGISTRY[args.activation] # Callable to instantiate an activation function 55 | 56 | self.khops = args.khops 57 | if self.khops == 1: 58 | self.enc = nn.ModuleDict({ 59 | '1hop': NodeGNBlock((self._obs_shape['nbr'], self._obs_shape['agent']), 60 | self._obs_shape['hop'], 61 | self._hidden_size, 62 | activation_type=args.activation), 63 | }) 64 | elif self.khops == 2: 65 | self.enc = nn.ModuleDict({ 66 | '2hop': NodeGNBlock((self._obs_shape['nbr'], self._obs_shape['nbr']), 67 | self._obs_shape['hop'], 68 | self._hidden_size, 69 | activation_type=args.activation), 70 | '1hop': NodeGNBlock((self._hidden_size, self._obs_shape['agent']), 71 | self._obs_shape['hop'], 72 | self._hidden_size, 73 | activation_type=args.activation), 74 | }) 75 | 76 | self.rnn = RnnLayer(self._hidden_size, args) 77 | inter_nbr_feats = self._hidden_size if self.khops == 2 else self._obs_shape['nbr'] 78 | self.f_out = EdgeGNBlock((inter_nbr_feats, self._hidden_size), 79 | self._obs_shape['hop'], 80 | 1, 81 | args.n_pow_opts, 82 | self._hidden_size) 83 | 84 | def init_hidden(self, batch_size: int = 1) -> Tensor: 85 | """Initializes RNN hidden states.""" 86 | return th.zeros(1, self._hidden_size).expand(batch_size, -1) 87 | 88 | def forward(self, obs, h): 89 | feat = obs.ndata['feat'] 90 | 91 | if self.khops == 2: 92 | g_2hop = obs['2hop'] 93 | x_nbr = self.enc['2hop'](g_2hop, feat['nbr'], g_2hop.edata['feat']) 94 | else: 95 | x_nbr = feat['nbr'] 96 | 97 | g_1hop = obs['1hop'] 98 | x = self.enc['1hop'](g_1hop, (x_nbr, feat['agent']), g_1hop.edata['feat']) 99 | 100 | # Update RNN hidden states independently when communication is disabled. 101 | x, h = self.rnn(x, h) 102 | # Compute logits for action selection. 103 | agent_out, nbr_out = self.f_out(g_1hop, (x_nbr, x), g_1hop.edata['feat']) 104 | # print(f"agent_out.size() = {agent_out.size()}") 105 | # print(f"nbr_out.size() = {nbr_out.size()}") 106 | padded_nbr_out = pad_edge_output(g_1hop, nbr_out, self.args.max_nbrs) 107 | q_vals = th.cat([padded_nbr_out, agent_out], dim=1) 108 | # print(f"q_vals.size() = {q_vals.size()}") 109 | return q_vals, h 110 | -------------------------------------------------------------------------------- /modules/agents/recurrent_agent.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | from modules.basics import * 6 | from modules.dueling import DuelingLayer 7 | from modules.encoders import REGISTRY as enc_REGISTRY 8 | 9 | 10 | class RecurrentAgent(nn.Module): 11 | """Recurrent agent""" 12 | 13 | def __init__(self, obs_shape, act_size, args) -> None: 14 | super(RecurrentAgent, self).__init__() 15 | self._obs_shape = obs_shape # Shape of observations 16 | self._act_size = act_size # Number of discrete actions 17 | self._obs = args.obs # Observation form 18 | self._hidden_size = args.hidden_size # Hidden size 19 | 20 | self.f_enc = enc_REGISTRY[self._obs](self._obs_shape, self._hidden_size, args) # Observation encoder 21 | self.rnn = RnnLayer(self._hidden_size, args) 22 | if getattr(args, "use_dueling", False): 23 | self.f_out = DuelingLayer(self._hidden_size, self._act_size, args) 24 | else: 25 | self.f_out = nn.Linear(self._hidden_size, self._act_size) # Output layer 26 | 27 | def init_hidden(self) -> Tensor: 28 | """Initializes RNN hidden states.""" 29 | return th.zeros(1, self._hidden_size) 30 | 31 | def forward(self, obs, h): 32 | # Encode observations. 33 | x = self.f_enc(obs) 34 | # Update RNN hidden states independently when communication is disabled. 35 | x, h = self.rnn(x, h) 36 | # Compute logits for action selection. 37 | logits = self.f_out(x) 38 | return logits, h 39 | 40 | 41 | if __name__ == '__main__': 42 | from types import SimpleNamespace as SN 43 | args = SN(**dict(hidden_size=128)) 44 | print(type(args)) 45 | -------------------------------------------------------------------------------- /modules/basics.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | 5 | class RnnLayer(nn.Module): 6 | """Recurrent layer enabling layer normalization""" 7 | 8 | def __init__(self, hidden_size, args): 9 | super(RnnLayer, self).__init__() 10 | self._hidden_size = hidden_size 11 | self.rnn = nn.GRUCell(self._hidden_size, self._hidden_size) 12 | 13 | self._use_layer_norm = args.use_layer_norm 14 | if self._use_layer_norm: 15 | self.norm = nn.LayerNorm(self._hidden_size) 16 | 17 | def forward(self, x, h): 18 | h = self.rnn(x, h) 19 | if self._use_layer_norm: 20 | y = self.norm(h) 21 | else: 22 | y = h 23 | return y, h 24 | -------------------------------------------------------------------------------- /modules/comm/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from modules.comm.tarmac import TarMAC 4 | REGISTRY['tarmac'] = TarMAC 5 | 6 | from modules.comm.disc import DiscreteCommunication 7 | REGISTRY['disc'] = DiscreteCommunication 8 | -------------------------------------------------------------------------------- /modules/comm/disc.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import Tensor 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from dgl import DGLGraph 7 | 8 | 9 | class DiscreteCommunication(nn.Module): 10 | """Discrete Communication""" 11 | 12 | def __init__(self, hidden_size, args) -> None: 13 | super(DiscreteCommunication, self).__init__() 14 | 15 | self._hidden_size = hidden_size # Size of hidden states 16 | self._msg_size = args.msg_size # Size of messages 17 | # Note: In discrete communication, we use 2 digits to denote 1 bit as either (0, 1) or (1, 0). 18 | # Therefore, outputs from message encoder take twice the number of digits of continuous counterparts. 19 | 20 | self.f_enc = nn.Linear(self._hidden_size + self._hidden_size, 2 * self._msg_size) # Message function 21 | self.f_dec = nn.Linear(2 * self._msg_size, 2 * self._msg_size) # Decoder of aggregated messages 22 | self.f_udt = nn.GRUCell(self._hidden_size + 2 * self._msg_size, self._hidden_size) # RNN unit 23 | 24 | def msg_func(self, edges): 25 | """Encodes discrete messages from local inputs and detached hidden states.""" 26 | # Get logits from message function. 27 | logits = self.f_enc(th.cat((edges.src['x'], edges.src['h']), 1)) 28 | # When discrete communication is required, 29 | # we use Gumbel-Softmax function to sample binary messages while keeping gradients for backpropagation. 30 | disc_msg = F.gumbel_softmax(logits.view(-1, self._msg_size, 2), tau=0.5, hard=True) 31 | return dict(m=disc_msg.flatten(1)) 32 | 33 | def aggr_func(self, nodes): 34 | """Aggregates incoming discrete messages by element-wise OR operation.""" 35 | aggr_msg = nodes.mailbox['m'].max(1)[0] 36 | return dict(c=aggr_msg) 37 | 38 | def forward(self, g: DGLGraph, feat: Tensor, h: Tensor) -> tuple[Tensor, Tensor]: 39 | with g.local_scope(): 40 | g.ndata['x'], g.ndata['h'] = feat, h.detach() # Get inputs and the latest hidden states. 41 | 42 | if g.number_of_edges() == 0: 43 | # When no edge is created, paddle zeros. 44 | g.dstdata['c'] = th.zeros(feat.shape[0], 2 * self._msg_size) 45 | else: 46 | # Otherwise, call message passing between nodes. 47 | g.update_all(self.msg_func, self.aggr_func) 48 | 49 | # Update the hidden states using inputs, aggregated messages and hidden states. 50 | h = self.f_udt(th.cat((g.ndata['x'], self.f_dec(g.ndata['c'])), 1), h) 51 | return h, h 52 | -------------------------------------------------------------------------------- /modules/comm/tarmac.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | from torch import Tensor 4 | import dgl 5 | from dgl import DGLGraph 6 | import dgl.nn.pytorch as dglnn 7 | import dgl.function as fn 8 | from dgl.nn.functional import edge_softmax 9 | 10 | 11 | class TarMAC(nn.Module): 12 | """TarMAC: Targeted Multi-Agent Communication""" 13 | 14 | def __init__(self, hidden_size, args): 15 | super(TarMAC, self).__init__() 16 | self._hidden_size = hidden_size # Size of hidden states 17 | self._msg_size = args.msg_size # Size of messages 18 | self._key_size = args.key_size # Size of signatures and queries 19 | self._n_rounds = args.n_rounds # Number of multi-round communication 20 | 21 | self.f_val = nn.Linear(2 * self._hidden_size, self._msg_size) # Value function (producing messages) 22 | self.f_sign = nn.Linear(2 * self._hidden_size, self._key_size) # Signature function (predicting keys at Tx) 23 | self.f_que = nn.Linear(2 * self._hidden_size, self._key_size) # Query function (predicting keys at Rx) 24 | self.f_udt = nn.GRUCell(self._hidden_size + self._msg_size, self._hidden_size) # RNN update function 25 | 26 | def forward(self, g: DGLGraph, feat: Tensor, h: Tensor) -> tuple[Tensor, Tensor]: 27 | with g.local_scope(): 28 | g.ndata['x'] = feat 29 | for l in range(self._n_rounds): 30 | g.ndata['h'] = h # Get the latest hidden states. 31 | 32 | # Build inputs to modules for communication. 33 | inputs = th.cat((g.srcdata['x'], g.srcdata['h'].detach()), 1) 34 | # Get signatures, values at source nodes. 35 | g.srcdata.update(dict(v=self.f_val(inputs), s=self.f_sign(inputs))) 36 | # Predict queries at destination nodes. 37 | g.dstdata.update(dict(q=self.f_que(inputs))) 38 | 39 | # Get attention scores on each edge by Dot-product of signature and query. 40 | g.apply_edges(fn.u_dot_v('s', 'q', 'e')) 41 | # Normalize attention scores 42 | e = g.edata.pop('e') / self._key_size # Divide by key-size 43 | g.edata['a'] = edge_softmax(g, e) 44 | # Aggregated messages by weighted sum 45 | g.update_all(fn.u_mul_e('v', 'a', 'm'), fn.sum('m', 'c')) 46 | 47 | # Update the hidden states of GRU. 48 | h = self.f_udt(th.cat((g.ndata['x'], g.ndata['c']), 1), h) 49 | return h, h 50 | -------------------------------------------------------------------------------- /modules/critics/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from .customized_q import AdHocRelationalQ 4 | REGISTRY['r-adhoc'] = AdHocRelationalQ 5 | 6 | from .customized_q import AdHocGraphQ 7 | REGISTRY['g-adhoc'] = AdHocGraphQ 8 | -------------------------------------------------------------------------------- /modules/critics/customized_q.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | from torch import Tensor 3 | import torch.nn as nn 4 | 5 | import dgl 6 | import dgl.nn.pytorch as dglnn 7 | from dgl import DGLGraph 8 | from modules.basics import * 9 | from modules.dueling import DuelingLayer 10 | from modules.encoders import REGISTRY as enc_REGISTRY 11 | from modules.graph_nn import NeighborSelector 12 | from modules.gn_blks import * 13 | from modules.utils import pad_edge_output 14 | 15 | 16 | class AdHocRelationalQ(nn.Module): 17 | """Graph agent""" 18 | 19 | def __init__(self, shared_obs_shape, act_size, args) -> None: 20 | super(AdHocRelationalQ, self).__init__() 21 | self._shared_obs_shape = shared_obs_shape # Shape of observations 22 | self._act_size = act_size # Number of discrete actions 23 | self._hidden_size = args.critic_hidden_size # Hidden size 24 | self._shared_obs = args.shared_obs if 'shared_obs' in args.pre_decision_fields else args.obs 25 | 26 | self.f_enc = enc_REGISTRY[self._shared_obs](self._shared_obs_shape, self._hidden_size, args) # Observation encoder 27 | self.rnn = RnnLayer(self._hidden_size, args) 28 | self.f_out = NeighborSelector(self._shared_obs_shape['nbr'], self._hidden_size, args.n_pow_opts, 1, self._hidden_size, args) 29 | 30 | def init_hidden(self) -> Tensor: 31 | """Initializes RNN hidden states.""" 32 | return th.zeros(1, self._hidden_size) 33 | 34 | def forward(self, shared_obs, h): 35 | # Encode observations. 36 | x = self.f_enc(shared_obs) 37 | # Update RNN hidden states independently when communication is disabled. 38 | x, h = self.rnn(x, h) 39 | # Compute logits for action selection. 40 | logits = self.f_out(shared_obs[('nbr', 'nearby', 'agent')], {'agent': x, 'nbr': shared_obs.nodes['nbr'].data['feat']}) 41 | # print(f"logits.size() = {logits.size()}") 42 | return logits, h 43 | 44 | 45 | class AdHocGraphQ(nn.Module): 46 | def __init__(self, shared_obs_shape, act_size, args): 47 | super(AdHocGraphQ, self).__init__() 48 | 49 | self._shared_obs_shape = shared_obs_shape # Shape of observations 50 | self._act_size = act_size # Number of discrete actions 51 | self.args = args 52 | 53 | self._hidden_size = args.critic_hidden_size # Hidden size 54 | self._shared_obs = args.shared_obs if 'shared_obs' in args.pre_decision_fields else args.obs 55 | 56 | self._activation = act_REGISTRY[args.activation] # Callable to instantiate an activation function 57 | 58 | self.khops = args.khops 59 | if self.khops == 1: 60 | self.enc = nn.ModuleDict({ 61 | '1hop': NodeGNBlock((self._shared_obs_shape['nbr'], self._shared_obs_shape['agent']), 62 | self._shared_obs_shape['hop'], 63 | self._hidden_size, 64 | activation_type=args.activation), 65 | }) 66 | elif self.khops == 2: 67 | self.enc = nn.ModuleDict({ 68 | '2hop': NodeGNBlock((self._shared_obs_shape['nbr'], self._shared_obs_shape['nbr']), 69 | self._shared_obs_shape['hop'], 70 | self._hidden_size, 71 | activation_type=args.activation), 72 | '1hop': NodeGNBlock((self._hidden_size, self._shared_obs_shape['agent']), 73 | self._shared_obs_shape['hop'], 74 | self._hidden_size, 75 | activation_type=args.activation), 76 | }) 77 | 78 | self.rnn = RnnLayer(self._hidden_size, args) 79 | inter_nbr_feats = self._hidden_size if self.khops == 2 else self._obs_shape['nbr'] 80 | self.f_out = EdgeGNBlock((inter_nbr_feats, self._hidden_size), 81 | self._shared_obs_shape['hop'], 82 | 1, 83 | args.n_pow_opts, 84 | self._hidden_size) 85 | 86 | def init_hidden(self, batch_size: int = 1) -> Tensor: 87 | """Initializes RNN hidden states.""" 88 | return th.zeros(1, self._hidden_size).expand(batch_size, -1) 89 | 90 | def forward(self, shared_obs, h): 91 | feat = shared_obs.ndata['feat'] 92 | 93 | if self.khops == 2: 94 | g_2hop = shared_obs['2hop'] 95 | x_nbr = self.enc['2hop'](g_2hop, feat['nbr'], g_2hop.edata['feat']) 96 | else: 97 | x_nbr = feat['nbr'] 98 | 99 | g_1hop = shared_obs['1hop'] 100 | x = self.enc['1hop'](g_1hop, (x_nbr, feat['agent']), g_1hop.edata['feat']) 101 | 102 | # Update RNN hidden states independently when communication is disabled. 103 | x, h = self.rnn(x, h) 104 | # Compute logits for action selection. 105 | agent_out, nbr_out = self.f_out(g_1hop, (x_nbr, x), g_1hop.edata['feat']) 106 | # print(f"agent_out.size() = {agent_out.size()}") 107 | # print(f"nbr_out.size() = {nbr_out.size()}") 108 | padded_nbr_out = pad_edge_output(g_1hop, nbr_out, self.args.max_nbrs) 109 | q_vals = th.cat([padded_nbr_out, agent_out], dim=1) 110 | # print(f"q_vals.size() = {q_vals.size()}") 111 | return q_vals, h 112 | -------------------------------------------------------------------------------- /modules/dueling.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | import torch.nn as nn 3 | 4 | from modules.activations import REGISTRY as act_REGISTRY 5 | 6 | 7 | class DuelingLayer(nn.Module): 8 | """Dueling architecture separating state values and advantage functions 9 | 10 | For more details, refer to "Dueling Network Architectures for Deep Reinforcement Learning". 11 | """ 12 | def __init__(self, hidden_size, n_actions, args): 13 | super(DuelingLayer, self).__init__() 14 | self._hidden_size = hidden_size 15 | self._activation = act_REGISTRY[args.activation] # Activation function 16 | 17 | self.f_val = [nn.Linear(self._hidden_size, self._hidden_size)] 18 | self.f_adv = [nn.Linear(self._hidden_size, self._hidden_size)] 19 | if args.use_layer_norm: # Layer normalization is used. 20 | self.f_val += [nn.LayerNorm(self._hidden_size)] 21 | self.f_adv += [nn.LayerNorm(self._hidden_size)] 22 | self.f_val += [nn.ReLU(), nn.Linear(self._hidden_size, 1)] 23 | self.f_adv += [nn.ReLU(), nn.Linear(self._hidden_size, n_actions)] 24 | 25 | self.f_val = nn.Sequential(*self.f_val) 26 | self.f_adv = nn.Sequential(*self.f_adv) 27 | 28 | def forward(self, x): 29 | vals = self.f_val(x) 30 | advs = self.f_adv(x) 31 | # print(f"vals.size() = {vals.size()}") 32 | # print(f"advs.size() = {advs.size()}") 33 | # print(f"advs.mean(-1) = {advs.mean(-1)}") 34 | # print(f"advs.mean(-1, keepdim=True).expand_as(a) = \n{advs.mean(-1, keepdim=True).expand_as(advs)}") 35 | return vals + advs - advs.mean(-1, keepdim=True) 36 | 37 | 38 | if __name__ == '__main__': 39 | from types import SimpleNamespace as SN 40 | args = SN(**dict(activation='relu', use_layer_norm='False')) 41 | hidden_size = 32 42 | n_actions = 5 43 | dueling = DuelingLayer(hidden_size, n_actions, args) 44 | 45 | batch_size = 10 46 | x = th.rand(batch_size, hidden_size) 47 | y = dueling(x) 48 | print(f"y.size() = {y.size()}") 49 | print(getattr(args, "name", 0)) 50 | -------------------------------------------------------------------------------- /modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from modules.encoders.flat_enc import FlatEncoder 4 | REGISTRY['flat'] = FlatEncoder 5 | 6 | from modules.encoders.rel_enc import RelationalEncoder 7 | REGISTRY['rel'] = RelationalEncoder 8 | -------------------------------------------------------------------------------- /modules/encoders/flat_enc.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import torch as th 3 | from torch import Tensor 4 | import torch.nn as nn 5 | from dgl import DGLGraph 6 | 7 | from modules.activations import REGISTRY as act_REGISTRY 8 | 9 | 10 | class FlatEncoder(nn.Module): 11 | """Flat input encoder""" 12 | 13 | def __init__(self, input_shape: int, hidden_size: int, args): 14 | super(FlatEncoder, self).__init__() 15 | 16 | self._hidden_size = hidden_size # Hidden size 17 | self._n_layers = args.n_layers # Number of fully-connected layers 18 | self._activation = act_REGISTRY[args.activation] # Activation function 19 | self._use_feat_norm = args.use_feat_norm 20 | self._use_layer_norm = args.use_layer_norm 21 | 22 | # Define input layer. 23 | if self._use_feat_norm: # Feature normalization is used. 24 | self.feat_norm = nn.LayerNorm(input_shape) 25 | layers = [nn.Linear(input_shape, self._hidden_size)] 26 | if self._use_layer_norm: # Add layer normalization. 27 | layers += [nn.LayerNorm(self._hidden_size)] 28 | layers += [self._activation()] 29 | # Define hidden layers. 30 | for l in range(1, self._n_layers): 31 | layers += [nn.Linear(self._hidden_size, self._hidden_size)] 32 | if self._use_layer_norm: # Add layer normalization. 33 | layers += [nn.LayerNorm(self._hidden_size)] 34 | layers += [self._activation()] 35 | self.layers = nn.Sequential(*layers) 36 | 37 | def forward(self, inputs: Union[Tensor, DGLGraph]) -> Tensor: 38 | # Extract feature observation is a graph. 39 | if isinstance(inputs, DGLGraph): 40 | inputs = inputs.ndata['feat'] 41 | if self._use_feat_norm: 42 | inputs = self.feat_norm(inputs) 43 | return self.layers(inputs) 44 | 45 | 46 | if __name__ == '__main__': 47 | from types import SimpleNamespace as SN 48 | import torch as th 49 | import torch.nn as nn 50 | 51 | args = SN(**dict(hidden_size=16, n_layers=2, batch_size=10, n_heads=4, activation=nn.ELU)) 52 | 53 | obs_shape = 5 54 | enc = FlatEncoder(obs_shape, args) 55 | print(enc) 56 | obs = th.rand(args.batch_size, obs_shape) 57 | rst = enc(obs) 58 | -------------------------------------------------------------------------------- /modules/encoders/rel_enc.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Mapping 2 | import torch as th 3 | from torch import Tensor 4 | import torch.nn as nn 5 | import dgl 6 | import dgl.nn.pytorch as dglnn 7 | from dgl import DGLHeteroGraph 8 | 9 | from modules.activations import REGISTRY as act_REGISTRY 10 | 11 | 12 | class RelationalEncoder(nn.Module): 13 | """Relational input encoder 14 | Observations are heterogeneous graphs of which edges are established from observed entities to models. 15 | """ 16 | 17 | def __init__(self, in_feats_size_dict: Mapping[str, int], hidden_size, args) -> None: 18 | super(RelationalEncoder, self).__init__() 19 | 20 | assert 'agent' in in_feats_size_dict, "agent features must be reserved in observations." 21 | in_feats_size_dict_ = in_feats_size_dict.copy() 22 | agent_feats_size = in_feats_size_dict_.pop('agent') # Size of input features of models 23 | self._ntypes = tuple(in_feats_size_dict_.keys()) # Number of entity types in observations (except agent) 24 | 25 | self._hidden_size = hidden_size # Hidden size 26 | self._n_heads = args.n_heads # Number of attention heads 27 | feats_per_head = self._hidden_size // self._n_heads # Size of output features per head 28 | self._activation = act_REGISTRY[args.activation] # Callable to instantiate an activation function 29 | 30 | # Define a separate module to process each type of entities. 31 | mods = dict() 32 | for ntype, feats_size in in_feats_size_dict_.items(): 33 | mods[ntype] = dglnn.GATv2Conv((feats_size, agent_feats_size), feats_per_head, self._n_heads, 34 | allow_zero_in_degree=True, residual=True, activation=self._activation()) 35 | self.f_conv = nn.ModuleDict(mods) # Dict holding graph convolution layers 36 | self.f_aggr = nn.Sequential(nn.Linear(self._hidden_size * len(self._ntypes), self._hidden_size), 37 | self._activation()) # MLP aggregator 38 | 39 | def forward(self, g: DGLHeteroGraph) -> Tensor: 40 | feat = g.ndata['feat'] 41 | outputs = {} 42 | # Go through all types of entities. 43 | for stype, etype, dtype in g.canonical_etypes: 44 | # When an entity type is not specified by modules, skip it. 45 | if (stype not in self.f_conv) or (etype != 'nearby') or (dtype != 'agent'): 46 | continue 47 | # Extract subgraph and apply graph convolution. 48 | rel_g = g[stype, etype, dtype] 49 | outputs[stype] = self.f_conv[stype](rel_g, (feat[stype], feat['agent'])) 50 | # Aggregate outputs from different relations to obtain final results. 51 | rsts = self._aggr_func(outputs) 52 | return rsts 53 | 54 | def _aggr_func(self, outputs: Mapping[str, Tensor]) -> Tensor: 55 | """Aggregates outputs from multiple relations. 56 | An MLP aggregator is used to transform stacked outputs into expected shape. 57 | """ 58 | # Stack outputs from relations in order. 59 | stacked = th.stack([outputs[ntype] for ntype in self._ntypes], dim=1) 60 | # Flatten stacked outputs and pass them through MLP aggregator. 61 | return self.f_aggr(stacked.flatten(1)) 62 | 63 | 64 | if __name__ == '__main__': 65 | from types import SimpleNamespace as SN 66 | import torch as th 67 | import torch.nn as nn 68 | import dgl 69 | 70 | args = SN(**dict(hidden_size=16, n_layers=2, batch_size=10, n_heads=4, activation=nn.ELU)) 71 | 72 | obs_feats_size_dict = {'gt': 3, 'uav': 2, 'agent': 3} 73 | enc = RelationalEncoder(obs_feats_size_dict, args) 74 | print(enc) 75 | graph_data = { 76 | ('gt', 'nearby', 'agent'): ((0, 1, 0), (0, 0, 1)), 77 | ('uav', 'nearby', 'agent'): ((0, 1), (1, 0)), 78 | ('agent', 'talks', 'agent'): ((0, 1), (1, 0)), 79 | } 80 | num_nodes_dict = {'gt': 2, 'uav': 2, 'agent': 2} 81 | obs_g = dgl.heterograph(graph_data, num_nodes_dict=num_nodes_dict) 82 | print(obs_g) 83 | feat = {ntype: th.rand(num_nodes_dict[ntype], obs_feats_size_dict[ntype]) for ntype in obs_feats_size_dict} 84 | rsts = enc(obs_g, feat) 85 | print(rsts) 86 | -------------------------------------------------------------------------------- /modules/gn_blks.py: -------------------------------------------------------------------------------- 1 | r"""Graph network (GN) blocks""" 2 | 3 | from typing import Union 4 | 5 | import torch as th 6 | import torch.nn as nn 7 | 8 | import dgl 9 | from dgl.base import DGLError, DGLWarning 10 | from dgl.utils import expand_as_pair 11 | import dgl.function as fn 12 | from modules.activations import REGISTRY as act_REGISTRY 13 | 14 | 15 | class NodeGNBlock(nn.Module): 16 | """Node-focused graph network (GN) block 17 | 18 | Components and update rules follow the paper 19 | "Relational inductive biases, deep learning, and graph networks" (https://arxiv.org/abs/1806.01261v1). 20 | """ 21 | def __init__(self, 22 | in_node_feats: Union[int, tuple[int, int]], # Size of input node features 23 | in_edge_feats: int, # Size of input edge features 24 | out_node_feats: int, # Size of output node features 25 | aggregator_type: str = 'mean', 26 | activation_type: str = 'relu', 27 | ): 28 | super(NodeGNBlock, self).__init__() 29 | 30 | self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats) 31 | self._in_edge_feats = in_edge_feats 32 | self._out_node_feats = out_node_feats 33 | 34 | if aggregator_type not in ('sum', 'max', 'mean'): 35 | raise KeyError('Aggregator type {} not recognized.'.format(aggregator_type)) 36 | self._aggr_type = aggregator_type 37 | self.activation = act_REGISTRY[activation_type] 38 | 39 | # Edge update function. 40 | self.f_e = nn.Sequential( 41 | nn.Linear(self._in_src_node_feats + self._in_edge_feats + self._in_dst_node_feats, self._out_node_feats), 42 | self.activation() 43 | ) 44 | # Node update function. 45 | self.f_v = nn.Sequential( 46 | nn.Linear(self._out_node_feats + self._in_dst_node_feats, self._out_node_feats), 47 | self.activation(), 48 | ) 49 | 50 | def edge_update_func(self, edges): 51 | x = th.cat([edges.src['v_i'], edges.data['e'], edges.dst['v_j']], 1) 52 | return {'m': self.f_e(x)} 53 | 54 | def forward(self, graph, node_feats, edge_feats): 55 | _reducer = getattr(fn, self._aggr_type) 56 | with graph.local_scope(): 57 | if isinstance(node_feats, tuple): 58 | src_node_feats, dst_node_feats = node_feats 59 | else: 60 | src_node_feats = dst_node_feats = node_feats 61 | 62 | graph.srcdata.update({'v_i': src_node_feats}) 63 | graph.dstdata.update({'v_j': dst_node_feats}) 64 | graph.edata.update({'e': edge_feats}) 65 | 66 | graph.update_all(self.edge_update_func, _reducer('m', 'neigh')) 67 | return self.f_v(th.cat([graph.dstdata['neigh'], graph.dstdata['v_j']], 1)) 68 | 69 | 70 | class EdgeGNBlock(nn.Module): 71 | def __init__(self, 72 | in_node_feats: Union[int, tuple[int, int]], # Size of input node features 73 | in_edge_feats: int, # Size of input edge features 74 | out_node_feats: int, # Size of output node features 75 | out_edge_feats: int, # Size of output edge features 76 | hidden_size: int, # Hidden size 77 | activation_type: str = 'relu', # Activation function 78 | ): 79 | super(EdgeGNBlock, self).__init__() 80 | 81 | self._in_src_node_feats, self._in_dst_node_feats = expand_as_pair(in_node_feats) 82 | self._in_edge_feats = in_edge_feats 83 | self._out_node_feats = out_node_feats 84 | self._out_edge_feats = out_edge_feats 85 | self._hidden_size = hidden_size 86 | 87 | self.activation = act_REGISTRY[activation_type] 88 | # Edge update function. 89 | self.f_e = nn.Sequential( 90 | nn.Linear(self._in_src_node_feats + self._in_edge_feats + self._in_dst_node_feats, self._hidden_size), 91 | self.activation(), 92 | nn.Linear(self._hidden_size, self._out_edge_feats) 93 | ) 94 | 95 | self.f_v = nn.Linear(self._in_dst_node_feats, self._out_node_feats) 96 | 97 | def edge_update_func(self, edges): 98 | x = th.cat([edges.src['v_i'], edges.data['e'], edges.dst['v_j']], 1) 99 | return {'e2': self.f_e(x)} 100 | 101 | def node_update_func(self, nodes): 102 | return {'v2_j': self.f_v(nodes.data['v_j'])} 103 | 104 | def forward(self, graph, node_feats, edge_feats): 105 | assert graph.is_unibipartite, "Only uni-bipartite graph is supported by `EdgeGNBlock`." 106 | with graph.local_scope(): 107 | if isinstance(node_feats, tuple): 108 | src_node_feats, dst_node_feats = node_feats 109 | else: 110 | src_node_feats = dst_node_feats = node_feats 111 | 112 | graph.srcdata.update({'v_i': src_node_feats}) 113 | graph.dstdata.update({'v_j': dst_node_feats}) 114 | graph.edata.update({'e': edge_feats}) 115 | 116 | graph.apply_edges(self.edge_update_func) 117 | v2_j = self.f_v(dst_node_feats) 118 | return v2_j, graph.edata['e2'] 119 | 120 | 121 | if __name__ == '__main__': 122 | g = dgl.heterograph({('nbr', 'nearby', 'agent'): (th.tensor([0, 1, 2, 1, 2, 3]), th.tensor([0, 0, 0, 1, 1, 1]))}) 123 | print(g[('nbr', 'nearby', 'agent')]) 124 | in_node_feats = (8, 7) 125 | in_edge_feats = 6 126 | out_node_feats = 5 127 | out_edge_feats = 4 128 | 129 | gnn = EdgeGNBlock(in_node_feats, in_edge_feats, out_node_feats, out_edge_feats, 32) 130 | node_feats = (th.rand(4, in_node_feats[0]), th.rand(2, in_node_feats[1])) 131 | edge_feats = th.rand(6, in_edge_feats) 132 | node_feats, edge_feats = gnn(g, node_feats, edge_feats) 133 | print(f"node_feats.size() = {node_feats}, \nedge_feats.size() = {edge_feats}") -------------------------------------------------------------------------------- /modules/graph_nn.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import torch as th 3 | import torch.nn as nn 4 | import dgl 5 | 6 | from modules.activations import REGISTRY as act_REGISTRY 7 | 8 | 9 | class NeighborSelector(nn.Module): 10 | def __init__(self, nbr_in_feats, agent_in_feats, nbr_out_feats, agent_out_feats, hidden_size, args): 11 | super(NeighborSelector, self).__init__() 12 | 13 | self.nbr_in_feats = nbr_in_feats 14 | self.agent_in_feats = agent_in_feats 15 | self.nbr_out_feats = nbr_out_feats 16 | self.agent_out_feats = agent_out_feats 17 | 18 | self.device = args.device 19 | self.max_nbrs = args.max_nbrs 20 | 21 | self._hidden_size = hidden_size 22 | self._activation = act_REGISTRY[args.activation] # Activation function 23 | self.nbr_predictor = nn.Sequential( 24 | nn.Linear(nbr_in_feats + agent_in_feats, self._hidden_size), 25 | self._activation(), 26 | nn.Linear(self._hidden_size, nbr_out_feats) 27 | ) 28 | self.agent_predictor = nn.Linear(self._hidden_size, agent_out_feats) 29 | 30 | def predict_entity_score(self, edges): 31 | h_ent = edges.src['x'] 32 | h_agent = edges.dst['x'] 33 | # print(f"h_ent.size() = {h_ent.size()}, h_agent.size() = {h_agent.size()}") 34 | score = self.nbr_predictor(th.cat([h_ent, h_agent], 1)) 35 | return {'score': score} # (batch_size * n_agents, nbr_out_feats) 36 | 37 | def forward(self, graph, x): 38 | with graph.local_scope(): 39 | # Get scores of neighbors and own with graph convolution. 40 | graph.ndata['x'] = x 41 | graph.apply_edges(self.predict_entity_score) 42 | nbr_scores = graph.edata['score'] # shape (batch_size * n_agents * nbr_per_agent, 1) 43 | own_score = self.agent_predictor(x['agent']) # shape (batch_size * n_agents, 1) 44 | # print(f"nbr_scores.size() = {nbr_scores.size()}, own_score.size() = {own_score.size()}") 45 | 46 | nbrs_per_agent = graph.in_degrees().tolist() # Number of neighbors around each agent 47 | # Split nbr scores for each agent and pad zero to `max_nbr`. 48 | nbr_scores = th.split(nbr_scores, split_size_or_sections=nbrs_per_agent, dim=0) # Each entry has shape (n_nbrs, nbr_out_feats) 49 | pad_zeros = [th.zeros(self.nbr_out_feats * (self.max_nbrs - nbrs_per_agent[i]), dtype=th.float, device=self.device) for i in range(len(nbrs_per_agent))] # all-zero with shape (max_nbrs - n_nbrs) of each agent 50 | # print(f"nbr_scores.size() = {nbr_scores}") 51 | # print(f"self.nbr_out_feats * nbrs_per_agent = {self.nbr_out_feats * nbrs_per_agent}") 52 | # print(f"pad_zeros = \n{pad_zeros}") 53 | # print(f"nbr_scores = \n{nbr_scores}") 54 | # Pad nbr scores of each agent with all-zero vector. 55 | padded_nbr_scores = [] 56 | for agent_idx, score in enumerate(nbr_scores): 57 | padded_nbr_scores.append(th.cat((score.flatten(), pad_zeros[agent_idx]))) 58 | padded_nbr_scores = th.stack(padded_nbr_scores) # Shape (batch_size * n_agents, max_nbrs) 59 | all_scores = th.cat([padded_nbr_scores, own_score], 1) 60 | # print(f"padded_nbr_scores has size {padded_nbr_scores.size()} = \n{padded_nbr_scores}") 61 | # print(f"all_scores.size() = {all_scores.size()}") 62 | return all_scores 63 | -------------------------------------------------------------------------------- /modules/mixers/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from modules.mixers.qmix import QMixer 4 | REGISTRY['qmix'] = QMixer 5 | -------------------------------------------------------------------------------- /modules/mixers/qmix.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch as th 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class QMixer(nn.Module): 8 | """QMIX: Monotonic Value Function Factorisation for Deep Multi-Agent Reinforcement Learning""" 9 | 10 | def __init__(self, state_shape, n_agents, args): 11 | super(QMixer, self).__init__() 12 | 13 | self.n_agents = n_agents 14 | self.state_dim = int(np.prod(state_shape)) 15 | self.embed_dim = args.mixing_embed_dim 16 | 17 | if getattr(args, "hypernet_layers", 1) == 1: 18 | self.hyper_w_1 = nn.Linear(self.state_dim, self.embed_dim * self.n_agents) 19 | self.hyper_w_final = nn.Linear(self.state_dim, self.embed_dim) 20 | elif getattr(args, "hypernet_layers", 1) == 2: 21 | hypernet_embed = args.hypernet_embed 22 | self.hyper_w_1 = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 23 | nn.ReLU(), 24 | nn.Linear(hypernet_embed, self.embed_dim * self.n_agents)) 25 | self.hyper_w_final = nn.Sequential(nn.Linear(self.state_dim, hypernet_embed), 26 | nn.ReLU(), 27 | nn.Linear(hypernet_embed, self.embed_dim)) 28 | elif getattr(args, "hypernet_layers", 1) > 2: 29 | raise Exception("Sorry >2 hypernet layers is not implemented!") 30 | else: 31 | raise Exception("Error setting number of hypernet layers.") 32 | 33 | # State dependent bias for hidden layer 34 | self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim) 35 | 36 | # V(s) instead of a bias for the last layers 37 | self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), 38 | nn.ReLU(), 39 | nn.Linear(self.embed_dim, 1)) 40 | 41 | def forward(self, agent_qs, states): 42 | ts, bs = agent_qs.size(0), agent_qs.size(1) 43 | states = states.reshape(-1, self.state_dim) 44 | agent_qs = agent_qs.view(-1, 1, self.n_agents) 45 | # First layer 46 | w1 = th.abs(self.hyper_w_1(states)) 47 | b1 = self.hyper_b_1(states) 48 | w1 = w1.view(-1, self.n_agents, self.embed_dim) 49 | b1 = b1.view(-1, 1, self.embed_dim) 50 | hidden = F.elu(th.bmm(agent_qs, w1) + b1) 51 | # Second layer 52 | w_final = th.abs(self.hyper_w_final(states)) 53 | w_final = w_final.view(-1, self.embed_dim, 1) 54 | # State-dependent bias 55 | v = self.V(states).view(-1, 1, 1) 56 | # Compute final output 57 | y = th.bmm(hidden, w_final) + v 58 | # Reshape and return 59 | q_tot = y.view(ts, bs, 1) 60 | return q_tot 61 | 62 | 63 | if __name__ == '__main__': 64 | from types import SimpleNamespace as SN 65 | 66 | n_agents = 3 67 | state_shape = 12 68 | mixing_embed_dim = 32 69 | config = dict(n_agents=n_agents, state_shape=state_shape, mixing_embed_dim=mixing_embed_dim, hypernet_embed=64) 70 | args = SN(**config) 71 | bs = 8 72 | T = 25 73 | mixer = QMixer(state_shape, n_agents, args) 74 | states = th.rand(T, bs, state_shape) 75 | qs = th.rand(T, bs, n_agents) 76 | out = mixer(qs, states) 77 | -------------------------------------------------------------------------------- /modules/utils.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | 4 | def pad_edge_output(graph, edge_feats, max_degree): 5 | 6 | out_edge_feats = edge_feats.size(-1) 7 | degrees_per_dst = graph.in_degrees().tolist() 8 | 9 | edge_feats_per_dst = th.split(edge_feats, split_size_or_sections=degrees_per_dst, dim=0) 10 | pad_zeros = [th.zeros(max_degree - d, out_edge_feats, dtype=th.float, device=edge_feats.device) for d in degrees_per_dst] 11 | 12 | # print(f"pad_zeros = {pad_zeros}") 13 | padded_edge_feats = [] 14 | for dst_nid, e_feats in enumerate(edge_feats_per_dst): 15 | # print(f"e_feats.size() = {e_feats.size()}, pad_zeros[dst_nid].size() = {pad_zeros[dst_nid].size()}") 16 | 17 | padded_edge_feats.append( 18 | th.cat([e_feats, pad_zeros[dst_nid]], dim=0).flatten() 19 | ) 20 | return th.stack(padded_edge_feats) 21 | -------------------------------------------------------------------------------- /policies/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from policies.shared_policy import SharedPolicy 4 | REGISTRY['shared'] = SharedPolicy 5 | -------------------------------------------------------------------------------- /policies/shared_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch as th 3 | from torch import Tensor 4 | from modules.agents import REGISTRY as agent_REGISTRY 5 | from components.action_selectors import REGISTRY as action_REGISTRY 6 | 7 | 8 | class SharedPolicy: 9 | """Policy shared by all homogeneous agents""" 10 | 11 | def __init__(self, env_info, args) -> None: 12 | self.n_agents = env_info.get('n_agents') 13 | 14 | obs_shape = env_info['obs_shape'][0] # Shape of observations 15 | self._build_agents(obs_shape, args.act_size, args) # Agents 16 | self.action_selector = action_REGISTRY[args.action_selector](args) # Action selector 17 | 18 | def _build_agents(self, obs_shape, act_size, args) -> None: 19 | self.model = agent_REGISTRY[args.agent](obs_shape, act_size, args) 20 | 21 | def init_hidden(self, batch_size: int = 1) -> Tensor: 22 | """Initializes RNN hidden states of a batch of multi-models.""" 23 | return self.model.init_hidden().expand(batch_size * self.n_agents, -1) 24 | 25 | def forward(self, obs, h: Tensor): 26 | logits, h = self.model(obs, h) 27 | return logits, h 28 | 29 | @ th.no_grad() 30 | def act(self, obs, h: Tensor, avail_actions: Optional[Tensor] = None, t: Optional[int] = None, 31 | mode: str = 'explore', **kwargs): 32 | """Selects actions of models from observations.""" 33 | logits, next_h = self.forward(obs, h) 34 | actions = self.action_selector.select_actions(logits, avail_actions, t, mode) 35 | return actions, next_h 36 | 37 | def parameters(self): 38 | """Returns parameters of neural network.""" 39 | return self.model.parameters() 40 | 41 | def to(self, device) -> None: 42 | """Moves neural network to device.""" 43 | self.model.to(device) 44 | 45 | def load_state(self, other_policy) -> None: 46 | """Loads the parameters from another policy.""" 47 | self.model.load_state_dict(other_policy.model.state_dict()) 48 | 49 | def eval(self) -> None: 50 | """Sets agent to eval mode""" 51 | self.model.eval() 52 | 53 | def train(self) -> None: 54 | """Sets agent to training mode.""" 55 | self.model.train() 56 | 57 | def __repr__(self): 58 | return f"Shared policy: \n{self.model}" 59 | -------------------------------------------------------------------------------- /process_csv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | 5 | # def moving_average(a, n=5) : 6 | # ret = np.cumsum(a, dtype=float) 7 | # ret[n:] = ret[n:] - ret[:-n] 8 | # return ret[n - 1:] / n 9 | 10 | 11 | def moving_average(x, smooth=5): 12 | y = np.ones(smooth) 13 | z = np.ones(len(x)) 14 | smoothed_x = np.convolve(x, y, 'same') / np.convolve(z, y, 'same') 15 | return smoothed_x 16 | 17 | 18 | def smooth_curves(filename): 19 | df = pd.read_csv(f'./{filename}.csv') 20 | # max_index = 51 21 | # df = df[0:max_index] 22 | 23 | for col in df.columns: 24 | seq = df[col].values 25 | if col == 'Step': 26 | df.loc[:, col] = seq / 1e6 27 | else: 28 | # print(f"df.index = {df.index}") 29 | # print(f"seq.shape = {seq.shape}") 30 | smooth_seq = moving_average(seq, 5) 31 | # print(f"smooth_seq.shape = {smooth_seq.shape}") 32 | # print(f"df[col] = {df[col]}") 33 | df.loc[:, col] = smooth_seq 34 | 35 | df.to_csv(f'./smoothed_{filename}.csv') 36 | 37 | 38 | if __name__ == '__main__': 39 | for exp_idx in ['3_1fb']: 40 | files = [f'exp{exp_idx}_agent_bottleneck_rate', 41 | f'exp{exp_idx}_c2dst_full_bottleneck_rate', f'exp{exp_idx}_c2dst_rand_bottleneck_rate', 42 | f'exp{exp_idx}_msinr_full_bottleneck_rate', f'exp{exp_idx}_msinr_rand_bottleneck_rate'] 43 | for file in files: 44 | smooth_curves(file) 45 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import json 4 | import yaml 5 | from typing import Any, Optional 6 | from collections.abc import Mapping 7 | from copy import deepcopy 8 | from types import SimpleNamespace as SN 9 | from functools import partial 10 | 11 | import random 12 | import numpy as np 13 | import torch as th 14 | import wandb 15 | 16 | from gym.spaces.discrete import Discrete 17 | from gym.spaces.multi_discrete import MultiDiscrete 18 | from envs import REGISTRY as env_REGISTRY 19 | from envs.multi_agent_env import MultiAgentEnv 20 | 21 | from policies import REGISTRY as policy_REGISTRY 22 | from components.buffers import REGISTRY as buff_REGISTRY 23 | from learners import REGISTRY as learn_REGISTRY, DETERMINISTIC_POLICY_GRADIENT_ALGOS 24 | from runners import REGISTRY as run_REGISTRY 25 | 26 | DEFAULT_DATA_DIR = osp.join(osp.abspath(osp.dirname(__file__)), 'results') 27 | 28 | 29 | def recursive_dict_update(d, u): 30 | """Merges two dictionaries.""" 31 | 32 | for k, v in u.items(): 33 | if isinstance(v, Mapping): 34 | d[k] = recursive_dict_update(d.get(k, {}), v) 35 | else: 36 | d[k] = v 37 | return d 38 | 39 | 40 | def config_copy(config): 41 | """Copies configuration.""" 42 | 43 | if isinstance(config, dict): 44 | return {k: config_copy(v) for k, v in config.items()} 45 | elif isinstance(config, list): 46 | return [config_copy(v) for v in config] 47 | else: 48 | return deepcopy(config) 49 | 50 | 51 | def check_args_sanity(config: Mapping[str, Any]) -> dict[str, Any]: 52 | """Checks the feasibility of configuration.""" 53 | 54 | # Setup correct device. 55 | if config['use_cuda'] and th.cuda.is_available(): 56 | config['device'] = 'cuda:{}'.format(config['cuda_idx']) 57 | else: 58 | config['use_cuda'] = False 59 | config['device'] = 'cpu' 60 | print(f"Choose to use {config['device']}.") 61 | 62 | # Env specific requirements 63 | if config['env_id'] == 'mpe': 64 | assert config['obs'] == 'flat', "MPE only supports flat obs." 65 | if config['shared_obs'] is not None: 66 | assert config['shared_obs'] == 'flat', "MPE only supports flat shared obs." 67 | if config['state'] is not None: 68 | assert config['state'] == 'flat', f"Unsupported state format 's{config['state']}' is encountered." 69 | 70 | return config 71 | 72 | 73 | def update_args_from_env(env: MultiAgentEnv, args): 74 | """Updates args from env.""" 75 | 76 | env_info = env.get_env_info() 77 | args.n_agents = env_info['n_agents'] 78 | 79 | if args.env_id.startswith('ad-hoc'): 80 | args.max_nbrs = getattr(env, 'max_nbrs', None) 81 | args.n_pow_opts = getattr(env, 'power_options', env.n_pow_lvs) 82 | args.khops = getattr(env, 'khops', 1) 83 | 84 | if args.runner == 'base': 85 | if args.rollout_len is None: 86 | args.rollout_len = env_info['episode_limit'] 87 | print(f"`rollout_len` is set to `episode_limit` as {args.rollout_len}.") 88 | if args.data_chunk_len is None: 89 | args.data_chunk_len = env_info['episode_limit'] 90 | print(f"`data_chunk_len` is set to `episode_limit` as {args.data_chunk_len}.") 91 | assert args.rollout_len is not None and args.data_chunk_len is not None, "Invalid rollout/data chunk length" 92 | elif args.runner == 'episode': 93 | args.data_chunk_len = env_info['episode_limit'] 94 | print(f"`data_chunk_len` is set to `episode_limit` as {env_info['episode_limit']}.") 95 | else: 96 | raise KeyError("Unrecognized name of runner") 97 | 98 | # Assume that all agents share the same action space and retrieve action info. 99 | act_space = env.action_space[0] 100 | args.act_size = env_info['n_actions'][0] 101 | # Note that `act_size` specifies output layer of modules, 102 | # while `act_shape` indicates the shape of actions stored in buffers (which may be different from `act_size`). 103 | if isinstance(act_space, Discrete): 104 | args.is_discrete = True 105 | args.is_multi_discrete = False 106 | args.act_shape = 1 if args.learner not in DETERMINISTIC_POLICY_GRADIENT_ALGOS else args.act_size 107 | elif isinstance(act_space, MultiDiscrete): 108 | args.is_discrete = True # Multi-discrete space is generalization of discrete space. 109 | args.is_multi_discrete = True 110 | args.nvec = act_space.nvec.tolist() # Number of actions in each space 111 | args.act_shape = len(args.nvec) if args.learner not in DETERMINISTIC_POLICY_GRADIENT_ALGOS else args.act_size 112 | else: # TODO: Continuous action use Box space. 113 | args.is_discrete = False 114 | # Discrete action selectors use available action mask. 115 | if args.is_discrete: 116 | args.pre_decision_fields.append('avail_actions') 117 | return args 118 | 119 | 120 | def run(env_id: str, env_kwargs: Mapping[str, Any], seed: int = 0, algo_name: str = 'q', 121 | train_kwargs: Mapping[str, Any] = dict(), run_tag: Optional[str] = None, 122 | add_suffix: bool = False, suffix: Optional[str] = None) -> None: 123 | """Main function to run the training loop""" 124 | 125 | # Set random seed. 126 | random.seed(seed) 127 | np.random.seed(seed) 128 | th.manual_seed(seed) 129 | 130 | # Load the default configuration. 131 | with open(os.path.join(os.path.dirname(__file__), 'config', "default.yaml"), "r") as f: 132 | config = yaml.safe_load(f) 133 | # Load hyper-params of algo. 134 | with open(os.path.join(os.path.dirname(__file__), 'config', f"algos/{algo_name}.yaml"), "r") as f: 135 | algo_config = yaml.safe_load(f) 136 | config = recursive_dict_update(config, algo_config) 137 | # Load mac parameters of communicative agents. 138 | if config['agent'] == 'comm': 139 | assert config['comm'] is not None, "Absence of communication protocol for communicative agents!" 140 | with open(os.path.join(os.path.dirname(__file__), 'config', f"comm/{config['comm']}.yaml"), "r") as f: 141 | comm_config = yaml.safe_load(f) 142 | config = recursive_dict_update(config, comm_config) 143 | # Load preference from train_kwargs. 144 | config = recursive_dict_update(config, train_kwargs) 145 | # Add env id. 146 | config['env_id'] = env_id 147 | # Make sure the legitimacy of configuration. 148 | config = check_args_sanity(config) 149 | del algo_config, train_kwargs # Delete redundant variables. 150 | args = SN(**config) # Convert to simple namespace. 151 | 152 | # Get directory to store models/results. 153 | # Project identifier includes `env_id` and probably `scenario`. 154 | scenario = env_kwargs.get('scenario_name', None) 155 | if add_suffix: 156 | if suffix is not None: 157 | project_name = env_id + '_' + suffix 158 | elif scenario is not None: 159 | project_name = env_id + '_' + scenario 160 | else: 161 | raise Exception("Suffix of project is unavailable.") 162 | else: 163 | project_name = env_id 164 | # Multiple runs are distinguished by algo name and tag. 165 | run_name = run_tag if run_tag is not None else algo_name 166 | # Create a subdirectory to distinguish runs with different random seeds. 167 | args.run_dir = osp.join(DEFAULT_DATA_DIR, project_name, run_name + f"_seed{seed}") 168 | if not osp.exists(args.run_dir): 169 | os.makedirs(args.run_dir) 170 | print(f"Run '{run_name}' under directory '{args.run_dir}'.") 171 | 172 | if args.use_wandb: # If W&B is used, 173 | # Runs with the same config except for rand seeds are grouped and their histories are plotted together. 174 | wandb.init(config=args, project=project_name, group=run_name, name=run_name + f"_seed{seed}", dir=args.run_dir, 175 | reinit=True) 176 | args.wandb_run_dir = wandb.run.dir 177 | 178 | # Define env function. 179 | env_fn = partial(env_REGISTRY[env_id], **env_kwargs) # Env function 180 | test_env_fn = partial(env_REGISTRY[env_id], **env_kwargs) # Test env function 181 | 182 | # Create runner holding instance(s) of env and get info. 183 | runner = run_REGISTRY[args.runner](env_fn, test_env_fn, args) 184 | args = update_args_from_env(runner.env, args) # Adapt args to env. 185 | 186 | # Setup key components. 187 | env_info = runner.get_env_info() 188 | policy = policy_REGISTRY[args.policy](env_info, args) # Policy making decisions 189 | policy.to(args.device) # Move policy to device. 190 | 191 | buffer = buff_REGISTRY[args.buffer](args) # Buffer holding experiences 192 | learner = learn_REGISTRY[args.learner](env_info, policy, args) # Algorithm training policy 193 | runner.add_components(policy, buffer, learner) # Add above components to runner. 194 | 195 | # Run the main loop of training. 196 | runner.run() 197 | # Clean-up after training. 198 | runner.cleanup() 199 | 200 | 201 | if __name__ == '__main__': 202 | import argparse 203 | 204 | parser = argparse.ArgumentParser() 205 | parser.add_argument('--seed', '-s', type=int, default=10) 206 | parser.add_argument('--algo', type=str, default='q') 207 | parser.add_argument('--tag', type=str, default=None) 208 | args = parser.parse_args() 209 | 210 | # print(type(args)) 211 | # a = dict(name='bob') 212 | # args = SN(**a) 213 | # args2 = argparse.Namespace(**a) 214 | # print(args) 215 | # print(vars(args)) 216 | # print(type(args)) 217 | # print(args2) 218 | # print(vars((args2))) 219 | # raise ValueError 220 | 221 | # Train UBS coverage. 222 | # env_id = 'ubs' 223 | # env_kwargs = dict(scenario_name='simple') 224 | 225 | # # Train Ad Hoc route. 226 | env_id = 'ad-hoc' 227 | env_kwargs = dict() 228 | 229 | train_kwargs = dict(use_cuda=True, cuda_idx=0, use_wandb=False, record_tests=True, rollout_len=10, data_chunk_len=5) 230 | run(env_id, env_kwargs, args.seed, algo_name=args.algo, train_kwargs=train_kwargs, run_tag=args.tag) 231 | -------------------------------------------------------------------------------- /runners/__init__.py: -------------------------------------------------------------------------------- 1 | REGISTRY = {} 2 | 3 | from runners.base_runner import BaseRunner 4 | REGISTRY['base'] = BaseRunner 5 | 6 | from runners.episode_runner import EpisodeRunner 7 | REGISTRY['episode'] = EpisodeRunner 8 | -------------------------------------------------------------------------------- /runners/base_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from typing import Any 4 | import json 5 | 6 | import numpy as np 7 | from numpy import ndarray 8 | import torch as th 9 | from torch import Tensor 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | import wandb 13 | 14 | from components.make_env import make_env 15 | from components.misc import * 16 | from learners import DETERMINISTIC_POLICY_GRADIENT_ALGOS 17 | import time 18 | 19 | 20 | class BaseRunner: 21 | """Base class of runners""" 22 | 23 | def __init__(self, env_fn, test_env_fn, args): 24 | 25 | self.env = make_env(env_fn, args) # Env used for training 26 | self.test_env = make_env(test_env_fn, args) # Separate copy of env used for evaluation 27 | 28 | self.args = args 29 | self.device = args.device 30 | self.run_dir = args.run_dir 31 | 32 | if not args.use_wandb: # When W&B is not used, 33 | # Create a directory to save intermediate results as a Tensorboard event. 34 | self.log_dir = osp.join(self.run_dir, 'log') 35 | if not osp.exists(self.log_dir): 36 | os.makedirs(self.log_dir) 37 | 38 | # Manually save config locally. 39 | # Note that wandb automatically stores config. 40 | config_path = osp.join(self.log_dir, 'run_config.json') 41 | with open(config_path, 'w') as f: 42 | f.write(json.dumps(vars(self.args), separators=(',', ':\t'), indent=4, sort_keys=True)) 43 | 44 | # Setup SummaryWriter of Tensorboard. 45 | self.writer = SummaryWriter(log_dir=self.log_dir) 46 | 47 | # Create a directory to save model parameters. 48 | self.model_dir = osp.join(self.run_dir, 'model') 49 | if not osp.exists(self.model_dir): 50 | os.makedirs(self.model_dir) 51 | 52 | # Some empty attributes to be added later 53 | self.policy = None # Policy to make decision from observations 54 | self.buffer = None # Replay buffer holding experiences 55 | self.learner = None # Algorithm to learn optimal policy 56 | 57 | self.start_time = None # Record wall-clock time. 58 | self.t_warmup = None 59 | self.t_env = None # Counter of env steps 60 | 61 | def add_components(self, policy, buffer, learner): 62 | """Adds key components of RL.""" 63 | self.policy = policy 64 | self.buffer = buffer 65 | self.learner = learner 66 | 67 | def get_env_info(self): 68 | return self.env.get_env_info() 69 | 70 | def run(self): 71 | """Main function running the training loop.""" 72 | 73 | # Prepare for training. 74 | self.start_time = time.time() # Record the start time of training. 75 | rnn_states = self.warmup() 76 | self.test_agent() 77 | 78 | # Run the main loop of training. 79 | # print(f"self.args.update_after = {self.args.update_after}, self.args.batch_size = {self.args.batch_size}.") 80 | while self.t_env < self.args.total_env_steps: 81 | # Collect rollout with exploration. 82 | rnn_states = self.collect(rnn_states) 83 | # print(f"Finish collect at t = {self.t_env}") 84 | # print(f"(self.t_env >= self.args.update_after) = {self.t_env >= self.args.update_after} and " 85 | # f"self.buffer.can_sample(self.args.batch_size) = {self.buffer.can_sample(self.args.batch_size)}") 86 | # print(f"len(self.buffer) = {len(self.buffer)}") 87 | # Perform update when feasible. 88 | if (self.t_env >= self.args.update_after) and self.buffer.can_sample(self.args.batch_size): 89 | diagnostic = self.learner.update(self.buffer, self.args.batch_size) 90 | self.log_metrics(diagnostic) 91 | 92 | def warmup(self) -> dict[str, Tensor]: 93 | """Prepares for training.""" 94 | # Reset variables. 95 | self.t_env = 0 # Reset timestep counter. 96 | 97 | # Reset components. 98 | self.env.reset() # Reset env. 99 | self.learner.reset() # Reset learner. 100 | rnn_states = self.get_init_hidden() # Initial RNN states 101 | 102 | # Let agents select random actions to fill replay buffer. 103 | self.policy.eval() # Set policy to eval mode. 104 | self.learner.eval() # Set modules held by learner to eval mode. 105 | self.t_warmup = 0 # Reset warm-up step counter. 106 | while (self.t_warmup < self.args.warmup_steps) or not self.buffer.can_sample(self.args.batch_size): 107 | rnn_states = self.interact(rnn_states, mode='rand') # Random interaction 108 | self.t_warmup += 1 # One frozen step is completed. 109 | 110 | return rnn_states # Leave RNN states. 111 | 112 | def collect(self, rnn_states: dict[str, Tensor]) -> dict[str, Tensor]: 113 | """Explores a fixed number of timesteps to collect experiences.""" 114 | self.policy.eval() # Set policy to eval mode. 115 | self.learner.eval() # Set modules held by learner to eval mode. 116 | 117 | for t in range(self.args.rollout_len): 118 | rnn_states = self.interact(rnn_states, mode='explore') 119 | self.t_env += 1 # Another env step is finished. 120 | 121 | # Regularly save checkpoints. 122 | if (self.t_env % self.args.save_interval == 0) or (self.t_env >= self.args.total_env_steps): 123 | save_path = osp.join(self.model_dir, f'checkpoint_t{self.t_env}.pt') 124 | self.learner.save_checkpoint(save_path) 125 | 126 | # Test the performance of trained models. 127 | if self.t_env % self.args.test_interval == 0: 128 | self.test_agent() 129 | 130 | # Handle the end of session. 131 | if self.t_env % self.args.steps_per_session == 0: 132 | session = self.t_env // self.args.steps_per_session # Index of session 133 | print(f"Finish session {session} at step {self.t_env}/{self.args.total_env_steps} " 134 | f"after {(time.time() - self.start_time):.1f}s.") 135 | # Call lr scheduler(s) if enabled. 136 | if self.args.anneal_lr: 137 | self.learner.schedule_lr() 138 | return rnn_states # Leave RNN states for next rollout. 139 | 140 | def interact(self, rnn_states: dict[str, Tensor], mode: str = 'explore'): 141 | """Completes an agent-env interaction loop of MDP.""" 142 | 143 | # Build pre-decision data. 144 | pre_decision_data = dict(**self.get_inputs_from_env(self.env), **rnn_states) 145 | 146 | # Select actions following epsilon-greedy strategy. 147 | # print(f"self.env.nbrs = {[nbr.nid for nbr in self.env.nbrs]}, " 148 | # f"self.env.agent.is_connected = {self.env.agent.is_connected}") 149 | # print(f"avail_acts = \n{pre_decision_data['avail_actions']}") 150 | actions, h = self.policy.act(pre_decision_data['obs'], pre_decision_data['h'], 151 | pre_decision_data['avail_actions'], self.t_env, mode=mode) 152 | 153 | # Call environment step. 154 | rewards, terminated, info = self.env.step(self.get_actions_to_env(actions, self.env)) 155 | 156 | # When deterministic policy gradient is used, apply one-hot encoding to discrete actions. 157 | if (self.args.learner in DETERMINISTIC_POLICY_GRADIENT_ALGOS) and self.args.is_discrete: 158 | n_classes = self.args.nvec if self.args.is_multi_discrete else self.args.act_size 159 | actions = onehot_from_actions(actions, n_classes) 160 | 161 | # Call learner step to get required data (e.g., critic RNN states for off-policy AC algos). 162 | rnn_states = dict(h=h, **self.learner.step(pre_decision_data, actions)) # Next RNN states 163 | 164 | # Collect post-decision data. 165 | post_decision_data = { 166 | 'actions': actions, 'rewards': rewards, 167 | 'terminated': terminated != info.get('truncated', False), 168 | } 169 | # Save transition to replay buffer. 170 | self.cache(**pre_decision_data, **post_decision_data, filled=True) 171 | 172 | # Reach the end of an episode. 173 | if terminated: 174 | # # Log episode info. 175 | # if 'truncated' in info: # Drop indicator of episode limit. 176 | # del info['truncated'] 177 | # self.log_metrics(info) 178 | 179 | # Append a pseudo-transition for correct bootstrapping. 180 | # This transition does not actually occur and only pre-decision data is used. 181 | last_inputs = self.get_inputs_from_env(self.env) 182 | # Let agents select last actions. 183 | actions, _ = self.policy.act(last_inputs['obs'], rnn_states['h'], last_inputs['avail_actions'], 184 | self.t_env + 1, mode=mode) 185 | if (self.args.learner in DETERMINISTIC_POLICY_GRADIENT_ALGOS) and self.args.is_discrete: 186 | n_classes = self.args.nvec if self.args.is_multi_discrete else self.args.act_size 187 | actions = onehot_from_actions(actions, n_classes) 188 | # Take empty rewards. 189 | empty_rewards = np.zeros(self.policy.n_agents, dtype=np.float32) 190 | # Forge the spurious transition and store to replay buffer. 191 | pseudo_transition = dict(actions=actions, rewards=empty_rewards, terminated=True, filled=False, 192 | **last_inputs, **rnn_states) 193 | self.cache(**pseudo_transition) 194 | 195 | # Reset env and RNN states. 196 | self.env.reset() 197 | rnn_states = self.get_init_hidden() 198 | 199 | return rnn_states 200 | 201 | def get_init_hidden(self) -> dict[str, Tensor]: 202 | """Gets initial RNN states of policy and other modules.""" 203 | h_policy = self.policy.init_hidden() # RNN states of policy 204 | h_others = self.learner.init_hidden() # Dict holding RNN states of other components 205 | rnn_states = dict(h=h_policy.to(self.device), **{k: v.to(self.device) for k, v in h_others.items()}) 206 | return rnn_states 207 | 208 | def get_inputs_from_env(self, env, train_mode: bool = True) -> dict[str, Any]: 209 | """Gets inputs from env as part of pre-decision data.""" 210 | # obs and avail_actions are required by all algorithms. 211 | obs = env.get_obs() 212 | obs = obs.to(self.device) 213 | avail_actions = env.get_avail_actions() 214 | if avail_actions is not None: 215 | avail_actions = th.tensor(avail_actions, dtype=th.float32, device=self.device) 216 | inputs_from_env = dict(obs=obs, avail_actions=avail_actions) 217 | 218 | # Note that following items can only be used in training rather than execution. 219 | if train_mode: 220 | # Shared observations and states are obtained only if they are specified in fields. 221 | if 'shared_obs' in self.args.pre_decision_fields: 222 | shared_obs = env.get_shared_obs() 223 | shared_obs = shared_obs.to(self.device) 224 | inputs_from_env['shared_obs'] = shared_obs 225 | if 'state' in self.args.pre_decision_fields: 226 | state = env.get_state().to(self.device) 227 | inputs_from_env['state'] = state 228 | return inputs_from_env 229 | 230 | def get_actions_to_env(self, actions: Tensor, env) -> list: 231 | """Transforms actions from policy to a list.""" 232 | acts = actions.cpu().numpy() # Convert to ndarray. 233 | acts_per_agent = np.split(acts, env.n_agents, axis=0) # Each entry correspond to action of an agent. 234 | if self.args.is_discrete: # When discrete action is adopted, 235 | if self.args.is_multi_discrete: 236 | acts_per_agent = [act.squeeze().tolist() for act in acts_per_agent] # Convert ndarray to list. 237 | else: 238 | acts_per_agent = [act.item() for act in acts_per_agent] # Convert ndarray to scalar. 239 | return acts_per_agent 240 | 241 | def cache(self, obs, actions, rewards, terminated, avail_actions, filled=True, **kwargs): 242 | """Stores a transition to replay buffer.""" 243 | # Reshape entries and hold them with a dict. 244 | transition = dict(obs=obs, avail_actions=avail_actions, 245 | actions=actions.view(self.env.n_agents, self.args.act_shape), 246 | rewards=th.tensor(rewards, dtype=th.float, device=self.device).reshape(1, self.env.n_agents), 247 | terminated=th.tensor(terminated, dtype=th.int, device=self.device).reshape(1, 1), 248 | filled=th.tensor(filled, dtype=th.int, device=self.device).reshape(1, 1)) 249 | transition.update(**kwargs) 250 | # Insert transition to replay buffer. 251 | self.buffer.insert(transition) 252 | 253 | def test_agent(self): 254 | """Tests the performance of trained agent.""" 255 | test_ep_rsts = {} 256 | self.policy.eval() # Set policy to eval mode. 257 | for j in range(self.args.n_test_episodes): 258 | self.test_env.reset() # Reset test env. 259 | h, terminated = self.policy.init_hidden().to(self.device), False # Reset RNN states and terminated. 260 | 261 | # Run an episode. 262 | self.test_env.render() # Render test env. 263 | while not terminated: 264 | # Get observations and available actions. 265 | inputs = self.get_inputs_from_env(self.test_env, train_mode=False) 266 | # Take (quasi) deterministic actions. 267 | actions, h = self.policy.act(inputs['obs'], h, inputs['avail_actions'], mode='test') 268 | # Call test env step. 269 | _, terminated, info = self.test_env.step(self.get_actions_to_env(actions, self.test_env)) 270 | # Render test env. 271 | self.test_env.render() 272 | 273 | # Save figure of rendered test env. 274 | if self.args.record_tests and j < 10: # Save storage. 275 | self.test_env.save_replay(save_dir=osp.join(self.run_dir, f't{self.t_env}'), tag=f'ep{j}') 276 | # Record episode info. 277 | for name, rst in info.items(): 278 | if name != 'truncated': # All entries other than episode limit are recorded 279 | if name not in test_ep_rsts: 280 | test_ep_rsts[name] = [] 281 | test_ep_rsts[name].append(rst) 282 | 283 | # Log test results. 284 | self.log_metrics(test_ep_rsts, prefix='Test') 285 | 286 | def log_metrics(self, info: dict[str, Any], prefix: str = None) -> None: 287 | """Logs scalar metrics held in a dict.""" 288 | for name, value in info.items(): 289 | if prefix is not None: 290 | name = prefix + name 291 | # Transform vector metrics into scalars by averaging. 292 | if isinstance(value, list): 293 | value = np.array(value) 294 | if isinstance(value, ndarray) or isinstance(value, Tensor): 295 | value = value.mean() 296 | # Log given metrics. 297 | if not self.args.use_wandb: # Tensorboard 298 | self.writer.add_scalar(name, value, self.t_env) 299 | else: # wandb 300 | wandb.log({name: value}, step=self.t_env) 301 | 302 | def cleanup(self) -> None: 303 | """Terminates utils after training.""" 304 | 305 | self.env.close() 306 | self.test_env.close() 307 | 308 | if not self.args.use_wandb: 309 | self.writer.flush() 310 | self.writer.close() 311 | print(f"Use command `tensorboard --logdir={self.run_dir}` to view results.") 312 | else: 313 | wandb.finish() 314 | # TODO: Export history to local .csv file. 315 | # Note that history data can be visited by using the Public API. 316 | # See https://docs.wandb.ai/guides/track/public-api-guide for detailed explanation. 317 | -------------------------------------------------------------------------------- /runners/episode_runner.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import time 3 | 4 | import torch as th 5 | from torch import Tensor 6 | 7 | from components.misc import * 8 | from learners import DETERMINISTIC_POLICY_GRADIENT_ALGOS 9 | from runners.base_runner import BaseRunner 10 | 11 | 12 | class EpisodeRunner(BaseRunner): 13 | """Episode runner 14 | 15 | Each call of `collect` method runs an episode to its maximum length. 16 | """ 17 | 18 | def __init__(self, env_fn, test_env_fn, args): 19 | super(EpisodeRunner, self).__init__(env_fn, test_env_fn, args) 20 | self.max_episode_steps = self.env.get_env_info()['episode_limit'] 21 | print(f"`max_episode_steps` of EpisodeRunner is set to `episode_limit` as {self.max_episode_steps}.") 22 | assert self.max_episode_steps is not None, "Maximum episode length is absent for `EpisodeRunner`." 23 | 24 | self.curr_episode = [] 25 | 26 | def warmup(self): 27 | """Prepares for training.""" 28 | # Reset variables. 29 | self.t_env = 0 # Reset timestep counter. 30 | 31 | # Reset components. 32 | self.env.reset() # Reset env. 33 | self.learner.reset() # Reset learner. 34 | rnn_states = self.get_init_hidden() # Initial RNN states 35 | 36 | # Let agents select random actions to fill replay buffer. 37 | self.t_warmup = 0 # Reset warm-up step counter. 38 | while (self.t_warmup < self.args.warmup_steps) or not self.buffer.can_sample(self.args.batch_size): 39 | rnn_states = self.collect(rnn_states, frozen=True) # Random interaction 40 | 41 | return rnn_states 42 | 43 | def collect(self, rnn_states: dict[str, Tensor], frozen: bool = False) -> dict[str, Tensor]: 44 | """Explores a fixed number of timesteps to collect experiences.""" 45 | act_mode = 'rand' if frozen else 'explore' 46 | self.policy.eval() # Set policy to eval mode. 47 | self.learner.eval() # Set modules held by learner to eval mode. 48 | 49 | terminated, filled = False, True 50 | # print(f"\nEpisode begin. Call reset.") 51 | self.env.reset() 52 | # print(f"Complete reset. agent = {self.env.agent}") 53 | 54 | self.curr_episode = [] # Clear current episode 55 | for t in range(self.max_episode_steps + 1): 56 | if not terminated: # Episode continues. 57 | if frozen: 58 | self.t_warmup += 1 59 | else: 60 | self.t_env += 1 61 | else: # Fill the episode to maximum length after termination. 62 | filled = False 63 | 64 | # Build pre-decision data. 65 | pre_decision_data = dict(**self.get_inputs_from_env(self.env), **rnn_states) 66 | 67 | # Select actions following epsilon-greedy strategy. 68 | actions, h = self.policy.act(pre_decision_data['obs'], pre_decision_data['h'], 69 | pre_decision_data['avail_actions'], self.t_env, mode=act_mode) 70 | 71 | if filled: # Call environment step. 72 | # print(f"t = {t}, step with action {actions.squeeze()}") 73 | rewards, terminated, info = self.env.step(self.get_actions_to_env(actions, self.env)) 74 | else: # Pseudo-transition 75 | # print(f"t = {t}, pseudo-step with action {actions.squeeze()}") 76 | rewards, terminated, info = np.zeros(self.env.n_agents, dtype=np.float32), True, dict() 77 | 78 | # When deterministic policy gradient is used, apply one-hot encoding to discrete actions. 79 | if (self.args.learner in DETERMINISTIC_POLICY_GRADIENT_ALGOS) and self.args.is_discrete: 80 | n_classes = self.args.nvec if self.args.is_multi_discrete else self.args.act_size 81 | actions = onehot_from_actions(actions, n_classes) 82 | 83 | # Call learner step to get required data (e.g., critic RNN states for off-policy AC algos). 84 | rnn_states = dict(h=h, **self.learner.step(pre_decision_data, actions)) # Next RNN states 85 | 86 | # Collect post-decision data. 87 | post_decision_data = { 88 | 'actions': actions, 'rewards': rewards, 89 | 'terminated': terminated != info.get('truncated', False), 90 | } 91 | # Save transition to replay buffer. 92 | self.cache(**pre_decision_data, **post_decision_data, filled=filled) 93 | 94 | if filled and not frozen: # Avoid recurring call of following functions. 95 | # if terminated: # End of episode handling. 96 | # if 'truncated' in info: # Drop indicator of episode limit. 97 | # del info['truncated'] 98 | # self.log_metrics(info) 99 | 100 | # Regularly save checkpoints. 101 | if (self.t_env % self.args.save_interval == 0) or (self.t_env >= self.args.total_env_steps): 102 | save_path = osp.join(self.model_dir, f'checkpoint_t{self.t_env}.pt') 103 | self.learner.save_checkpoint(save_path) 104 | 105 | # Test the performance of trained models. 106 | if self.t_env % self.args.test_interval == 0: 107 | self.test_agent() 108 | 109 | # Handle the end of session. 110 | if self.t_env % self.args.steps_per_session == 0: 111 | session = self.t_env // self.args.steps_per_session # Index of session 112 | print(f"Finish session {session} at step {self.t_env}/{self.args.total_env_steps} " 113 | f"after {(time.time() - self.start_time):.1f}s.") 114 | # Call lr scheduler(s) if enabled. 115 | if self.args.anneal_lr: 116 | self.learner.schedule_lr() 117 | 118 | # if getattr(self.args, 'retrace_rewards', False): 119 | # # print(f"Final agent route = {self.env.agent}.") 120 | # # Reshape rewards at the end of episode. 121 | # after_rewards = self.env.retrace_onward_bottleneck_rate() 122 | # # print(f"connected = {self.env.agent.is_connected}, after_rewards = {after_rewards}, per_hop_rate = {self.env.get_per_hop_rate(self.env.agent)}") 123 | # # if self.env.agent.is_connected: 124 | # # self.env.save_replay(show_img=True) 125 | # # raise ValueError 126 | # for t, r in enumerate(after_rewards): 127 | # if t < len(self.curr_episode): 128 | # self.curr_episode[t]["rewards"] = th.tensor(r, dtype=th.float, device=self.device).reshape(1, self.env.n_agents) 129 | # else: 130 | # print("Number of hops > `max_episode_steps`.") 131 | # print(f"agent = {self.env.agent}, n_hops = {self.env.agent.n_hops}") 132 | # print(f"after_rewards = {after_rewards}") 133 | # self.env.save_replay(show_img=True) 134 | # raise ValueError 135 | 136 | # Insert transitions to replay buffer. 137 | for transition in self.curr_episode: 138 | self.buffer.insert(transition) 139 | 140 | return rnn_states 141 | 142 | def cache(self, obs, actions, rewards, terminated, avail_actions, filled=True, **kwargs): 143 | # Reshape entries and hold them with a dict. 144 | transition = dict(obs=obs, avail_actions=avail_actions, 145 | actions=actions.view(self.env.n_agents, self.args.act_shape), 146 | rewards=th.tensor(rewards, dtype=th.float, device=self.device).reshape(1, self.env.n_agents), 147 | terminated=th.tensor(terminated, dtype=th.int, device=self.device).reshape(1, 1), 148 | filled=th.tensor(filled, dtype=th.int, device=self.device).reshape(1, 1)) 149 | transition.update(**kwargs) 150 | self.curr_episode.append(transition) 151 | 152 | --------------------------------------------------------------------------------