├── README.md ├── envs.py ├── gen_data.py ├── nn_modules.py ├── train_dnn.py ├── train_hignn.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # hignn 2 | Code implementation of our paper [Scalable Power Control/Beamforming in Heterogeneous Wireless Networks with Graph Neural Networks](https://ieeexplore.ieee.org/document/9685457). 3 | 4 | - `envs.py` defines classes of heterogeneous wireless channels and provides an implementation of the closed-form FP algorithm in heterogeneous settings. 5 | - `gen_data.py` generates datasets for training/test. 6 | - `utils.py` includes functions shared by both `train_hignn.py` and `train_dnn.py`. 7 | - `nn_modules.py` defines the neural network (NN) modules. 8 | - `train_hignn.py` is the main file carrying out the training-loop of *heterogeneous interference graph neural networks (HIGNNs)*. 9 | - `train_dnn.py` is the main file carrying out the training-loop of deep neural networks (DNNs) as comparison. 10 | 11 | If you use this code, please cite our work: 12 | 13 | ```tex 14 | @INPROCEEDINGS{9685457, 15 | author={Zhang, Xiaochen and Zhao, Haitao and Xiong, Jun and Liu, Xiaoran and Zhou, Li and Wei, Jibo}, 16 | booktitle={2021 IEEE Global Communications Conference (GLOBECOM)}, 17 | title={Scalable Power Control/Beamforming in Heterogeneous Wireless Networks with Graph Neural Networks}, 18 | year={2021}, 19 | volume={}, 20 | number={}, 21 | pages={01-06}, 22 | doi={10.1109/GLOBECOM46510.2021.9685457} 23 | } 24 | ``` 25 | 26 | If you have any questions, please contact zhangxiaochen14@nudt.edu.cn. 27 | -------------------------------------------------------------------------------- /envs.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This script defines 3 | - classes of heterogeneous MISO wireless channels 4 | - benchmark algos for beamforming 5 | - other required components 6 | """ 7 | 8 | # Import functional tools. 9 | import copy 10 | # Import packages for visualization. 11 | import matplotlib.pyplot as plt 12 | # Import packages for mathematics. 13 | import numpy as np 14 | 15 | 16 | class InterferenceChannel: 17 | """Base class of general heterogeneous interference channels (IFCs)""" 18 | 19 | def __init__(self, num_links, num_tx_ants, p_max, var_awgn, weight, bandwidth=1.): 20 | 21 | # Parameters shared by all links within each type 22 | self.num_links = copy.deepcopy(num_links) # Number of links for each link type 23 | self.num_tx_ants = copy.deepcopy(num_tx_ants) # Number of Tx antennas for each link type 24 | 25 | # Parameters individually held by each link 26 | self.p_max = {k: v * np.ones(self.num_links[k]) for k, v in p_max.items()} # Power constraint on each transmitter (Watt) 27 | self.var_awgn = {k: v * np.ones(self.num_links[k]) for k, v in var_awgn.items()} # Noise variance at each receiver (Watt) 28 | self.weight = copy.deepcopy(weight) # Weight of each link 29 | 30 | # Parameters shared by the entire IFC 31 | self.bandwidth = bandwidth # Total bandwidth shared by all links (Hz) 32 | self.total_num_links = sum(num_links.values()) # Total number of links 33 | 34 | # Display specs of the channel. 35 | print("An instance of {} is created with specs:".format(self.__class__.__name__)) 36 | print("num_links = {}".format(self.num_links)) 37 | print("weight = \n{}".format(self.weight)) 38 | print("bandwidth = {} (Hz)".format(self.bandwidth)) 39 | print("var_awgn = \n{} (Watt)".format(self.var_awgn)) 40 | print("p_max = \n{} (Watt)".format(self.p_max)) 41 | 42 | def channel_response(self): 43 | """Generates zero-mean complex Gaussian variables with unit variance as channel response for all links.""" 44 | h = {} # A dictionary holding channel response of all relations 45 | # h[(rtype, ttype)] refers to a relation between transmitters of ttype and receivers of rtype, 46 | # of which the (i, j)-th entry gives CSI between the j-th transmitter of ttype and the i-th receiver of rtype. 47 | for rtype in self.num_links.keys(): 48 | for ttype in self.num_links.keys(): 49 | h[(rtype, ttype)] = 1 / np.sqrt(2) * ( 50 | np.random.randn(self.num_links[rtype], self.num_links[ttype], self.num_tx_ants[ttype]) 51 | + 1j * np.random.randn(self.num_links[rtype], self.num_links[ttype], self.num_tx_ants[ttype]) 52 | ) 53 | return h 54 | 55 | 56 | class HeteroD2DNetwork(InterferenceChannel): 57 | """Heterogeneous device-to-device (D2D) network""" 58 | 59 | def __init__(self, num_links, num_tx_ants, p_max, var_awgn, weight, bandwidth=5e6, 60 | len_area=400., dmin=2, dmax=50): 61 | super(HeteroD2DNetwork, self).__init__(num_links, num_tx_ants, p_max, var_awgn, weight, bandwidth) 62 | 63 | self.len_area = len_area # Length of the square area 64 | self.dmin, self.dmax = dmin, dmax # Minimum/maximum link distance 65 | 66 | self.pos_tx, self.pos_rx = {}, {} # Positions of Tx/Rx 67 | self.dist = {} # Distance between each Tx and Rx 68 | 69 | def layout(self): 70 | """Randomly determine the Tx/Rx positions.""" 71 | for ltype in self.num_links.keys(): 72 | # Randomly determine the positions of Tx in the square area. 73 | self.pos_tx[ltype] = self.len_area * np.random.rand(self.num_links[ltype], 2) 74 | self.pos_rx[ltype] = [] 75 | 76 | for idx in range(self.num_links[ltype]): 77 | is_valid_pos = False 78 | while not is_valid_pos: 79 | radius = self.dmin + np.random.rand() * (self.dmax - self.dmin) # Distance of Rx to Rx 80 | ang = 2 * np.pi * np.random.rand() # Angle of Rx to Rx 81 | pos = self.pos_tx[ltype][idx] + radius * np.array([np.cos(ang), np.sin(ang)]) # Position of Rx 82 | # The generated position is kept only if it lies in the valid range. 83 | if 0 <= pos[0] <= self.len_area and 0 <= pos[1] <= self.len_area: 84 | is_valid_pos = True 85 | self.pos_rx[ltype].append(pos) 86 | self.pos_rx[ltype] = np.array(self.pos_rx[ltype]) 87 | 88 | # Compute distance between each Rx and Tx. 89 | for rtype in self.num_links.keys(): 90 | for ttype in self.num_links.keys(): 91 | # The (i, j)-th entry is the distance between the Rx i of rtype and Tx j of ttype. 92 | self.dist[(rtype, ttype)] = np.empty([self.num_links[rtype], self.num_links[ttype]]) 93 | for i in range(self.num_links[rtype]): 94 | for j in range(self.num_links[ttype]): 95 | self.dist[(rtype, ttype)][i, j] = np.sqrt(np.square(self.pos_rx[rtype][i] - self.pos_tx[ttype][j]).sum()) 96 | # Ensure that nodes are not too close. 97 | self.dist[(rtype, ttype)] = np.clip(self.dist[(rtype, ttype)], self.dmin, np.sqrt(2) * self.len_area) 98 | 99 | def channel_response(self): 100 | h = {} # A dictionary holding channel response of all relations 101 | # Reset the layout of the network. 102 | self.layout() 103 | # Generate channel response for each relation. 104 | for rtype in self.num_links.keys(): 105 | for ttype in self.num_links.keys(): 106 | shadowing = np.random.randn(*self.dist[(rtype, ttype)].shape) 107 | large_scale_fading = 4.4 * 10 ** 5 / ((self.dist[(rtype, ttype)] ** 1.88) * (10 ** (shadowing * 6.3 / 20))) 108 | small_scale_fading = 1 / np.sqrt(2) * ( 109 | np.random.randn(self.num_links[rtype], self.num_links[ttype], self.num_tx_ants[ttype]) + 110 | 1j * np.random.randn(self.num_links[rtype], self.num_links[ttype], self.num_tx_ants[ttype])) 111 | h[(rtype, ttype)] = np.sqrt(np.expand_dims(large_scale_fading, axis=-1)) * small_scale_fading 112 | return h 113 | 114 | 115 | def link_capacity(h, x, var_awgn): 116 | """ 117 | Computes the rate profile for each link in the IFC. 118 | 119 | Arguments: 120 | h (dict of ndarrays): channel response of all relations 121 | x (dict of ndarrays): beamforming vectors of each transmitters 122 | var_awgn (dict of ndarrays): noise variance at each receiver 123 | 124 | Returns: 125 | An dict of ndarrays listing the capacity of each link 126 | """ 127 | # At each receiver, compute the received signal power from each transmitter. 128 | rx_power, interference, sinr, rate = {}, {}, {}, {} 129 | for rtype, ttype in h.keys(): 130 | rx_power[(rtype, ttype)] = np.square(np.abs((h[(rtype, ttype)] * x[ttype]).sum(-1))) 131 | 132 | # Compute the interference level at each receiver. 133 | for rtype in x.keys(): 134 | interference[rtype] = - np.diag(rx_power[(rtype, rtype)]) 135 | for ttype in x.keys(): 136 | interference[rtype] += rx_power[(rtype, ttype)].sum(-1) 137 | 138 | # Compute the signal-to-interference-plus-noise ratio (SINR) and achievable rate. 139 | for rtype in x.keys(): 140 | sinr[rtype] = np.diag(rx_power[(rtype, rtype)]) / (interference[rtype] + var_awgn[rtype]) 141 | rate[rtype] = np.log(1 + sinr[rtype]) 142 | return rate 143 | 144 | 145 | def weighted_sum_rate(h, x, var_awgn, w): 146 | """ 147 | Computes the weighted sum rate (WSR) of the IFC. 148 | 149 | Arguments: 150 | h (dict of ndarrays): channel response of all relations 151 | x (dict of ndarrays): transmission scheme of each transmitters 152 | var_awgn (dict of ndarrays): noise variance at each receiver 153 | w (dict of ndarrays): weight of each link 154 | 155 | Returns: 156 | the weighted sum rate of all links in the channel 157 | """ 158 | # Compute the rate for each link. 159 | rate = link_capacity(h, x, var_awgn) 160 | # Accumulate the WSR throughout link types. 161 | wsr = 0. # Weighted sum rate 162 | for ltype in w.keys(): 163 | wsr += np.dot(w[ltype], rate[ltype]) 164 | return wsr 165 | 166 | 167 | def max_wsr_cf_fp(h, var_awgn, p_max, w, max_num_iters=500): 168 | """ 169 | Maximizes the WSR of IFC via closed-form fractional programming (FP), which is originally proposed in the paper 170 | 171 | @ARTICLE{8314727, 172 | author={K. {Shen} and W. {Yu}}, 173 | journal={IEEE Transactions on Signal Processing}, 174 | title={Fractional Programming for Communication Systems—Part I: Power Control and Beamforming}, 175 | year={2018}, 176 | volume={66}, 177 | number={10}, 178 | pages={2616-2630}, 179 | doi={10.1109/TSP.2018.2812733} 180 | } 181 | 182 | Arguments: 183 | h (dict of ndarrays): channel response of all relations 184 | var_awgn (dict of ndarrays): noise variance at each receiver 185 | p_max (dict of ndarrays): power constraint on each transmitter 186 | w (dict of ndarrays): weight of each link 187 | max_num_iters (int): Maximum number of iterations 188 | 189 | Returns: 190 | x (dict of ndarrays): optimal transmission scheme for each types of links 191 | log_wsrs (list): WSRs achieved through iterations 192 | """ 193 | # Initialize transmission scheme. 194 | num_links, num_tx_ants = {}, {} # Number of links and Tx antennas 195 | x, gamma, y = {}, {}, {} # Transmission scheme and auxiliary variables 196 | for ltype in p_max.keys(): 197 | num_links[ltype] = h[(ltype, ltype)].shape[1] 198 | num_tx_ants[ltype] = h[(ltype, ltype)].shape[-1] 199 | x[ltype] = np.expand_dims(np.sqrt(p_max[ltype] / num_tx_ants[ltype]), -1) * np.ones( 200 | (num_links[ltype], num_tx_ants[ltype]), dtype=np.complex) 201 | 202 | epsilon = 1e-5 # Threshold to terminate loops 203 | # Record the WSR achieved by initial transmission scheme 204 | wsrs = [weighted_sum_rate(h, x, var_awgn, w)] # Log of WSRs 205 | 206 | for num_iter in range(max_num_iters): 207 | x_pre = copy.deepcopy(x) # Copy of solution in the last iteration 208 | # Compute received signal, power and interference. 209 | rx_signal, rx_power, interference = {}, {}, {} 210 | for rtype, ttype in h.keys(): 211 | rx_signal[(rtype, ttype)] = (h[(rtype, ttype)] * x[ttype]).sum(-1) 212 | rx_power[(rtype, ttype)] = np.square(np.abs(rx_signal[(rtype, ttype)])) 213 | for rtype in x.keys(): 214 | interference[rtype] = -np.diag(rx_power[(rtype, rtype)]) 215 | for ttype in x.keys(): 216 | interference[rtype] += rx_power[(rtype, ttype)].sum(-1) 217 | for ltype in x.keys(): 218 | # Update gamma by (57). 219 | gamma[ltype] = np.diag(rx_power[(ltype, ltype)]) / (interference[ltype] + var_awgn[ltype]) 220 | # Update y by (60). 221 | y[ltype] = np.sqrt(w[ltype] * (1 + gamma[ltype])) * np.diag(rx_signal[(ltype, ltype)]) / ( 222 | interference[ltype] + np.diag(rx_power[(ltype, ltype)]) + var_awgn[ltype]) 223 | # Update x by (61). 224 | hy, hyyh, sum_hyyh = {}, {}, {} 225 | for rtype in x.keys(): 226 | for ttype in x.keys(): 227 | # Compute $h_{ji}^{H}y_{j}$ for each i, j. 228 | hy[(rtype, ttype)] = np.conj(h[(rtype, ttype)]) * y[rtype].reshape(-1, 1, 1) 229 | # Compute $h_{ji}^{H}y_{j}y_{j}^{H}h_{ji}$ for each (i, j). 230 | hyyh[(rtype, ttype)] = np.matmul(np.expand_dims(hy[(rtype, ttype)], -1), 231 | np.expand_dims(np.conj(hy[(rtype, ttype)]), -2)) 232 | 233 | # Sum $h_{ji}^{H}y_{j}y_{j}^{H}h_{ji}$ along j (idx of Rx). 234 | for ttype in x.keys(): 235 | sum_hyyh[ttype] = 0 236 | for rtype in x.keys(): 237 | sum_hyyh[ttype] += hyyh[(rtype, ttype)].sum(axis=0) 238 | 239 | for ltype in x.keys(): 240 | for i in range(num_links[ltype]): 241 | eta_l, eta_r = 0., 1. 242 | x[ltype][i] = np.sqrt(w[ltype][i] * (1 + gamma[ltype][i])) * np.dot( 243 | np.linalg.inv(eta_r * np.eye(num_tx_ants[ltype]) + sum_hyyh[ltype][i]), hy[(ltype, ltype)][i, i]) 244 | 245 | # Scale the lower/upper bound such that the optimal eta lies in (eta_l, eta_r). 246 | while np.linalg.norm(x[ltype][i]) > np.sqrt(p_max[ltype][i]): 247 | eta_l = eta_r 248 | eta_r = 2 * eta_r 249 | x[ltype][i] = np.sqrt(w[ltype][i] * (1 + gamma[ltype][i])) * np.dot( 250 | np.linalg.inv(eta_l * np.eye(num_tx_ants[ltype]) + sum_hyyh[ltype][i]), 251 | hy[(ltype, ltype)][i, i]) 252 | 253 | # Use bisection search to find the optimal eta. 254 | while eta_r - eta_l > epsilon: 255 | eta = (eta_l + eta_r) / 2 256 | x[ltype][i] = np.sqrt(w[ltype][i] * (1 + gamma[ltype][i])) * np.dot( 257 | np.linalg.inv(eta * np.eye(num_tx_ants[ltype]) + sum_hyyh[ltype][i]), hy[(ltype, ltype)][i, i]) 258 | if np.linalg.norm(x[ltype][i]) < np.sqrt(p_max[ltype][i]): 259 | eta_r = eta 260 | else: 261 | eta_l = eta 262 | # Record the achieved WSR in the current iteration. 263 | wsrs.append(weighted_sum_rate(h, x, var_awgn, w)) 264 | # Break the loop in advance when the solution converges. 265 | diff = 0. # Difference between the current solution and previous one. 266 | for ltype in x.keys(): 267 | diff += np.linalg.norm(x[ltype] - x_pre[ltype]) 268 | if diff <= epsilon: 269 | break 270 | return x, wsrs 271 | 272 | 273 | if __name__ == '__main__': 274 | # Execute only if the script is run as main file. 275 | 276 | # Create an instance of wireless network as env. 277 | env = HeteroD2DNetwork(**{'num_links': {'siso': 8, 'miso': 4}, 'num_tx_ants': {'siso': 1, 'miso': 2}, 278 | 'p_max': {'siso': 1., 'miso': 1.}, 'var_awgn': {'siso': 1., 'miso': 1.}, 279 | 'weight': {'siso': np.ones(8), 'miso': np.ones(4)}}) 280 | # Compute the channel response. 281 | h = env.channel_response() 282 | # Check the range of channel coefficients. 283 | print("Check the range of channel coefficients:") 284 | for ltype in h.keys(): 285 | print("|h[{}]| lies in [{}, {}]".format(ltype, np.min(np.abs(h[ltype])), np.max(np.abs(h[ltype])))) 286 | 287 | # Test the func 'weighted_sum_rate'. 288 | print("Test the func 'weighted_sum_rate':") 289 | x = {'siso': np.random.randn(env.num_links['siso'], env.num_tx_ants['siso']), 290 | 'miso': np.random.randn(env.num_links['miso'], env.num_tx_ants['miso'])} 291 | r = link_capacity(h, x, env.var_awgn) 292 | print("r = {}".format(r)) 293 | wsr = weighted_sum_rate(h, x, var_awgn=env.var_awgn, w=env.weight) 294 | print("wsr = {}".format(wsr)) 295 | 296 | # Test the correctness of benchmark algorithms. 297 | _, wsrs = max_wsr_cf_fp(h, env.var_awgn, env.p_max, env.weight) 298 | 299 | plt.figure() 300 | plt.plot(wsrs, marker='.', markevery=50, label='Closed-form FP') 301 | plt.legend() 302 | plt.title("Weighted sum rates (WSR) vs. iterations") 303 | plt.xlabel("iteration") 304 | plt.ylabel("WSR (nat/sec)") 305 | plt.show() 306 | -------------------------------------------------------------------------------- /gen_data.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This script generates training/test sets. 3 | """ 4 | 5 | # Import functional tools. 6 | import pickle 7 | from tqdm import tqdm 8 | # Import packages for mathematics. 9 | import numpy as np 10 | # Import user-defined modules. 11 | import envs 12 | 13 | 14 | def gen_samples(size, chan_type, requires_label=True): 15 | """ 16 | Arguments: 17 | size (int): number of samples to generate 18 | chan_type (str): type of channel 19 | requires_label (Bool): if optimal beamforming vectors are computed 20 | 21 | Returns: 22 | a list of samples holding information with a dict 23 | """ 24 | # Ensure the channel type is available. 25 | assert chan_type in ['gaussian', 'd2d'] 26 | n_links = {'siso': 8, 'miso': 4} 27 | n_t = {'siso': 1, 'miso': 2} 28 | specs = {'num_links': n_links, 'num_tx_ants': n_t, 29 | 'p_max': {k: 1. for k in n_links.keys()}, 30 | 'var_awgn': {k: 1. for k in n_links.keys()}, 31 | 'weight': {k: np.ones(v) for k, v in n_links.items()}} 32 | if chan_type == 'gaussian': 33 | env = envs.InterferenceChannel(**specs) 34 | else: 35 | env = envs.HeteroD2DNetwork(**specs) 36 | 37 | # Generate samples. 38 | data_list = [] 39 | for idx in tqdm(range(size)): 40 | h = env.channel_response() 41 | sample = {'num_links': env.num_links, 'num_tx_ants': env.num_tx_ants, 42 | 'var_awgn': env.var_awgn, 'weight': env.weight, 'p_max': env.p_max, 43 | 'h': h} 44 | # Get labels if 'requires_label' == True. 45 | if requires_label: 46 | # Compute the optimal transmission scheme and WSR with benchmark algorithm. 47 | x, wsrs = envs.max_wsr_cf_fp(h=h, var_awgn=env.var_awgn, p_max=env.p_max, w=env.weight) 48 | sample['x_targ'], sample['wsr_targ'] = x, wsrs[-1] 49 | # Append the sample to the list. 50 | data_list.append(sample) 51 | return data_list, specs 52 | 53 | 54 | if __name__ == '__main__': 55 | # Execute only if the script is run as main file. 56 | 57 | # Create and store samples. 58 | size_trainset = 500000 # Number of samples to be generated 59 | size_testset = 1000 60 | chan_type = 'd2d' # Type of channel 61 | PATH = './datasets/d2d_12links/' 62 | 63 | # Generate training/test data. 64 | train_data, specs = gen_samples(size_trainset, chan_type, requires_label=False) 65 | test_data, specs = gen_samples(size_testset, chan_type, requires_label=True) 66 | 67 | # Save training set. 68 | dir_trainset = PATH + 'train.pickle' 69 | file = open(dir_trainset, 'wb') 70 | pickle.dump(train_data, file) 71 | file.close() 72 | print("Training set is stored to directory: {}".format(dir_trainset)) 73 | 74 | # Save test set. 75 | dir_testset = PATH + 'test.pickle' 76 | file = open(dir_testset, 'wb') 77 | pickle.dump(test_data, file) 78 | file.close() 79 | print("Test set is stored to directory: {}".format(dir_testset)) 80 | 81 | # Save data specifications. 82 | dir_specs = PATH + 'specs.pickle' 83 | file = open(dir_specs, 'wb') 84 | pickle.dump(specs, file) 85 | file.close() 86 | 87 | # Read and check data. 88 | with open(dir_trainset, 'rb') as file: 89 | train_data = pickle.load(file) 90 | with open(dir_testset, 'rb') as file: 91 | test_data = pickle.load(file) 92 | print("Size of training set: {}".format(len(train_data))) 93 | print("Size of test set: {}".format(len(test_data))) 94 | print("specs: {}".format(specs)) -------------------------------------------------------------------------------- /nn_modules.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This script defines the neural network (NN) modules. 3 | """ 4 | 5 | # Import functional tools. 6 | import copy 7 | # Import packages for mathematics. 8 | import math 9 | # Import packages for ML. 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import dgl.function as fn 14 | 15 | 16 | def create_mlp(sizes): 17 | """ 18 | Creates a multi-layer perceptron (MLP). 19 | 20 | Arguments: 21 | sizes (list): integers specifying number of neurons for each layer 22 | 23 | Returns: 24 | an MLP holding layers using a sequential container 25 | """ 26 | assert len(sizes) > 1 27 | layers = [] 28 | for l in range(len(sizes) - 1): 29 | layers += [nn.Linear(sizes[l], sizes[l + 1])] 30 | # layers += [nn.BatchNorm1d(sizes[l + 1])] 31 | layers += [nn.ReLU()] if l < len(sizes) - 2 else [] 32 | return nn.Sequential(*layers) 33 | 34 | 35 | class GraphConv(nn.Module): 36 | """Graph convolution layer""" 37 | 38 | def __init__(self, num_src_feats, num_dst_feats, num_edge_feats, msg_size, num_out_feats): 39 | super(GraphConv, self).__init__() 40 | # Message function takes the input features of src nodes and edge features. 41 | self.fc_msg = create_mlp([num_src_feats + num_edge_feats, 16, msg_size]) 42 | # self.fc_msg = create_mlp([num_src_feats + num_edge_feats, 64, 32, msg_size]) 43 | # Update function takes the input features of dst nodes and aggregated messages. 44 | self.fc_udt = create_mlp([num_dst_feats + msg_size, 16, num_out_feats]) 45 | # self.fc_udt = create_mlp([num_dst_feats + msg_size, 64, 32, num_out_feats]) 46 | 47 | def message(self, edges): 48 | # msg = torch.cat([edges.data['h'].flatten(start_dim=1), edges.src['in_feats']], dim=1) 49 | msg = torch.cat([edges.data['h'], edges.src['in_feats']], dim=1) 50 | return {'m': self.fc_msg(msg)} 51 | 52 | def forward(self, graph, x): 53 | num_nan = 0 54 | for item in self.fc_msg: 55 | if isinstance(item, nn.Linear): 56 | num_nan += item.weight.isnan().sum() 57 | for item in self.fc_udt: 58 | if isinstance(item, nn.Linear): 59 | num_nan += item.weight.isnan().sum() 60 | if num_nan > 0: 61 | print("nan is found in model parameters.") 62 | graph.ndata['in_feats'] = x 63 | graph.update_all(self.message, fn.max('m', 'r')) 64 | return self.fc_udt(torch.cat([graph.dstdata['in_feats'], graph.dstdata['r']], dim=1)) 65 | 66 | 67 | class HeteroGraphConv(nn.Module): 68 | """Heterograph convolution layer""" 69 | 70 | def __init__(self, num_tx_ants, num_in_feats, msg_size, num_out_feats): 71 | super(HeteroGraphConv, self).__init__() 72 | mods = {} 73 | for stype in num_tx_ants.keys(): 74 | for dtype in num_tx_ants.keys(): 75 | mods[stype + '-to-' + dtype] = GraphConv(num_in_feats[stype], num_in_feats[dtype], 76 | 2 * (num_tx_ants[stype] + num_tx_ants[dtype]), 77 | msg_size, 78 | num_out_feats[dtype]) 79 | # Holds submodules in a dictionary. 80 | self.mods = nn.ModuleDict(mods) 81 | 82 | def forward(self, g, inputs): 83 | """Applies graph convolution on heterograph g.""" 84 | # Each node type has individual outputs. 85 | outputs = {ntype: [] for ntype in g.dsttypes} 86 | # Graph convolution is executed for each relation. 87 | for stype, etype, dtype in g.canonical_etypes: 88 | # Create a subgraph for the current relation. 89 | rel_graph = g[stype, etype, dtype] 90 | # Skip if the relation contains no edge. 91 | if rel_graph.number_of_edges() == 0: 92 | continue 93 | # Skip if no input is provided for src nodes. 94 | if stype not in inputs: 95 | continue 96 | # Pass the subgraph and inputs to corresponding module. 97 | if rel_graph.is_homogeneous: 98 | # If the subgraph is homogeneous, inputs should be a Tensor specified by node type. 99 | dstdata = self.mods[stype + '-to-' + dtype](rel_graph, inputs[stype]) 100 | else: 101 | # If the subgraph is heterogeneous, inputs is direct passed as a dict. 102 | dstdata = self.mods[stype + '-to-' + dtype](rel_graph, inputs) 103 | outputs[dtype].append(dstdata) 104 | # Aggregate the outputs from all relations for each node type. 105 | rsts = {} 106 | for ntype, alist in outputs.items(): 107 | if len(alist) != 0: 108 | # rsts[ntype] = torch.mean(torch.stack(alist), 0) 109 | rsts[ntype], indices = torch.max(torch.stack(alist), 0) 110 | return rsts 111 | 112 | 113 | class HIGNN(nn.Module): 114 | """Heterogeneous interference graph neural networks""" 115 | 116 | def __init__(self, num_tx_ants, p_max, msg_size=8, udt_size=8, num_layers=3): 117 | super(HIGNN, self).__init__() 118 | self.num_tx_ants = num_tx_ants 119 | self.num_hidden_layers = max(0, num_layers - 2) 120 | self.max_output = {k: math.sqrt(v) for k, v in p_max.items()} 121 | 122 | self.enc = HeteroGraphConv(**{'num_tx_ants': num_tx_ants, 'msg_size': msg_size, 123 | 'num_in_feats': {k: 2 * v for k, v in num_tx_ants.items()}, 124 | 'num_out_feats': {k: udt_size for k, v in num_tx_ants.items()}}) 125 | 126 | self.core = HeteroGraphConv(**{'num_tx_ants': num_tx_ants, 'msg_size': msg_size, 127 | 'num_in_feats': {k: udt_size + 2 * v for k, v in num_tx_ants.items()}, 128 | 'num_out_feats': {k: udt_size for k, v in num_tx_ants.items()}}) 129 | 130 | self.dec = HeteroGraphConv(**{'num_tx_ants': num_tx_ants, 'msg_size': msg_size, 131 | 'num_in_feats': {k: udt_size + 2 * v for k, v in num_tx_ants.items()}, 132 | 'num_out_feats': {k: 2 * v for k, v in num_tx_ants.items()}}) 133 | 134 | def forward(self, g, x): 135 | feats = copy.deepcopy(x) 136 | # Encoder 137 | x = self.enc(g, x) 138 | # Processor 139 | for i in range(self.num_hidden_layers): 140 | x = {k: torch.cat([x[k], feats[k]], dim=1) for k in x.keys()} 141 | x = self.core(g, x) 142 | x = {k: F.relu(v) for k, v in x.items()} 143 | # Decoder 144 | x = {k: torch.cat([x[k], feats[k]], dim=1) for k in x.keys()} 145 | x = self.dec(g, x) 146 | # Impose power constraints at output layer. 147 | x_norm = {k: torch.sqrt(v.pow(2).sum(-1, keepdim=True)) for k, v in x.items()} 148 | x = {k: self.max_output[k] * torch.div(v, torch.max(x_norm[k], torch.ones(x_norm[k].size()))).view( 149 | -1, self.num_tx_ants[k], 2) for k, v in x.items()} 150 | return x 151 | 152 | -------------------------------------------------------------------------------- /train_dnn.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This script trains a deep neural network (DNN) for resource allocation (RA) in heterogeneous wireless networks. 3 | """ 4 | 5 | # import functional tools 6 | from tqdm import tqdm 7 | import copy 8 | import pickle 9 | # Import packages for mathematics. 10 | import numpy as np 11 | # Import packages for ML. 12 | import torch 13 | from torch.utils.data import DataLoader 14 | # Import user-defined modules. 15 | import utils, nn_modules as nn_mods 16 | 17 | 18 | def compute_statistics(samples): 19 | """Computes the mean/variance of given samples.""" 20 | statistics = {} 21 | for rtype in num_links.keys(): 22 | for ttype in num_links.keys(): 23 | statistics[(rtype, ttype)] = {} 24 | h_rel = torch.view_as_real(torch.stack( 25 | [torch.tensor(samples[i]['h'][(rtype, ttype)], dtype=torch.cfloat) for i in range(len(samples))])) 26 | statistics[(rtype, ttype)] = {'mean': h_rel.mean(dim=0), 'var': h_rel.var(dim=0)} 27 | 28 | for k, v in statistics.items(): 29 | for kk, vv in v.items(): 30 | print("statistics[{}][{}].size() = {}".format(k, kk, vv.size())) 31 | return statistics 32 | 33 | 34 | def proc_data(samples, requires_label): 35 | """Preprocesses given samples and creates a dataset.""" 36 | inputs, labels = [], [] 37 | for i in range(len(samples)): 38 | h = copy.deepcopy(samples[i]['h']) 39 | h_temp = {k: torch.view_as_real(torch.as_tensor(v, dtype=torch.cfloat)) for k, v in h.items()} 40 | # A dictionary of 3-D Tensors are flattened to a 1-D Tensor. 41 | input = torch.cat([((h_temp[('siso', 'siso')] - statistics[('siso', 'siso')]['mean']) / statistics[('siso', 'siso')]['var'].sqrt()).flatten(), 42 | ((h_temp[('miso', 'siso')] - statistics[('miso', 'siso')]['mean']) / statistics[('miso', 'siso')]['var'].sqrt()).flatten(), 43 | ((h_temp[('siso', 'miso')] - statistics[('siso', 'miso')]['mean']) / statistics[('siso', 'miso')]['var'].sqrt()).flatten(), 44 | ((h_temp[('miso', 'miso')] - statistics[('miso', 'miso')]['mean']) / statistics[('miso', 'miso')]['var'].sqrt()).flatten()], dim=0) 45 | inputs.append(input) 46 | # Add optimal beamforming vectors from FP if it is required. 47 | if requires_label: 48 | labels.append({'h': {k: torch.as_tensor(v, dtype=torch.cfloat) for k, v in h.items()}, 49 | 'wsr_targ': samples[i]['wsr_targ']}) 50 | else: 51 | labels.append({'h': {k: torch.as_tensor(v, dtype=torch.cfloat) for k, v in h.items()}}) 52 | return list(zip(inputs, labels)) 53 | 54 | 55 | def collate(data): 56 | """collate_fn for training set""" 57 | inputs, labels = map(list, zip(*data)) 58 | inputs = torch.stack(inputs) 59 | # Build a block diagonal channel matrix per relation from all graphs in the batch. 60 | h = {} 61 | c_etypes = [('siso', 'siso'), ('miso', 'siso'), ('siso', 'miso'), ('miso', 'miso')] 62 | for stype, dtype in c_etypes: 63 | # For each relation, put channel matrices from different graphs into a list. 64 | h_rel = [item['h'][(stype, dtype)] for item in labels] 65 | # Construct the block diagonal matrix. 66 | h[(stype, dtype)] = utils.build_diag_block(h_rel) 67 | # Seal reconstructed h and concatenated wsr_targ into batched labels. 68 | if 'wsr_targ' in labels[0].keys(): 69 | labels = {'wsr_targ': [item['wsr_targ'] for item in labels], 70 | 'h': h} 71 | else: 72 | labels = {'h': h} 73 | return inputs, labels 74 | 75 | 76 | def bf_recovery(outputs): 77 | """Recovers the beamforming vectors for each link types from outputs of DNN.""" 78 | x = {} 79 | for i in range(len(output_sizes) - 1): 80 | x[ntypes[i]] = outputs[:, output_sizes[i]:output_sizes[i + 1]].reshape(-1, num_tx_ants[ntypes[i]] * 2) 81 | x_norm = {k: torch.sqrt(v.pow(2).sum(-1, keepdim=True)) for k, v in x.items()} 82 | x = {k: np.sqrt(p_max[k]) * torch.div(v, torch.max(x_norm[k], torch.ones(x_norm[k].size()))).view( 83 | -1, num_tx_ants[k], 2) for k, v in x.items()} 84 | return x 85 | 86 | 87 | def train_per_epoch(): 88 | """Executes one epoch of training.""" 89 | # Set the model to training mode. 90 | model.train() 91 | for i, data in tqdm(enumerate(train_loader)): 92 | # Extract data from train_loader. 93 | inputs, labels = data 94 | # Reset optimizer. 95 | optimizer.zero_grad() 96 | # Pass batched data to model. 97 | outputs = model(inputs) 98 | x = bf_recovery(outputs) 99 | # Compute negative WSR as loss. 100 | loss = utils.weighted_sum_rate(labels['h'], x, 101 | {k: torch.ones(v.size(0), dtype=torch.float) for k, v in x.items()}, 102 | {k: torch.ones(v.size(0), dtype=torch.float) for k, v in x.items()}).neg() 103 | # Call back propagation and execute one step of optimization. 104 | loss.backward() 105 | optimizer.step() 106 | 107 | 108 | def test(): 109 | """Tests performance of trained model.""" 110 | # Set the model to test mode. 111 | model.eval() 112 | wsr_model, wsr_targ = 0., 0. 113 | with torch.no_grad(): 114 | for i, data in enumerate(test_loader): 115 | # Extract data from train_loader. 116 | inputs, labels = data 117 | # Pass batched data to model. 118 | outputs = model(inputs) 119 | x = bf_recovery(outputs) 120 | var_awgn = {k: torch.ones(v.size(0), dtype=torch.float) for k, v in x.items()} 121 | weight = {k: torch.ones(v.size(0), dtype=torch.float) for k, v in x.items()} 122 | # Accumulate utilities (WSR). 123 | wsr_model += utils.weighted_sum_rate(labels['h'], x, var_awgn, weight).item() 124 | wsr_targ += torch.tensor(labels['wsr_targ']).sum().item() 125 | 126 | # Record and display the average performance. 127 | acc.append(wsr_model / wsr_targ) 128 | test_epochs.append(epoch) 129 | if (epoch >= 1) & (wsr_model / wsr_targ == np.max(np.array(acc))): 130 | torch.save(model.state_dict(), model_path) 131 | print("epoch: {}, acc: {:.2%}, model is saved.".format(epoch, wsr_model / wsr_targ)) 132 | else: 133 | print("epoch: {}, acc: {:.2%}.".format(epoch, wsr_model / wsr_targ)) 134 | 135 | 136 | if __name__ == '__main__': 137 | # Read training/test data and specifications. 138 | PATH = './datasets/d2d_12links/' # PATH should be consistent with the one used in `gen_data.py`. 139 | print("Path of datasets: {}".format(PATH)) 140 | dir_trainset = PATH + 'train.pickle' 141 | dir_testset = PATH + 'test.pickle' 142 | dir_specs = PATH + 'specs.pickle' 143 | model_path = 'model_dnn.pth' 144 | 145 | with open(dir_trainset, 'rb') as file: 146 | train_data = pickle.load(file) 147 | with open(dir_testset, 'rb') as file: 148 | test_data = pickle.load(file) 149 | with open(dir_specs, 'rb') as file: 150 | specs = pickle.load(file) 151 | 152 | num_links, num_tx_ants, p_max = specs['num_links'], specs['num_tx_ants'], specs['p_max'] # Parameters used globally 153 | train_data = train_data[:100000] # Select the actual size of training set. 154 | print("Size of training set: {}".format(len(train_data))) 155 | print("Size of test set: {}".format(len(test_data))) 156 | print("Specs: {}".format(specs)) 157 | 158 | # Process data, build datasets and create DataLoaders. 159 | statistics = compute_statistics(train_data) 160 | trainset = proc_data(train_data, requires_label=False) 161 | train_loader = DataLoader(trainset, collate_fn=collate, batch_size=256) 162 | testset = proc_data(test_data, requires_label=True) 163 | test_loader = DataLoader(testset, collate_fn=collate, batch_size=128) 164 | 165 | # Compute the sizes of input/output layers of DNN. 166 | in_size = 0 167 | for rtype in num_links.keys(): 168 | for ttype in num_links.keys(): 169 | in_size += num_links[rtype] * num_links[ttype] * num_tx_ants[ttype] * 2 170 | out_size = {k: num_links[k] * num_tx_ants[k] * 2 for k in num_links.keys()} 171 | out_size = sum(out_size.values()) 172 | # Create an instance of DNN. 173 | model = nn_mods.create_mlp([in_size, 512, 512, 512, out_size]) 174 | # model.load_state_dict(torch.load(model_path)) 175 | print(model) 176 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=5e-4) # Adam optimizer 177 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) # Learning rate decay 178 | 179 | # Compute the indices of entries for different node types, which is later used in bf_recovery. 180 | ntypes = ['siso', 'miso'] 181 | output_sizes = [2 * num_tx_ants[k] * num_links[k] for k in ntypes] 182 | output_sizes.insert(0, 0) 183 | output_sizes = torch.tensor(output_sizes) 184 | for i in range(1, len(output_sizes)): 185 | output_sizes[i:] += output_sizes[i - 1] 186 | 187 | # Training loop. 188 | num_epochs = 100 # Number of epochs 189 | test_epochs, acc = [], [] # Record of test results. 190 | epoch = -1 191 | test() 192 | for epoch in range(num_epochs): 193 | # Train the model. 194 | train_per_epoch() 195 | # Test the model performance after each 5 epochs. 196 | if epoch < 20 or epoch % 5 == 4: 197 | test() 198 | scheduler.step() 199 | -------------------------------------------------------------------------------- /train_hignn.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This script trains an HIGNN for resource allocation (RA) in heterogeneous wireless networks. 3 | """ 4 | 5 | # Import functional tools. 6 | import copy 7 | import pickle 8 | from tqdm import tqdm 9 | # Import packages for mathematics. 10 | import numpy as np 11 | # Import packages for ML. 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.data import DataLoader 15 | import dgl 16 | # Import user-defined modules. 17 | import utils, nn_modules as nn_mods 18 | 19 | 20 | def compute_statistics(samples): 21 | """ 22 | Computes the mean/variance of given samples. 23 | 24 | Arguments: 25 | samples (list): 26 | 27 | Returns: 28 | stat_h_ii (dict): the mean and variance of direct links h_ii by node type 29 | stat_h_ij (dict): the mean and variance of interference links h_ij by relation 30 | """ 31 | num_links, num_tx_ants = samples[0]['num_links'], samples[0]['num_tx_ants'] 32 | stat_h_ii, stat_h_ij = {}, {} 33 | # Statistics are computed for each relation. 34 | # Additionally, since the distance of direct links is different from interfering links, it is treated independently. 35 | for rtype in num_links.keys(): 36 | stat_h_ii[rtype] = {} 37 | for ttype in num_links.keys(): 38 | stat_h_ij[(rtype, ttype)] = {} 39 | 40 | # Obtain the concatenated channel matrix h_rel from h[(rtype, ttype)] of all samples. 41 | h_rel = torch.stack( 42 | [torch.tensor(samples[i]['h'][(rtype, ttype)], dtype=torch.cfloat) for i in range(len(samples))]) 43 | 44 | if rtype == ttype: 45 | # Create a mask to select h_ii. 46 | mask_diag = torch.cat([torch.eye(num_links[rtype]).unsqueeze(-1) for i in range(num_tx_ants[rtype])], 47 | dim=-1) 48 | # Create a mask to select h_ij. 49 | mask_off_diag = torch.ones(num_links[rtype], num_links[rtype], num_tx_ants[rtype]) - mask_diag 50 | # Compute the statistics of h_ii. 51 | stat_h_ii[rtype]['mean'] = torch.view_as_real(h_rel * mask_diag).sum() / ( 52 | len(samples) * num_links[rtype] * num_tx_ants[rtype] * 2) 53 | stat_h_ii[rtype]['var'] = torch.view_as_real((h_rel - stat_h_ii[rtype]['mean']) * mask_diag).pow( 54 | 2).sum() / (len(samples) * num_links[rtype] * num_tx_ants[rtype] * 2) 55 | # Compute the statistics of h_ij. 56 | stat_h_ij[(rtype, ttype)]['mean'] = torch.view_as_real(h_rel * mask_off_diag).sum() / ( 57 | len(samples) * num_links[rtype] * (num_links[ttype] - 1) * num_tx_ants[ttype] * 2) 58 | stat_h_ij[(rtype, ttype)]['var'] = torch.view_as_real( 59 | (h_rel - stat_h_ij[(rtype, rtype)]['mean']) * mask_off_diag).pow(2).sum() / ( 60 | len(samples) * num_links[rtype] * ( 61 | num_links[ttype] - 1) * num_tx_ants[ttype] * 2) 62 | else: 63 | h_rel = torch.view_as_real(h_rel) # a (n_samples, n_rx, n_tx, n_tx_ants, 2) real Tensor 64 | stat_h_ij[(rtype, ttype)]['mean'] = h_rel.mean() 65 | stat_h_ij[(rtype, ttype)]['var'] = (h_rel - h_rel.mean()).pow(2).sum() / h_rel.numel() 66 | return stat_h_ii, stat_h_ij 67 | 68 | 69 | def build_heterograph(sample): 70 | """ 71 | Builds a heterograph describing the IFC from given sample. 72 | 73 | Arguments: 74 | sample (dict): inputs and outputs of the closed-form FP algorithm 75 | 76 | Returns: 77 | a DGLGraph storing the selected features of the sample 78 | """ 79 | 80 | # Extract params from the sample. 81 | num_links, num_tx_ants = sample['num_links'], sample['num_tx_ants'] 82 | p_max, var_awgn, weight = copy.deepcopy(sample['p_max']), copy.deepcopy(sample['var_awgn']), copy.deepcopy(sample['weight']) 83 | h = copy.deepcopy(sample['h']) 84 | for rtype in num_links.keys(): 85 | for ttype in num_links.keys(): 86 | h[(rtype, ttype)] = torch.view_as_real(torch.tensor(h[(rtype, ttype)], dtype=torch.cfloat)) 87 | 88 | h_ii = {} # Channel response of direct links 89 | for ltype in num_links.keys(): 90 | p_max[ltype] = torch.as_tensor(p_max[ltype], dtype=torch.float).unsqueeze(-1) 91 | var_awgn[ltype] = torch.as_tensor(var_awgn[ltype], dtype=torch.float).unsqueeze(-1) 92 | weight[ltype] = torch.as_tensor(weight[ltype], dtype=torch.float).unsqueeze(-1) 93 | h_ii[ltype] = (h[(ltype, ltype)][np.arange(num_links[ltype]), np.arange(num_links[ltype])] - stat_h_ii[ltype]['mean']) / stat_h_ii[ltype]['var'].sqrt() 94 | 95 | graph_data = {} # Source IDs and destination IDs of all edges 96 | h_ij = {} # Channel response of interfering links 97 | for stype in num_links.keys(): 98 | for dtype in num_links.keys(): 99 | graph_data[(stype, '-interfered-by-', dtype)] = ([], []) # Two lists hold the source and destination IDs of the specified edge type. 100 | h_ij[(stype, '-interfered-by-', dtype)] = [] 101 | 102 | # Add an edge if the norm of channel coefficient lie above the threshold. 103 | threshold = 0.05 104 | for stype in num_links.keys(): 105 | for dtype in num_links.keys(): 106 | for i in range(h[(stype, dtype)].size(0)): 107 | for j in range(h[(stype, dtype)].size(1)): 108 | if (np.linalg.norm(h[(stype, dtype)][i, j]) > threshold) & (i != j): 109 | graph_data[(stype, '-interfered-by-', dtype)][0].append(i) 110 | graph_data[(stype, '-interfered-by-', dtype)][1].append(j) 111 | edge_feat = torch.cat([(h[(stype, dtype)][i, j].flatten() - stat_h_ij[(stype, dtype)]['mean']) / stat_h_ij[(stype, dtype)]['var'].sqrt(), 112 | (h[(dtype, stype)][j, i].flatten() - stat_h_ij[(dtype, stype)]['mean']) / stat_h_ij[(dtype, stype)]['var'].sqrt()], dim=0) 113 | h_ij[(stype, '-interfered-by-', dtype)].append(edge_feat) 114 | h_ij[(stype, '-interfered-by-', dtype)] = torch.stack(h_ij[(stype, '-interfered-by-', dtype)]) 115 | 116 | # Build the heterograph from edges defined in graph_data. 117 | g = dgl.heterograph(graph_data, num_nodes_dict=num_links) 118 | # Assign node attributes. 119 | g.ndata['p_max'], g.ndata['weight'], g.ndata['var_awgn'], g.ndata['h'] = p_max, weight, var_awgn, h_ii 120 | # Assign edge attributes. 121 | g.edata['h'] = h_ij 122 | return g 123 | 124 | 125 | def proc_data(samples, requires_label): 126 | """ 127 | builds heterographs from a list of samples and create a dataset. 128 | 129 | Arguments: 130 | samples (list): a list of sapmles holding info using a dict 131 | requires_label (bool): if optimal beamforming vectors are computed 132 | 133 | Returns: 134 | a list of tuples holding (graph, labels) 135 | """ 136 | graphs, labels = [], [] 137 | for i in tqdm(range(len(samples))): 138 | # Build a heterograph from the sample. 139 | graphs.append(build_heterograph(samples[i])) 140 | h_ = {} 141 | for k in samples[i]['h'].keys(): 142 | h_[k] = torch.as_tensor(samples[i]['h'][k], dtype=torch.cfloat) 143 | # Add optimal beamforming vectors from FP if it is required. 144 | if requires_label: 145 | labels.append({'wsr_targ': samples[i]['wsr_targ'], 'h': h_}) 146 | else: 147 | labels.append({'h': h_}) 148 | 149 | # Display the statistics of h_ii and h_ij. 150 | bg = dgl.batch(graphs) 151 | for ntype in bg.ntypes: 152 | print("bg.nodes[{}].data['h'].mean() = {}".format(ntype, bg.nodes[ntype].data['h'].mean())) 153 | print("bg.nodes[{}].data['h'].var() = {}".format(ntype, bg.nodes[ntype].data['h'].var())) 154 | for c_etype in bg.canonical_etypes: 155 | print("bg.edges[{}].data['h'].mean() = {}".format(c_etype, bg.edges[c_etype].data['h'].mean())) 156 | print("bg.edges[{}].data['h'].var() = {}".format(c_etype, bg.edges[c_etype].data['h'].var())) 157 | return list(zip(graphs, labels)) 158 | 159 | 160 | def collate(data): 161 | """collate_fn for training set""" 162 | # Encapsulate a batch of graphs/labels into two lists. 163 | graphs, labels = map(list, zip(*data)) 164 | # Batch a collection of graphs into one graph. 165 | bg = dgl.batch(graphs) 166 | # Build a block diagonal channel matrix per relation from all graphs in the batch. 167 | h = {} 168 | for stype, etype, dtype in bg.canonical_etypes: 169 | # For each relation, put channel matrices from different graphs into a list. 170 | h_rel = [item['h'][(stype, dtype)] for item in labels] 171 | # Construct the block diagonal matrix. 172 | h[(stype, dtype)] = utils.build_diag_block(h_rel) 173 | # Seal reconstructed h and concatenated wsr_targ into batched labels. 174 | bl = {'h': h} 175 | if 'wsr_targ' in labels[0].keys(): 176 | bl['wsr_targ'] = [item['wsr_targ'] for item in labels] 177 | return bg, bl 178 | 179 | 180 | def get_in_feats(g): 181 | """ 182 | Build inputs from node features. 183 | Direct channel response consists of a (2 * num_tx_ants,) real Tensor for each node. 184 | """ 185 | return {ntype: g.nodes[ntype].data['h'].flatten(start_dim=1) for ntype in g.ntypes} 186 | 187 | 188 | def train_per_epoch(): 189 | """Executes one epoch of training.""" 190 | # Set the model to training mode. 191 | model.train() 192 | for i, data in enumerate(train_loader): 193 | # Extract data from train_loader. 194 | batched_graph, batched_labels = data 195 | # Reset optimizer. 196 | optimizer.zero_grad() 197 | # Pass batched data to model. 198 | in_feats = get_in_feats(batched_graph) 199 | outputs = model(batched_graph, in_feats) 200 | # Compute negative WSR as loss. 201 | loss = utils.weighted_sum_rate(batched_labels['h'], outputs, batched_graph.ndata['var_awgn'], batched_graph.ndata['weight']).neg() 202 | # Call back propagation and execute one step of optimization. 203 | loss.backward() 204 | optimizer.step() 205 | 206 | 207 | def test(): 208 | """Tests performance of trained model.""" 209 | # Set the model to test mode. 210 | model.eval() 211 | wsr_model, wsr_targ = 0., 0. 212 | with torch.no_grad(): 213 | for i, data in enumerate(test_loader): 214 | # Extract data from train_loader. 215 | batched_graph, batched_labels = data 216 | # Pass batched data to model. 217 | in_feats = get_in_feats(batched_graph) 218 | outputs = model(batched_graph, in_feats) 219 | # Accumulate utilities (WSR). 220 | wsr_model += utils.weighted_sum_rate(batched_labels['h'], outputs, batched_graph.ndata['var_awgn'], batched_graph.ndata['weight']).item() 221 | wsr_targ += torch.tensor(batched_labels['wsr_targ']).sum().item() 222 | 223 | # Record and display the average performance. 224 | acc.append(wsr_model / wsr_targ) 225 | test_epochs.append(epoch) 226 | if (epoch >= 1) & (wsr_model / wsr_targ == np.max(np.array(acc))): 227 | torch.save(model.state_dict(), model_path) 228 | print("epoch: {}, acc: {:.2%}, model is saved.".format(epoch, wsr_model / wsr_targ)) 229 | else: 230 | print("epoch: {}, acc: {:.2%}.".format(epoch, wsr_model / wsr_targ)) 231 | 232 | 233 | if __name__ == '__main__': 234 | # Read training/test data and specifications. 235 | PATH = './datasets/d2d_12links/' # PATH should be consistent with the one used in `gen_data.py`. 236 | print("Path of datasets: {}".format(PATH)) 237 | dir_trainset = PATH + 'train.pickle' 238 | dir_testset = PATH + 'test.pickle' 239 | dir_specs = PATH + 'specs.pickle' 240 | model_path = 'model_hignn.pth' 241 | 242 | with open(dir_trainset, 'rb') as file: 243 | train_data = pickle.load(file) 244 | with open(dir_testset, 'rb') as file: 245 | test_data = pickle.load(file) 246 | with open(dir_specs, 'rb') as file: 247 | specs = pickle.load(file) 248 | 249 | train_data = train_data[:5000] # Select the actual size of training set. 250 | print("Size of training set: {}".format(len(train_data))) 251 | print("Size of test set: {}".format(len(test_data))) 252 | print("Specs: {}".format(specs)) 253 | 254 | # Process data, build datasets and create DataLoaders. 255 | stat_h_ii, stat_h_ij = compute_statistics(train_data) 256 | print("Test the statistics of processed train_data...") 257 | trainset = proc_data(train_data, requires_label=False) 258 | train_loader = DataLoader(trainset, batch_size=64, collate_fn=collate, shuffle=True) 259 | print("Test the statistics of processed test_data...") 260 | testset = proc_data(test_data, requires_label=True) 261 | test_loader = DataLoader(testset, batch_size=256, collate_fn=collate, shuffle=True) 262 | 263 | # Create an instance of HIGNN. 264 | model = nn_mods.HIGNN(num_tx_ants=specs['num_tx_ants'], p_max=specs['p_max']) 265 | # model.load_state_dict(torch.load(model_path)) 266 | # print("model: \n{}".format(model)) 267 | optimizer = torch.optim.Adam(model.parameters(), lr=0.0005) # Adam optimizer 268 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) # Learning rate decay 269 | 270 | # Training loop. 271 | num_epochs = 200 # Number of epochs 272 | test_epochs, acc = [], [] # Record of test results. 273 | epoch = -1 274 | test() 275 | for epoch in range(num_epochs): 276 | # Train the model. 277 | train_per_epoch() 278 | # Test the model performance after each 5 epochs. 279 | if epoch < 20 or epoch % 5 == 4: 280 | test() 281 | scheduler.step() 282 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | r""" 2 | This script includes functions shared by both `train_hignn.py` and `train_dnn.py`. 3 | """ 4 | 5 | # Import packages for mathematics. 6 | import numpy as np 7 | # Import packages for ML. 8 | import torch 9 | from torch import Tensor 10 | 11 | 12 | def cmp2real(x): 13 | """Converts a complex ndarray to a real one with an additional dimension -1.""" 14 | return np.concatenate([np.expand_dims(x.real, -1), np.expand_dims(x.imag, -1)], axis=-1) 15 | 16 | 17 | def cmp_matmul(a, b): 18 | """Matrix multiplication of two (complex) Tensors.""" 19 | assert isinstance(a, Tensor) 20 | if a.dtype == torch.cfloat: 21 | a_real, a_imag = a.real, a.imag 22 | else: 23 | a_real, a_imag = a[:, :, 0], a[:, :, 1] 24 | b_real, b_imag = b[:, :, 0], b[:, :, 1] 25 | return [a_real * b_real - a_imag * b_imag, a_real * b_imag + a_imag * b_real] 26 | 27 | 28 | def build_diag_block(blk_list): 29 | """Build a block diagonal Tensor from a list of Tensors.""" 30 | if blk_list[0].ndim == 2: 31 | return torch.block_diag(*blk_list) 32 | elif blk_list[0].ndim == 3: 33 | blks = [] 34 | for idx_ant in range(blk_list[0].shape[-1]): 35 | blks_per_ant = [] 36 | for idx_link in range(len(blk_list)): 37 | blks_per_ant.append(blk_list[idx_link][:, :, idx_ant]) 38 | blks.append(torch.block_diag(*blks_per_ant)) 39 | return torch.dstack(blks) 40 | else: 41 | raise Exception("Invalid input dimension") 42 | 43 | 44 | def weighted_sum_rate(h, x, var_awgn, w): 45 | """ 46 | Computes the weighted sum rate from the outputs from models. 47 | Note that this function is different from the one in 'envs.py'. 48 | """ 49 | rx_power = {} 50 | for rtype, ttype in h.keys(): 51 | rx_power[(rtype, ttype)] = cmp_matmul(h[(rtype, ttype)], x[ttype])[0].sum(-1).pow(2) + \ 52 | cmp_matmul(h[(rtype, ttype)], x[ttype])[1].sum(-1).pow(2) 53 | # Compute the interference level at each receiver. 54 | interference = {} 55 | for rtype in x.keys(): 56 | interference[rtype] = - torch.diag(rx_power[(rtype, rtype)]) 57 | for ttype in x.keys(): 58 | interference[rtype] += rx_power[(rtype, ttype)].sum(-1) 59 | # Compute the signal-to-interference-plus-noise ratio (SINR) and achievable rate. 60 | sinr, rate = {}, {} 61 | for rtype in x.keys(): 62 | sinr[rtype] = torch.diag(rx_power[(rtype, rtype)]) / (interference[rtype] + var_awgn[rtype].squeeze()) 63 | rate[rtype] = torch.log(1 + sinr[rtype]) 64 | 65 | # Accumulate the WSR throughout link types. 66 | wsr = 0. # Weighted sum rate 67 | for ltype in w.keys(): 68 | wsr += w[ltype].squeeze().dot(rate[ltype]) 69 | return wsr 70 | --------------------------------------------------------------------------------