├── LICENSE ├── README.md ├── agents └── ppo_agent.py ├── configs ├── general.gin └── ppo.gin ├── environment └── environment.py ├── eval.py ├── grid_search.py ├── lib ├── actor.py ├── critic.py └── run_experiment.py ├── requirements.txt ├── run.py └── utils ├── defo_process_results.py ├── functions.py └── tf_logs.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 BNN-UPC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MARL+GNN for Traffic Engineering Optimization 2 | #### Code of paper "Is Machine Learning Ready for Traffic Engineering Optimization?" 3 | #### To appear as a Main Conference Paper at IEEE ICNP 2021 4 | #### Guillermo Bernárdez, José Suárez-Varela, Albert Lopez, Bo Wu, Shihan Xiao, Xiangle Cheng, Pere Barlet-Ros, Albert Cabellos-Aparicio. 5 | #### Links to paper: [[ArXiv](https://arxiv.org/abs/2109.01445)] 6 | #### Download datasets [here](https://bnn.upc.edu/download/marl-gnn-te_datasets/) 7 | 8 | 9 | ## Abstract 10 | Traffic Engineering (TE) is a basic building block of the Internet. In this paper, we analyze whether modern Machine Learning (ML) methods are ready to be used for TE optimization. We address this open question through a comparative analysis between the state of the art in ML and the state of the art in TE. To this end, we first present a novel distributed system for TE that leverages the latest advancements in ML. Our system implements a novel architecture that combines Multi-Agent Reinforcement Learning (MARL) and Graph Neural Networks (GNN) to minimize network congestion. In our evaluation, we compare our MARL+GNN system with DEFO, a network optimizer based on Constraint Programming that represents the state of the art in TE. Our experimental results show that the proposed MARL+GNN solution achieves equivalent performance to DEFO in a wide variety of network scenarios including three real-world network topologies. At the same time, we show that MARL+GNN can achieve significant reductions in execution time (from the scale of minutes with DEFO to a few seconds with our solution). 11 | -------------------------------------------------------------------------------- /agents/ppo_agent.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | import tensorflow as tf 3 | import tensorflow_probability as tfp 4 | from environment.environment import Environment 5 | from lib.actor import Actor 6 | from lib.critic import Critic 7 | from utils.functions import linearly_decaying_epsilon 8 | from utils.defo_process_results import get_traffic_matrix 9 | import utils.tf_logs as tf_logs 10 | import copy 11 | import numpy as np 12 | import random 13 | import os 14 | import logging 15 | import time 16 | import csv 17 | 18 | import gin.tf 19 | 20 | 21 | @gin.configurable 22 | class PPOAgent(object): 23 | '''An implementation of a GNN-based PPO Agent''' 24 | 25 | def __init__(self, 26 | env, 27 | eval_env_type=['GBN','NSFNet','GEANT2'], 28 | num_eval_samples=10, 29 | clip_param=0.2, 30 | critic_loss_factor=0.5, 31 | entropy_loss_factor=0.001, 32 | normalize_advantages=True, 33 | max_grad_norm=1.0, 34 | gamma=0.99, 35 | gae_lambda=0.95, 36 | horizon=None, 37 | batch_size=25, 38 | epochs=3, 39 | last_training_sample=99, 40 | eval_period=50, 41 | max_evals=100, 42 | select_max_action=False, 43 | optimizer=tf.keras.optimizers.Adam( 44 | learning_rate=0.0003, 45 | beta_1=0.9, 46 | epsilon=0.00001), 47 | change_traffic = False, 48 | change_traffic_period = 1, 49 | base_dir='logs', 50 | checkpoint_base_dir='checkpoints', 51 | save_checkpoints=True): 52 | 53 | self.env = env 54 | self.eval_env_type = eval_env_type 55 | self.num_eval_samples = num_eval_samples 56 | self.clip_param = clip_param 57 | 58 | self._get_actor_critic_functions() 59 | self.num_actions = self.env.n_links 60 | 61 | self.optimizer = optimizer 62 | self.critic_loss_factor = critic_loss_factor 63 | self.entropy_loss_factor = entropy_loss_factor 64 | self.normalize_advantages = normalize_advantages 65 | self.max_grad_norm = max_grad_norm 66 | 67 | self.gamma = gamma 68 | self.gae_lambda = gae_lambda 69 | self.given_horizon = horizon 70 | self.define_horizon() 71 | self.epochs = epochs 72 | self.batch_size = batch_size 73 | self.last_training_sample = last_training_sample 74 | self.eval_period = eval_period 75 | self.max_evals = max_evals 76 | self.select_max_action = select_max_action 77 | self.change_traffic = change_traffic 78 | self.change_traffic_period = change_traffic_period 79 | self.eval_step = 0 80 | self.eval_episode = 0 81 | self.base_dir= base_dir 82 | self.checkpoint_base_dir = checkpoint_base_dir 83 | self.save_checkpoints = save_checkpoints 84 | self.reload_model = False 85 | self.change_sample = False 86 | 87 | def _get_actor_critic_functions(self): 88 | self.actor = Actor(self.env.G, num_features=self.env.num_features) 89 | self.actor.build() 90 | self.critic = Critic(self.env.G, num_features=self.env.num_features) 91 | self.critic.build() 92 | 93 | def define_horizon(self): 94 | if self.given_horizon is not None: 95 | self.horizon = self.given_horizon 96 | elif self.env.network == 'NSFNet': 97 | self.horizon = 100 98 | elif self.env.network == 'GBN': 99 | self.horizon = 150 100 | elif self.env.network == 'GEANT2': 101 | self.horizon = 200 102 | else: 103 | self.horizon = 200 104 | 105 | def reset_env(self): 106 | self.env.reset(change_sample=self.change_sample) 107 | if self.change_sample and len(self.env.env_type) > 1: 108 | actor_model = copy.deepcopy(self.actor.trainable_variables) 109 | critic_model = copy.deepcopy(self.critic.trainable_variables) 110 | self._get_actor_critic_functions() 111 | self.load_model(actor_model, critic_model) 112 | self.define_horizon() 113 | self.change_sample = False 114 | 115 | 116 | def gae_estimation(self, rewards, values, last_value): 117 | last_gae_lambda = 0 118 | advantages = np.zeros_like(values, dtype=np.float32) 119 | for i in reversed(range(self.horizon)): 120 | if i == self.horizon - 1: 121 | next_value = last_value 122 | else: 123 | next_value = values[i+1] 124 | delta = rewards[i] + self.gamma * next_value - values[i] 125 | advantages[i] = last_gae_lambda = delta + self.gamma * self.gae_lambda * last_gae_lambda 126 | returns = values + advantages 127 | if self.normalize_advantages: 128 | advantages = (advantages - np.mean(advantages)) / (np.std(advantages) + 1e-8) 129 | return returns, advantages 130 | 131 | 132 | def run_episode(self): 133 | # reset state at the beginning of each iteration 134 | self.reset_env() 135 | state = self.env.get_state() 136 | states = np.zeros((self.horizon, self.env.n_links*self.actor.num_features), dtype=np.float32) 137 | actions = np.zeros(self.horizon, dtype=np.float32) 138 | rewards = np.zeros(self.horizon, dtype=np.float32) 139 | log_probs = np.zeros(self.horizon, dtype=np.float32) 140 | values = np.zeros(self.horizon, dtype=np.float32) 141 | 142 | for t in range(self.horizon): 143 | action, log_prob = self.act(state) 144 | value = self.run_critic(state) 145 | next_state, reward = self.env.step(action.numpy()) 146 | states[t] = state 147 | actions[t] = action 148 | rewards[t] = reward 149 | log_probs[t] = log_prob 150 | values[t] = value.numpy()[0] 151 | state = next_state 152 | value = self.run_critic(state) 153 | last_value = value.numpy()[0] 154 | 155 | return states, actions, rewards, log_probs, values, last_value 156 | 157 | def run_update(self, training_episode, states, actions, returns, advantages, log_probs): 158 | actor_losses, critic_losses, losses = [], [], [] 159 | inds = np.arange(self.horizon) 160 | for _ in range(self.epochs): 161 | np.random.shuffle(inds) 162 | for start in range(0, self.horizon, self.batch_size): 163 | end = start + self.batch_size 164 | minibatch_ind = inds[start:end] 165 | actor_loss, critic_loss, loss, grads = self.compute_losses_and_grads(states[minibatch_ind], 166 | actions[minibatch_ind], returns[minibatch_ind], 167 | advantages[minibatch_ind], log_probs[minibatch_ind]) 168 | self.apply_grads(grads) 169 | actor_losses.append(actor_loss.numpy()) 170 | critic_losses.append(critic_loss.numpy()) 171 | losses.append(loss.numpy()) 172 | 173 | return actor_losses, critic_losses, losses 174 | 175 | 176 | def train_and_evaluate(self): 177 | training_episode = -1 178 | while not (self.env.num_sample == self.last_training_sample and self.change_sample == True): 179 | training_episode += 1 180 | print('Episode ', training_episode, '...') 181 | states, actions, rewards, log_probs, values, last_value = self.run_episode() 182 | returns, advantages = self.gae_estimation(rewards, values, last_value) 183 | 184 | actor_losses, critic_losses, losses = self.run_update(training_episode, states, actions, returns, advantages, log_probs) 185 | tf_logs.training_episode_logs(self.writer, self.env, training_episode, states, rewards, losses, actor_losses, critic_losses) 186 | 187 | if (training_episode+1) % self.eval_period == 0: 188 | self.training_eval() 189 | if self.save_checkpoints: 190 | self.actor._set_inputs(states[0]) 191 | self.critic._set_inputs(states[0]) 192 | self.save_model(os.path.join(self.checkpoint_dir, 'episode'+str(self.eval_episode))) 193 | if self.change_traffic and self.eval_episode % self.change_traffic_period == 0: 194 | self.change_sample = True 195 | 196 | 197 | def only_evaluate(self): 198 | self.env.initialize_environment(num_sample=self.last_training_sample+1) 199 | for _ in range(self.max_evals): 200 | self.evaluation() 201 | self.change_sample = True 202 | 203 | def generate_eval_env(self): 204 | self.eval_envs = {} 205 | for eval_env_type in self.eval_env_type: 206 | self.eval_envs[eval_env_type] = Environment(env_type=eval_env_type, 207 | traffic_profile=self.env.traffic_profile, 208 | routing=self.env.routing) 209 | 210 | def generate_eval_actor_critic_functions(self): 211 | self.eval_actor = {} 212 | self.eval_critic = {} 213 | for eval_env_type in self.eval_env_type: 214 | self.eval_actor[eval_env_type] = Actor(self.eval_envs[eval_env_type].G, num_features=self.env.num_features) 215 | self.eval_actor[eval_env_type].build() 216 | self.eval_critic[eval_env_type] = Critic(self.eval_envs[eval_env_type].G, num_features=self.env.num_features) 217 | self.eval_critic[eval_env_type].build() 218 | 219 | def update_eval_actor_critic_functions(self): 220 | for eval_env_type in self.eval_env_type: 221 | for w_model, w_eval_actor in zip(self.actor.trainable_variables, 222 | self.eval_actor[eval_env_type].trainable_variables): 223 | w_eval_actor.assign(w_model) 224 | for w_model, w_eval_critic in zip(self.critic.trainable_variables, 225 | self.eval_critic[eval_env_type].trainable_variables): 226 | w_eval_critic.assign(w_model) 227 | 228 | def training_eval(self): 229 | # Evaluation phase 230 | print('\n\tEvaluation ' + str(self.eval_episode) + '...\n') 231 | 232 | if self.eval_episode == 0: 233 | self.generate_eval_env() 234 | self.generate_eval_actor_critic_functions() 235 | 236 | self.update_eval_actor_critic_functions() 237 | 238 | for eval_env_type in self.eval_env_type: 239 | self.eval_envs[eval_env_type].define_num_sample(100) 240 | 241 | total_min_max = [] 242 | mini_eval_episode = self.eval_episode * self.num_eval_samples 243 | for _ in range(self.num_eval_samples): 244 | self.eval_envs[eval_env_type].reset(change_sample=True) 245 | state = self.eval_envs[eval_env_type].get_state() 246 | tf_logs.eval_step_logs(self.writer, self.eval_envs[eval_env_type], self.eval_step, state) 247 | if self.eval_envs[eval_env_type].link_traffic_to_states: 248 | max_link_utilization = [np.max(state[:self.eval_envs[eval_env_type].n_links])] 249 | mean_link_utilization = [np.mean(state[:self.eval_envs[eval_env_type].n_links])] 250 | probs, values = [], [] 251 | 252 | for i in range(self.horizon): 253 | self.eval_step += 1 254 | action, log_prob = self.eval_act(self.eval_actor[eval_env_type], state, select_max=self.select_max_action) 255 | value = self.eval_critic[eval_env_type](state) 256 | next_state, reward = self.eval_envs[eval_env_type].step(action.numpy()) 257 | probs.append(np.exp(log_prob)) 258 | values.append(value.numpy()[0]) 259 | state = next_state 260 | if self.eval_envs[eval_env_type].link_traffic_to_states: 261 | max_link_utilization.append(np.max(state[:self.eval_envs[eval_env_type].n_links])) 262 | #mean_link_utilization.append(np.mean(state[:self.eval_envs[eval_env_type].n_links])) 263 | 264 | tf_logs.eval_step_logs(self.writer, self.eval_envs[eval_env_type], self.eval_step, state) 265 | 266 | if self.env.link_traffic_to_states: 267 | total_min_max.append(np.min(max_link_utilization)) 268 | #tf_logs.eval_final_log(self.writer, mini_eval_episode, max_link_utilization, eval_env_type) 269 | mini_eval_episode += 1 270 | 271 | tf_logs.eval_top_log(self.writer, self.eval_episode, total_min_max, eval_env_type) 272 | self.eval_episode += 1 273 | 274 | 275 | def evaluation(self): 276 | # Evaluation phase 277 | print('\n\tEvaluation ' + str(self.eval_episode) + '...\n') 278 | self.reset_env() 279 | state = self.env.get_state() 280 | tf_logs.eval_step_logs(self.writer, self.env, self.eval_step, state) 281 | if self.env.link_traffic_to_states: 282 | max_link_utilization = [np.max(state[:self.env.n_links])] 283 | mean_link_utilization = [np.mean(state[:self.env.n_links])] 284 | probs, values = [], [] 285 | 286 | for i in range(self.horizon): 287 | self.eval_step += 1 288 | action, log_prob = self.act(state, select_max=self.select_max_action) 289 | value = self.run_critic(state) 290 | next_state, reward = self.env.step(action.numpy()) 291 | probs.append(np.exp(log_prob)) 292 | values.append(value.numpy()[0]) 293 | state = next_state 294 | if self.env.link_traffic_to_states: 295 | max_link_utilization.append(np.max(state[:self.env.n_links])) 296 | mean_link_utilization.append(np.mean(state[:self.env.n_links])) 297 | 298 | tf_logs.eval_step_logs(self.writer, self.env, self.eval_step, state, reward, probs[i], values[i]) 299 | 300 | if self.env.link_traffic_to_states: 301 | tf_logs.eval_final_log(self.writer, self.eval_episode, max_link_utilization, ('+').join(self.env.env_type)) 302 | if self.only_eval: self.write_eval_results(self.eval_episode, np.min(max_link_utilization)) 303 | 304 | #self.eval_step += 10 305 | self.eval_episode += 1 306 | 307 | @tf.function 308 | def compute_actor_loss(self, new_log_probs, old_log_probs, advantages): 309 | ratio = tf.exp(new_log_probs - old_log_probs) 310 | pg_loss_1 = - advantages * ratio 311 | pg_loss_2 = - advantages * tf.clip_by_value(ratio, 1.0 - self.clip_param, 1.0 + self.clip_param) 312 | actor_loss = tf.reduce_mean(tf.maximum(pg_loss_1, pg_loss_2)) 313 | return actor_loss 314 | 315 | @tf.function 316 | def get_new_log_prob_and_entropy(self, state, action): 317 | logits = self.actor(state, training=True) 318 | probs = tfp.distributions.Categorical(logits=logits) 319 | return (probs.log_prob(action), probs.entropy()) 320 | 321 | @tf.function 322 | def compute_losses_and_grads(self, states, actions, returns, advantages, old_log_probs): 323 | with tf.GradientTape(persistent=True) as tape: 324 | new_log_probs, entropy = tf.map_fn(lambda x: self.get_new_log_prob_and_entropy(x[0], x[1]), (states, actions), fn_output_signature=(tf.float32, tf.float32)) 325 | 326 | values = tf.map_fn(lambda x: self.critic(x, training=True), states, fn_output_signature=tf.float32) 327 | values = tf.reshape(values, [-1]) 328 | 329 | critic_loss = tf.reduce_mean(tf.square(returns - values)) 330 | entropy_loss = tf.reduce_mean(entropy) 331 | actor_loss = self.compute_actor_loss(new_log_probs, old_log_probs, advantages) 332 | loss = actor_loss - self.entropy_loss_factor*entropy_loss + self.critic_loss_factor*critic_loss 333 | 334 | grads = tape.gradient(loss, self.actor.trainable_variables+self.critic.trainable_variables) 335 | if self.max_grad_norm is not None: 336 | # Clip the gradients (normalize) 337 | grads, _grad_norm = tf.clip_by_global_norm(grads, self.max_grad_norm) 338 | 339 | return actor_loss, critic_loss, loss, grads 340 | 341 | def apply_grads(self, grads): 342 | self.optimizer.apply_gradients(zip(grads, self.actor.trainable_variables+self.critic.trainable_variables)) 343 | 344 | @tf.function 345 | def act(self, state, select_max=False): 346 | logits = self.actor(state) 347 | probs = tfp.distributions.Categorical(logits=logits) 348 | if select_max: 349 | action = tf.argmax(logits) 350 | else: 351 | action = probs.sample() 352 | 353 | return action, probs.log_prob(action) 354 | 355 | @tf.function 356 | def eval_act(self, actor, state, select_max=False): 357 | logits = actor(state) 358 | probs = tfp.distributions.Categorical(logits=logits) 359 | if select_max: 360 | action = tf.argmax(logits) 361 | else: 362 | action = probs.sample() 363 | 364 | return action, probs.log_prob(action) 365 | 366 | @tf.function 367 | def run_critic(self, state): 368 | return self.critic(state) 369 | 370 | 371 | def save_model(self, checkpoint_dir): 372 | self.actor.save(checkpoint_dir+'/actor') 373 | self.critic.save(checkpoint_dir+'/critic') 374 | 375 | 376 | def load_model(self, actor_model, critic_model): 377 | for w_model, w_actor in zip(actor_model, 378 | self.actor.trainable_variables): 379 | w_actor.assign(w_model) 380 | for w_model, w_critic in zip(critic_model, 381 | self.critic.trainable_variables): 382 | w_critic.assign(w_model) 383 | 384 | 385 | def load_saved_model(self, model_dir, only_eval): 386 | model = keras.models.load_model(model_dir+'/actor') 387 | for w_model, w_actor in zip(model.trainable_variables, 388 | self.actor.trainable_variables): 389 | w_actor.assign(w_model) 390 | if not only_eval: 391 | model = keras.models.load_model(model_dir+'/critic') 392 | for w_model, w_critic in zip(model.trainable_variables, 393 | self.critic.trainable_variables): 394 | w_critic.assign(w_model) 395 | self.model_dir = model_dir 396 | self.reload_model = True 397 | 398 | def write_eval_results(self, step, value): 399 | csv_dir = os.path.join('./notebooks/logs', self.experiment_identifier) 400 | if not os.path.exists(csv_dir): 401 | os.makedirs(csv_dir) 402 | with open(csv_dir+'/results.csv', "a") as csv_file: 403 | writer = csv.writer(csv_file, delimiter=',') 404 | writer.writerow([step,value]) 405 | 406 | def set_experiment_identifier(self, only_eval): 407 | self.only_eval = only_eval 408 | mode = 'eval' if only_eval else 'training' 409 | 410 | if mode == 'training': 411 | #ENVIRONMENT 412 | network = '+'.join([str(elem) for elem in self.env.env_type]) 413 | traffic_profile = self.env.traffic_profile 414 | routing = self.env.routing 415 | env_folder = ('-').join([network,traffic_profile,routing]) 416 | 417 | #PPOAGENT 418 | batch = 'batch'+str(self.batch_size) 419 | gae_lambda = 'gae'+str(self.gae_lambda) 420 | lr = 'lr'+str(self.optimizer.get_config()['learning_rate']) 421 | epsilon = 'epsilon'+str(self.optimizer.epsilon) 422 | clip = 'clip'+str(self.clip_param) 423 | gamma = 'gamma'+str(self.gamma) 424 | period = 'period'+str(self.eval_period) 425 | epoch = 'epoch'+str(self.epochs) 426 | agent_folder = ('-').join([batch,lr,epsilon,gae_lambda,clip,gamma,period,epoch]) 427 | 428 | #ACTOR-CRITIC 429 | state_size = 'size'+str(self.actor.link_state_size) 430 | iters = 'iters'+str(self.actor.message_iterations) 431 | aggregation = self.actor.aggregation 432 | nn_size = 'nnsize'+str(self.actor.final_hidden_layer_size) 433 | dropout = 'drop'+str(self.actor.dropout_rate) 434 | activation = self.actor.activation_fn 435 | function_folder = ('-').join([state_size,iters,aggregation,nn_size,dropout,activation]) 436 | 437 | self.experiment_identifier = os.path.join(mode, env_folder, agent_folder, function_folder) 438 | 439 | else: 440 | model_dir = self.model_dir 441 | 442 | network = '+'.join([str(elem) for elem in self.env.env_type]) 443 | traffic_profile = self.env.traffic_profile 444 | routing = self.env.routing 445 | eval_env_folder = ('-').join([network,traffic_profile,routing]) 446 | 447 | #RELOADED MODEL 448 | env_folder = os.path.join(model_dir.split('/')[3]) 449 | agent_folder = os.path.join(model_dir.split('/')[4]) 450 | function_folder = os.path.join(model_dir.split('/')[5]) 451 | episode = os.path.join(model_dir.split('/')[6]) 452 | 453 | self.experiment_identifier = os.path.join(mode, eval_env_folder, env_folder, agent_folder, function_folder, episode) 454 | 455 | return self.experiment_identifier 456 | 457 | 458 | def set_writer_and_checkpoint_dir(self, writer_dir, checkpoint_dir): 459 | self.writer_dir = writer_dir 460 | self.checkpoint_dir = checkpoint_dir 461 | self.writer = tf.summary.create_file_writer(self.writer_dir) -------------------------------------------------------------------------------- /configs/general.gin: -------------------------------------------------------------------------------- 1 | import lib.run_experiment 2 | import environment.environment 3 | import gin.tf.external_configurables 4 | 5 | Runner.algorithm = 'PPO' 6 | Runner.reload_model = False 7 | Runner.model_dir = 'checkpoints/training/gravity_1/NSFNet+GEANT2/PPO_ecmp_agg_period100/clip0.25/gamma0.97/episode41' 8 | Runner.only_eval = False 9 | Runner.save_checkpoints = True 10 | 11 | Environment.env_type = 'NSFNet+GEANT2' 12 | Environment.traffic_profile = 'gravity_1' 13 | Environment.routing = 'sp' 14 | Environment.probs_to_states = False 15 | Environment.seed_init_weights = 1 -------------------------------------------------------------------------------- /configs/ppo.gin: -------------------------------------------------------------------------------- 1 | import agents.ppo_agent 2 | import lib.actor 3 | import lib.critic 4 | import gin.tf.external_configurables 5 | 6 | PPOAgent.gamma = 0.95 7 | PPOAgent.clip_param = 0.2 8 | PPOAgent.batch_size = 25 9 | PPOAgent.select_max_action = False 10 | PPOAgent.epochs = 3 11 | PPOAgent.gae_lambda = 0.9 12 | PPOAgent.horizon = None 13 | PPOAgent.eval_period = 50 14 | PPOAgent.change_traffic = True 15 | PPOAgent.change_traffic_period = 1 16 | PPOAgent.last_training_sample = 50 17 | PPOAgent.max_evals = 50 18 | PPOAgent.eval_env_type = ['GBN','NSFNet','GEANT2'] 19 | PPOAgent.num_eval_samples = 10 20 | 21 | PPOAgent.critic_loss_factor=0.5 22 | PPOAgent.entropy_loss_factor=0.001 23 | PPOAgent.normalize_advantages=True 24 | PPOAgent.max_grad_norm=1.0 25 | PPOAgent.optimizer = @tf.keras.optimizers.Adam() 26 | tf.keras.optimizers.Adam.learning_rate=0.0003 27 | tf.keras.optimizers.Adam.beta_1=0.9 28 | tf.keras.optimizers.Adam.epsilon=0.1 29 | 30 | Actor.link_state_size = 16 31 | Actor.aggregation = 'min_max' 32 | Actor.first_hidden_layer_size = 128 33 | Actor.dropout_rate = 0.15 34 | Actor.final_hidden_layer_size = 64 35 | Actor.message_iterations = 8 36 | Actor.activation_fn = 'tanh' 37 | 38 | Critic.link_state_size = 16 39 | Critic.aggregation = 'min_max' 40 | Critic.first_hidden_layer_size = 128 41 | Critic.dropout_rate = 0.15 42 | Critic.final_hidden_layer_size = 64 43 | Critic.message_iterations = 8 44 | Critic.activation_fn = 'tanh' -------------------------------------------------------------------------------- /environment/environment.py: -------------------------------------------------------------------------------- 1 | from utils.functions import pairwise_iteration 2 | 3 | import networkx as nx 4 | import numpy as np 5 | import copy 6 | import os 7 | import gin.tf 8 | 9 | 10 | DEFAULT_EDGE_ATTRIBUTES = { 11 | 'increments': 1, 12 | 'reductions': 1, 13 | 'weight': 0.0, 14 | 'traffic': 0.0 15 | } 16 | 17 | 18 | @gin.configurable 19 | class Environment(object): 20 | 21 | def __init__(self, 22 | env_type='NSFNet', 23 | traffic_profile='uniform', 24 | routing='ecmp', 25 | init_sample=0, 26 | seed_init_weights=1, 27 | min_weight=1.0, 28 | max_weight=4.0, 29 | weight_change=1.0, 30 | weight_update='sum', 31 | weigths_to_states=True, 32 | link_traffic_to_states=True, 33 | probs_to_states=False, 34 | reward_magnitude='link_traffic', 35 | base_reward='min_max', 36 | reward_computation='change', 37 | base_dir='datasets'): 38 | 39 | env_type = [env for env in env_type.split('+')]#env_type if type(env_type) == list else [env_type] 40 | self.env_type = env_type 41 | self.traffic_profile = traffic_profile 42 | self.routing = routing 43 | 44 | self.num_sample = init_sample-1 45 | self.seed_init_weights = seed_init_weights 46 | self.min_weight = min_weight 47 | self.max_weight = max_weight 48 | self.weight_change = weight_change 49 | self.weight_update = weight_update 50 | 51 | num_features = 0 52 | self.weigths_to_states = weigths_to_states 53 | if self.weigths_to_states: num_features += 1 54 | self.link_traffic_to_states = link_traffic_to_states 55 | if self.link_traffic_to_states: num_features += 1 56 | self.probs_to_states = probs_to_states 57 | if self.probs_to_states: num_features += 2 58 | self.num_features = num_features 59 | self.reward_magnitude = reward_magnitude 60 | self.base_reward = base_reward 61 | self.reward_computation = reward_computation 62 | 63 | self.base_dir = base_dir 64 | self.dataset_dirs = [] 65 | for env in env_type: 66 | self.dataset_dirs.append(os.path.join(base_dir, env, traffic_profile)) 67 | 68 | self.initialize_environment() 69 | self.get_weights() 70 | self._generate_routing() 71 | self._get_link_traffic() 72 | self.reward_measure = self.compute_reward_measure() 73 | self.set_target_measure() 74 | 75 | 76 | def load_topology_object(self): 77 | try: 78 | nx_file = os.path.join(self.base_dir, self.network, 'graph_attr.txt') 79 | self.topology_object = nx.DiGraph(nx.read_gml(nx_file, destringizer=int)) 80 | except: 81 | self.topology_object = nx.DiGraph() 82 | capacity_file = os.path.join(self.dataset_dir, 'capacities', 'graph.txt') 83 | with open(capacity_file) as fd: 84 | for line in fd: 85 | if 'Link_' in line: 86 | camps = line.split(" ") 87 | self.topology_object.add_edge(int(camps[1]),int(camps[2])) 88 | self.topology_object[int(camps[1])][int(camps[2])]['bandwidth'] = int(camps[4]) 89 | 90 | def load_capacities(self): 91 | if self.traffic_profile == 'gravity_full': 92 | capacity_file = os.path.join(self.dataset_dir, 'capacities', 'graph-TM-'+str(self.num_sample)+'.txt') 93 | else: 94 | capacity_file = os.path.join(self.dataset_dir, 'capacities', 'graph.txt') 95 | with open(capacity_file) as fd: 96 | for line in fd: 97 | if 'Link_' in line: 98 | camps = line.split(" ") 99 | self.G[int(camps[1])][int(camps[2])]['capacity'] = int(camps[4]) 100 | 101 | def load_traffic_matrix(self): 102 | tm_file = os.path.join(self.dataset_dir, 'TM', 'TM-'+str(self.num_sample)) 103 | self.traffic_demand = np.zeros((self.n_nodes,self.n_nodes)) 104 | with open(tm_file) as fd: 105 | fd.readline() 106 | fd.readline() 107 | for line in fd: 108 | camps = line.split(" ") 109 | self.traffic_demand[int(camps[1]),int(camps[2])] = float(camps[3]) 110 | self.get_link_probs() 111 | 112 | def initialize_environment(self, num_sample=None, random_env=True): 113 | if num_sample is not None: 114 | self.num_sample = num_sample 115 | else: 116 | self.num_sample += 1 117 | if random_env: 118 | num_env = np.random.randint(0,len(self.env_type)) 119 | else: 120 | num_env = self.num_sample % len(self.env_type) 121 | self.network = self.env_type[num_env] 122 | self.dataset_dir = self.dataset_dirs[num_env] 123 | 124 | self.load_topology_object() 125 | self.generate_graph() 126 | self.load_capacities() 127 | self.load_traffic_matrix() 128 | 129 | def next_sample(self): 130 | if len(self.env_type) > 1: 131 | self.initialize_environment() 132 | else: 133 | self.num_sample += 1 134 | self._reset_edge_attributes() 135 | self.load_capacities() 136 | self.load_traffic_matrix() 137 | 138 | def define_num_sample(self, num_sample): 139 | self.num_sample = num_sample - 1 140 | 141 | def reset(self, change_sample=False): 142 | if change_sample: 143 | self.next_sample() 144 | else: 145 | if self.seed_init_weights is None: self._define_init_weights() 146 | self._reset_edge_attributes() 147 | self.get_weights() 148 | self._generate_routing() 149 | self._get_link_traffic() 150 | self.reward_measure = self.compute_reward_measure() 151 | self.set_target_measure() 152 | 153 | 154 | def generate_graph(self): 155 | G = copy.deepcopy(self.topology_object) 156 | self.n_nodes = G.number_of_nodes() 157 | self.n_links = G.number_of_edges() 158 | self._define_init_weights() 159 | idx = 0 160 | link_ids_dict = {} 161 | for (i,j) in G.edges(): 162 | G[i][j]['id'] = idx 163 | G[i][j]['increments'] = 1 164 | G[i][j]['reductions'] = 1 165 | G[i][j]['weight'] = copy.deepcopy(self.init_weights[idx]) 166 | link_ids_dict[idx] = (i,j) 167 | G[i][j]['capacity'] = G[i][j]['bandwidth'] 168 | G[i][j]['traffic'] = 0.0 169 | idx += 1 170 | self.G = G 171 | incoming_links, outcoming_links = self._generate_link_indices_and_adjacencies() 172 | self.G.add_node('graph_data', link_ids_dict=link_ids_dict, incoming_links=incoming_links, outcoming_links=outcoming_links) 173 | 174 | 175 | def set_target_measure(self): 176 | self.target_sp_routing = copy.deepcopy(self.sp_routing) 177 | self.target_reward_measure = copy.deepcopy(self.reward_measure) 178 | self.target_link_traffic = copy.deepcopy(self.link_traffic) 179 | self.get_weights() 180 | self.target_weights = copy.deepcopy(self.raw_weights) 181 | 182 | 183 | def get_weights(self, normalize=True): 184 | weights = [0.0]*self.n_links 185 | for i,j in self.G.edges(): 186 | weights[self.G[i][j]['id']] = copy.deepcopy(self.G[i][j]['weight']) 187 | self.raw_weights = weights 188 | max_weight = self.max_weight*3 189 | self.weights = [weight/max_weight for weight in weights] 190 | 191 | def get_state(self): 192 | state = [] 193 | link_traffic = copy.deepcopy(self.link_traffic) 194 | weights = copy.deepcopy(self.weights) 195 | if self.link_traffic: 196 | state += link_traffic 197 | if self.weigths_to_states: 198 | state += weights 199 | if self.probs_to_states: 200 | state += self.p_in 201 | state += self.p_out 202 | return np.array(state, dtype=np.float32) 203 | 204 | def define_weight(self, link, weight): 205 | i, j = link 206 | self.G[i][j]['weight'] = weight 207 | self._generate_routing() 208 | self._get_link_traffic() 209 | 210 | def update_weights(self, link, action_value, step_back=False): 211 | i, j = link 212 | if self.weight_update == 'min_max': 213 | if action_value == 0: 214 | self.G[i][j]['weight'] = max(self.G[i][j]['weight']-self.weight_change, self.min_weight) 215 | elif action_value == 1: 216 | self.G[i][j]['weight'] = min(self.G[i][j]['weight']+self.weight_change, self.max_weight) 217 | else: 218 | if self.weight_update == 'increment_reduction': 219 | if action_value == 0: 220 | self.G[i][j]['reductions'] += 1 221 | elif action_value == 1: 222 | self.G[i][j]['increments'] += 1 223 | self.G[i][j]['weight'] = self.G[i][j]['increments'] / self.G[i][j]['reductions'] 224 | elif self.weight_update == 'sum': 225 | if step_back: 226 | self.G[i][j]['weight'] -= self.weight_change 227 | else: 228 | self.G[i][j]['weight'] += self.weight_change 229 | 230 | def reinitialize_weights(self, seed_init_weights=-1, min_weight=None, max_weight=None): 231 | if seed_init_weights != -1: 232 | self.seed_init_weights = seed_init_weights 233 | if min_weight: self.min_weight = min_weight 234 | if max_weight: self.max_weight = max_weight 235 | 236 | self.generate_graph() 237 | self.get_weights() 238 | self._generate_routing() 239 | self._get_link_traffic() 240 | 241 | def reinitialize_routing(self, routing): 242 | self.routing = routing 243 | self._get_link_traffic() 244 | 245 | def step(self, action, step_back=False): 246 | #link_id, action_value = action 247 | link = self.G.nodes()['graph_data']['link_ids_dict'][action] 248 | #self.update_weights(link, action_value, step_back) 249 | self.update_weights(link, 0, step_back) 250 | self.get_weights() 251 | self._generate_routing() 252 | self._get_link_traffic() 253 | state = self.get_state() 254 | reward = self._compute_reward() 255 | return state, reward 256 | 257 | def step_back(self, action): 258 | state, reward = self.step(action, step_back=True) 259 | return state, reward 260 | 261 | 262 | # in the q_function we want to use info on the complete path (src_node, next_hop, n3, n4, ..., dst_node) 263 | # this function returns the indices of links in the path 264 | def get_complete_link_path(self, node_path): 265 | link_path = [] 266 | for i, j in pairwise_iteration(node_path): 267 | link_path.append(self.G[i][j]['id']) 268 | # pad the path until "max_length" (implementation is easier if all paths have same size) 269 | link_path = link_path + ([-1] * (self.n_links-len(link_path))) 270 | return link_path 271 | 272 | 273 | 274 | """ 275 | **************************************************************************** 276 | PRIVATE FUNCTIONS OF THE ENVIRONMENT CLASS 277 | **************************************************************************** 278 | """ 279 | 280 | def _define_init_weights(self): 281 | np.random.seed(seed=self.seed_init_weights) 282 | self.init_weights = np.random.randint(self.min_weight,self.max_weight+1,self.n_links) 283 | np.random.seed(seed=None) 284 | 285 | 286 | # generates indices for links in the network 287 | def _generate_link_indices_and_adjacencies(self): 288 | # for the q_function, we want to have info on link-link connection points 289 | # there is a link-link connection between link A and link B if link A 290 | # is an incoming link of node C and link B is an outcoming node of node C. 291 | # For connection "i", the incoming link is incoming_links[i] and the 292 | # outcoming link is outcoming_links[i] 293 | incoming_links = [] 294 | outcoming_links = [] 295 | # iterate through all links 296 | for i in self.G.nodes(): 297 | for j in self.G.neighbors(i): 298 | incoming_link_id = self.G[i][j]['id'] 299 | # for each link, search its outcoming links 300 | for k in self.G.neighbors(j): 301 | outcoming_link_id = self.G[j][k]['id'] 302 | incoming_links.append(incoming_link_id) 303 | outcoming_links.append(outcoming_link_id) 304 | 305 | return incoming_links, outcoming_links 306 | 307 | def _reset_edge_attributes(self, attributes=None): 308 | if attributes is None: 309 | attributes = list(DEFAULT_EDGE_ATTRIBUTES.keys()) 310 | if type(attributes) != list: attributes = [attributes] 311 | for (i,j) in self.G.edges(): 312 | for attribute in attributes: 313 | if attribute == 'weight': 314 | self.G[i][j][attribute] = copy.deepcopy(self.init_weights[self.G[i][j]['id']]) 315 | else: 316 | self.G[i][j][attribute] = copy.deepcopy(DEFAULT_EDGE_ATTRIBUTES[attribute]) 317 | 318 | def _normalize_traffic(self): 319 | for (i,j) in self.G.edges(): 320 | self.G[i][j]['traffic'] /= self.G[i][j]['capacity'] 321 | 322 | def _generate_routing(self, next_hop=None): 323 | self.sp_routing = dict(nx.all_pairs_dijkstra_path(self.G)) 324 | #self.path_lengths = dict(nx.all_pairs_dijkstra_path_length(self.G)) 325 | 326 | 327 | def successive_equal_cost_multipaths(self, src, dst, traffic): 328 | new_srcs = self.next_hop_dict[src][dst] 329 | traffic /= len(new_srcs) 330 | for new_src in new_srcs: 331 | self.G[src][new_src]['traffic'] += traffic 332 | if new_src != dst: 333 | self.successive_equal_cost_multipaths(new_src, dst, traffic) 334 | 335 | 336 | # returns a list of traffic volumes of each link 337 | def _distribute_link_traffic(self, routing=None): 338 | self._reset_edge_attributes('traffic') 339 | if self.routing == 'sp': 340 | if routing is None: routing = self.sp_routing 341 | for i in self.G.nodes(): 342 | if i=='graph_data': continue 343 | for j in self.G.nodes(): 344 | if j=='graph_data' or i == j: continue 345 | traffic = self.traffic_demand[i][j] 346 | for u,v in pairwise_iteration(routing[i][j]): 347 | self.G[u][v]['traffic'] += traffic 348 | elif self.routing == 'ecmp': 349 | visited_pairs = set() 350 | self.next_hop_dict = {i : {j : set() for j in range(self.G.number_of_nodes()-1) if j != i} for i in range(self.G.number_of_nodes()-1)} 351 | for src in range(self.G.number_of_nodes()-1): 352 | for dst in range(self.G.number_of_nodes()-1): 353 | if src == dst: continue 354 | if (src,dst) not in visited_pairs: 355 | routings = set([item for sublist in [[(routing[i],routing[i+1]) for i in range(len(routing)-1)] for routing in nx.all_shortest_paths(self.G, src, dst, 'weight')] for item in sublist]) 356 | for (new_src,next_hop) in routings: 357 | self.next_hop_dict[new_src][dst].add(next_hop) 358 | visited_pairs.add((new_src,dst)) 359 | traffic = self.traffic_demand[src][dst] 360 | self.successive_equal_cost_multipaths(src, dst, traffic) 361 | 362 | self._normalize_traffic() 363 | 364 | def _get_link_traffic(self, routing=None): 365 | self._distribute_link_traffic(routing) 366 | link_traffic = [0]*self.n_links 367 | for i,j in self.G.edges(): 368 | link_traffic[self.G[i][j]['id']] = self.G[i][j]['traffic'] 369 | self.link_traffic = link_traffic 370 | self.mean_traffic = np.mean(link_traffic) 371 | self.get_weights() 372 | 373 | def get_link_traffic(self): 374 | link_traffic = [0]*self.n_links 375 | for i,j in self.G.edges(): 376 | link_traffic[self.G[i][j]['id']] = self.G[i][j]['traffic'] 377 | return link_traffic 378 | 379 | def get_link_probs(self): 380 | traffic_in = np.sum(self.traffic_demand, axis=0) 381 | traffic_out = np.sum(self.traffic_demand, axis=1) 382 | node_p_in = list(traffic_in / np.sum(traffic_in)) 383 | node_p_out = list(traffic_out / np.sum(traffic_out)) 384 | self.p_in = [0]*self.n_links 385 | self.p_out = [0]*self.n_links 386 | for i,j in self.G.edges(): 387 | self.p_in[self.G[i][j]['id']] = node_p_out[i] 388 | self.p_out[self.G[i][j]['id']] = node_p_in[j] 389 | 390 | # reward function is currently quite simple 391 | def compute_reward_measure(self, measure=None): 392 | if measure is None: 393 | if self.reward_magnitude == 'link_traffic': 394 | measure = self.link_traffic 395 | elif self.reward_magnitude == 'weights': 396 | measure = self.raw_weights 397 | 398 | if self.base_reward == 'mean_times_std': 399 | return np.mean(measure) * np.std(measure) 400 | elif self.base_reward == 'mean': 401 | return np.mean(measure) 402 | elif self.base_reward == 'std': 403 | return np.std(measure) 404 | elif self.base_reward == 'diff_min_max': 405 | return np.max(measure) - np.min(measure) 406 | elif self.base_reward == 'min_max': 407 | return np.max(measure) 408 | 409 | def _compute_reward(self, current_reward_measure=None): 410 | if current_reward_measure is None: 411 | current_reward_measure = self.compute_reward_measure() 412 | 413 | if self.reward_computation == 'value': 414 | reward = - current_reward_measure 415 | elif self.reward_computation == 'change': 416 | reward = self.reward_measure - current_reward_measure 417 | 418 | self.reward_measure = current_reward_measure 419 | 420 | return reward 421 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import os 4 | 5 | #ENVIRONMENT 6 | networks=['TopologyZoo/Bandcon','TopologyZoo/Ans','TopologyZoo/Aarnet','TopologyZoo/Spiralight','TopologyZoo/Cesnet2001','TopologyZoo/HostwayInternational','TopologyZoo/Rhnet','TopologyZoo/Integra','TopologyZoo/Marnet','TopologyZoo/Darkstrand','TopologyZoo/Packetexchange','TopologyZoo/GtsRomania','TopologyZoo/Noel','TopologyZoo/Restena','TopologyZoo/GtsHungary','TopologyZoo/Aconet','TopologyZoo/Amres','TopologyZoo/Arpanet19719','TopologyZoo/Arpanet19728','TopologyZoo/Garr199905','TopologyZoo/Psinet','TopologyZoo/Nsfnet','TopologyZoo/HurricaneElectric','TopologyZoo/HiberniaUs','TopologyZoo/Garr199904','TopologyZoo/Arn','TopologyZoo/Oxford','TopologyZoo/Abvt','TopologyZoo/Twaren','TopologyZoo/Renater1999','TopologyZoo/Xeex','TopologyZoo/Renater2004','TopologyZoo/Vinaren','TopologyZoo/Ilan','TopologyZoo/VisionNet','TopologyZoo/Sago','TopologyZoo/Ibm','TopologyZoo/Fatman','TopologyZoo/EliBackbone','TopologyZoo/Garr200404','TopologyZoo/KentmanJul2005','TopologyZoo/Shentel','TopologyZoo/WideJpn','TopologyZoo/Cynet','TopologyZoo/Nordu2010','TopologyZoo/Navigata','TopologyZoo/Claranet','TopologyZoo/Biznet','TopologyZoo/BtEurope','TopologyZoo/Arpanet19723','TopologyZoo/Quest','TopologyZoo/Gambia','TopologyZoo/Garr200112','TopologyZoo/Cesnet200304','TopologyZoo/Geant2001','TopologyZoo/KentmanAug2005','TopologyZoo/Peer1','TopologyZoo/Bbnplanet','TopologyZoo/Garr200109','TopologyZoo/Istar','TopologyZoo/Ernet','TopologyZoo/Jgn2Plus','TopologyZoo/Savvis','TopologyZoo/Janetbackbone','TopologyZoo/Agis','TopologyZoo/Uran','TopologyZoo/BtAsiaPac','TopologyZoo/HiberniaUk','TopologyZoo/Sprint','TopologyZoo/Grena','TopologyZoo/Compuserve','TopologyZoo/Atmnet','TopologyZoo/York','TopologyZoo/Goodnet','TopologyZoo/Renater2001'] 7 | traffics=['uniform'] 8 | routings=['ecmp'] 9 | model_dirs=['./checkpoints/training/NSFNet+GEANT2-uniform_paul-ecmp/batch25-lr0.0003-epsilon0.1-gae0.9-clip0.3-gamma0.97-period50-epoch3/size16-iters8-min_max-nnsize64-drop0.25-selu/episode57'] 10 | #model_dirs=['./checkpoints/training/NSFNet+GEANT2-uniform_paul-ecmp/batch25-lr0.0003-epsilon0.1-gae0.9-clip0.3-gamma0.97-period50-epoch3/size16-iters8-min_max-nnsize64-drop0.25-tanh/episode35'] 11 | 12 | #PPOAGENT 13 | batches = [25] 14 | gae_lambdas = [0.9] 15 | lrs = [0.0003] 16 | epsilons = [0.1] 17 | periods=[50] 18 | gammas=[0.95] 19 | clips=[0.2] 20 | epochs=[3] 21 | 22 | #ACTOR-CRITIC 23 | link_state_sizes=[16] 24 | message_iterations=[8] 25 | aggregations=['min_max'] 26 | nn_sizes = [[64,128]] 27 | dropouts = [0.15] 28 | activations = ['selu'] 29 | 30 | 31 | for activation in activations: 32 | for dropout in dropouts: 33 | for nn_size in nn_sizes: 34 | for aggregation in aggregations: 35 | for message_iteration in message_iterations: 36 | for link_state_size in link_state_sizes: 37 | for routing in routings: 38 | for model_dir in model_dirs: 39 | for traffic in traffics: 40 | for network in networks: 41 | for period in periods: 42 | for gamma in gammas: 43 | for clip in clips: 44 | for epoch in epochs: 45 | for batch in batches: 46 | for gae_lambda in gae_lambdas: 47 | for lr in lrs: 48 | for epsilon in epsilons: 49 | activation = 'selu' if 'selu' in model_dir else 'tanh' 50 | dropout = 0.15 if 'drop0.15' in model_dir else 0.25 51 | cmd = "python ./run.py" \ 52 | " --gin_bindings='Runner.model_dir = \""+model_dir+"\"'" \ 53 | " --gin_bindings='Runner.reload_model = True'" \ 54 | " --gin_bindings='Runner.only_eval = True'" \ 55 | " --gin_bindings='Runner.save_checkpoints = False'" \ 56 | " --gin_bindings='Environment.env_type = \""+network+"\"'" \ 57 | " --gin_bindings='Environment.traffic_profile = \""+traffic+"\"'" \ 58 | " --gin_bindings='Environment.routing = \""+routing+"\"'" \ 59 | " --gin_bindings='PPOAgent.eval_period = "+str(period)+"'" \ 60 | " --gin_bindings='PPOAgent.gamma = "+str(gamma)+"'" \ 61 | " --gin_bindings='PPOAgent.clip_param = "+str(clip)+"'" \ 62 | " --gin_bindings='PPOAgent.epochs = "+str(epoch)+"'" \ 63 | " --gin_bindings='PPOAgent.batch_size = "+str(batch)+"'" \ 64 | " --gin_bindings='PPOAgent.gae_lambda = "+str(gae_lambda)+"'" \ 65 | " --gin_bindings='tf.keras.optimizers.Adam.learning_rate = "+str(lr)+"'" \ 66 | " --gin_bindings='tf.keras.optimizers.Adam.epsilon = "+str(epsilon)+"'" \ 67 | " --gin_bindings='Actor.link_state_size = "+str(link_state_size)+"'" \ 68 | " --gin_bindings='Actor.aggregation = \""+aggregation+"\"'" \ 69 | " --gin_bindings='Actor.first_hidden_layer_size = "+str(nn_size[1])+"'" \ 70 | " --gin_bindings='Actor.final_hidden_layer_size = "+str(nn_size[0])+"'" \ 71 | " --gin_bindings='Actor.dropout_rate = "+str(dropout)+"'" \ 72 | " --gin_bindings='Actor.message_iterations = "+str(message_iteration)+"'" \ 73 | " --gin_bindings='Actor.activation_fn = \""+activation+"\"'" \ 74 | " --gin_bindings='Critic.link_state_size = "+str(link_state_size)+"'" \ 75 | " --gin_bindings='Critic.aggregation = \""+aggregation+"\"'" \ 76 | " --gin_bindings='Critic.first_hidden_layer_size = "+str(nn_size[1])+"'" \ 77 | " --gin_bindings='Critic.final_hidden_layer_size = "+str(nn_size[0])+"'" \ 78 | " --gin_bindings='Critic.dropout_rate = "+str(dropout)+"'" \ 79 | " --gin_bindings='Critic.message_iterations = "+str(message_iteration)+"'" \ 80 | " --gin_bindings='Critic.activation_fn = \""+activation+"\"' &" 81 | 82 | os.system(cmd) -------------------------------------------------------------------------------- /grid_search.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import os 4 | 5 | #ENVIRONMENT 6 | networks=['NSFNet+GEANT2'] #['TopologyZoo/Bandcon','TopologyZoo/Ans','TopologyZoo/Aarnet','TopologyZoo/Spiralight','TopologyZoo/Cesnet2001','TopologyZoo/HostwayInternational','TopologyZoo/Rhnet','TopologyZoo/Integra','TopologyZoo/Marnet','TopologyZoo/Darkstrand','TopologyZoo/Packetexchange','TopologyZoo/GtsRomania','TopologyZoo/Noel','TopologyZoo/Restena','TopologyZoo/GtsHungary','TopologyZoo/Aconet','TopologyZoo/Amres','TopologyZoo/Arpanet19719','TopologyZoo/Arpanet19728','TopologyZoo/Garr199905','TopologyZoo/Psinet','TopologyZoo/Nsfnet','TopologyZoo/HurricaneElectric','TopologyZoo/HiberniaUs','TopologyZoo/Garr199904','TopologyZoo/Arn','TopologyZoo/Oxford','TopologyZoo/Abvt','TopologyZoo/Twaren','TopologyZoo/Renater1999','TopologyZoo/Xeex','TopologyZoo/Renater2004','TopologyZoo/Vinaren','TopologyZoo/Ilan','TopologyZoo/VisionNet','TopologyZoo/Sago','TopologyZoo/Ibm','TopologyZoo/Fatman','TopologyZoo/EliBackbone','TopologyZoo/Garr200404','TopologyZoo/KentmanJul2005','TopologyZoo/Shentel','TopologyZoo/WideJpn','TopologyZoo/Cynet','TopologyZoo/Nordu2010','TopologyZoo/Navigata','TopologyZoo/Claranet','TopologyZoo/Biznet','TopologyZoo/BtEurope','TopologyZoo/Arpanet19723','TopologyZoo/Quest','TopologyZoo/Gambia','TopologyZoo/Garr200112','TopologyZoo/Cesnet200304','TopologyZoo/Geant2001','TopologyZoo/KentmanAug2005','TopologyZoo/Peer1','TopologyZoo/Bbnplanet','TopologyZoo/Garr200109','TopologyZoo/Istar','TopologyZoo/Ernet','TopologyZoo/Jgn2Plus','TopologyZoo/Savvis','TopologyZoo/Janetbackbone','TopologyZoo/Agis','TopologyZoo/Uran','TopologyZoo/BtAsiaPac','TopologyZoo/HiberniaUk','TopologyZoo/Sprint','TopologyZoo/Grena','TopologyZoo/Compuserve','TopologyZoo/Atmnet','TopologyZoo/York','TopologyZoo/Goodnet','TopologyZoo/Renater2001'] 7 | traffics=['gravity_1'] 8 | routings=['sp'] 9 | 10 | #PPOAGENT 11 | batches = [25] 12 | gae_lambdas = [0.9,0.95] 13 | lrs = [0.0003] 14 | epsilons = [0.1,0.01] 15 | periods=[50] 16 | gammas=[0.95,0.97] 17 | clips=[0.2] 18 | epochs=[3] 19 | 20 | #ACTOR-CRITIC 21 | link_state_sizes=[16] 22 | message_iterations=[8] 23 | aggregations=['min_max'] 24 | nn_sizes = [[64,128]] 25 | dropouts = [0.15] 26 | activations = ['tanh', 'selu'] 27 | 28 | 29 | for activation in activations: 30 | for dropout in dropouts: 31 | for nn_size in nn_sizes: 32 | for aggregation in aggregations: 33 | for message_iteration in message_iterations: 34 | for link_state_size in link_state_sizes: 35 | for network in networks: 36 | for traffic in traffics: 37 | for routing in routings: 38 | for period in periods: 39 | for gamma in gammas: 40 | for clip in clips: 41 | for epoch in epochs: 42 | for batch in batches: 43 | for gae_lambda in gae_lambdas: 44 | for lr in lrs: 45 | for epsilon in epsilons: 46 | cmd = "python ./run.py" \ 47 | " --gin_bindings='Environment.env_type = \""+network+"\"'" \ 48 | " --gin_bindings='Environment.traffic_profile = \""+traffic+"\"'" \ 49 | " --gin_bindings='Environment.routing = \""+routing+"\"'" \ 50 | " --gin_bindings='PPOAgent.eval_period = "+str(period)+"'" \ 51 | " --gin_bindings='PPOAgent.gamma = "+str(gamma)+"'" \ 52 | " --gin_bindings='PPOAgent.clip_param = "+str(clip)+"'" \ 53 | " --gin_bindings='PPOAgent.epochs = "+str(epoch)+"'" \ 54 | " --gin_bindings='PPOAgent.batch_size = "+str(batch)+"'" \ 55 | " --gin_bindings='PPOAgent.gae_lambda = "+str(gae_lambda)+"'" \ 56 | " --gin_bindings='tf.keras.optimizers.Adam.learning_rate = "+str(lr)+"'" \ 57 | " --gin_bindings='tf.keras.optimizers.Adam.epsilon = "+str(epsilon)+"'" \ 58 | " --gin_bindings='Actor.link_state_size = "+str(link_state_size)+"'" \ 59 | " --gin_bindings='Actor.aggregation = \""+aggregation+"\"'" \ 60 | " --gin_bindings='Actor.first_hidden_layer_size = "+str(nn_size[1])+"'" \ 61 | " --gin_bindings='Actor.final_hidden_layer_size = "+str(nn_size[0])+"'" \ 62 | " --gin_bindings='Actor.dropout_rate = "+str(dropout)+"'" \ 63 | " --gin_bindings='Actor.message_iterations = "+str(message_iteration)+"'" \ 64 | " --gin_bindings='Actor.activation_fn = \""+activation+"\"'" \ 65 | " --gin_bindings='Critic.link_state_size = "+str(link_state_size)+"'" \ 66 | " --gin_bindings='Critic.aggregation = \""+aggregation+"\"'" \ 67 | " --gin_bindings='Critic.first_hidden_layer_size = "+str(nn_size[1])+"'" \ 68 | " --gin_bindings='Critic.final_hidden_layer_size = "+str(nn_size[0])+"'" \ 69 | " --gin_bindings='Critic.dropout_rate = "+str(dropout)+"'" \ 70 | " --gin_bindings='Critic.message_iterations = "+str(message_iteration)+"'" \ 71 | " --gin_bindings='Critic.activation_fn = \""+activation+"\"' &" 72 | 73 | os.system(cmd) -------------------------------------------------------------------------------- /lib/actor.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import numpy as np 4 | import networkx 5 | 6 | import gin.tf 7 | 8 | @gin.configurable 9 | class Actor(keras.Model): 10 | def __init__(self, 11 | graph, 12 | num_actions = 1, 13 | num_features = 2, 14 | link_state_size=16, 15 | #message_hidden_layer_size=64, 16 | aggregation='min_max', 17 | first_hidden_layer_size=128, 18 | dropout_rate=0.5, 19 | final_hidden_layer_size=64, 20 | message_iterations = 8, 21 | activation_fn='tanh', 22 | final_activation_fn='linear'): 23 | 24 | super(Actor, self).__init__() 25 | # HYPERPARAMETERS 26 | self.num_actions = num_actions 27 | self.num_features = num_features 28 | self.n_links = graph.number_of_edges() 29 | self.link_state_size = link_state_size 30 | self.message_hidden_layer_size = final_hidden_layer_size 31 | self.aggregation = aggregation 32 | self.message_iterations = message_iterations 33 | 34 | 35 | # FIXED INPUTS 36 | # for a link-link connection "i", self.incoming_links[i] is the incoming link 37 | # and self.outcoming_links[i] is the outcoming_link 38 | # see environment class function "_generate_link_indices_and_adjacencies()" for details 39 | self.incoming_links = graph.nodes()['graph_data']['incoming_links'] 40 | self.outcoming_links = graph.nodes()['graph_data']['outcoming_links'] 41 | 42 | # NEURAL NETWORKS 43 | self.hidden_layer_initializer = tf.keras.initializers.Orthogonal(gain=np.sqrt(2)) 44 | self.final_layer_initializer = tf.keras.initializers.Orthogonal(gain=0.01) 45 | self.kernel_regularizer = None #keras.regularizers.l2(0.01) 46 | #keras.initializers.VarianceScaling(scale=1.0 / np.sqrt(3.0),mode='fan_in', distribution='uniform') 47 | self.activation_fn = activation_fn 48 | self.final_hidden_layer_size = final_hidden_layer_size 49 | self.first_hidden_layer_size = first_hidden_layer_size 50 | self.dropout_rate = dropout_rate 51 | self.final_activation_fn = final_activation_fn 52 | self.define_network() 53 | 54 | 55 | def define_network(self): 56 | # message 57 | self.create_message = keras.models.Sequential(name='create_message') 58 | self.create_message.add(keras.layers.Dense(self.message_hidden_layer_size, 59 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 60 | self.create_message.add(keras.layers.Dense(self.link_state_size, 61 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 62 | 63 | # link update 64 | self.link_update = keras.models.Sequential(name='link_update') 65 | self.link_update.add(keras.layers.Dense(self.first_hidden_layer_size, 66 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 67 | self.link_update.add(keras.layers.Dense(self.final_hidden_layer_size, 68 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 69 | self.link_update.add(keras.layers.Dense(self.link_state_size, 70 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 71 | 72 | # readout 73 | self.readout = keras.models.Sequential(name='readout') 74 | self.readout.add(keras.layers.Dense(self.first_hidden_layer_size, kernel_initializer=self.hidden_layer_initializer, 75 | kernel_regularizer=self.kernel_regularizer, activation=self.activation_fn)) 76 | self.readout.add(keras.layers.Dropout(self.dropout_rate)) 77 | self.readout.add(keras.layers.Dense(self.final_hidden_layer_size, kernel_initializer=self.hidden_layer_initializer, 78 | kernel_regularizer=self.kernel_regularizer, activation=self.activation_fn)) 79 | self.readout.add(keras.layers.Dropout(self.dropout_rate)) 80 | self.readout.add(keras.layers.Dense(self.num_actions, kernel_initializer=self.final_layer_initializer, 81 | kernel_regularizer=self.kernel_regularizer, activation=self.final_activation_fn)) 82 | 83 | 84 | 85 | def build(self, input_shape=None): 86 | #del input_shape 87 | self.create_message.build(input_shape = [None, 2 * self.link_state_size]) 88 | if self.aggregation == 'sum': 89 | self.link_update.build(input_shape = [None, 2 * self.link_state_size]) 90 | elif self.aggregation == 'min_max': 91 | self.link_update.build(input_shape = [None, 3 * self.link_state_size]) 92 | self.readout.build(input_shape = [None, self.link_state_size]) 93 | self.built = True 94 | 95 | @tf.function 96 | def message_passing(self, input): 97 | input_tensor = tf.convert_to_tensor(input) 98 | link_states = tf.reshape(input_tensor, [self.num_features,self.n_links]) 99 | link_states = tf.transpose(link_states) 100 | padding = [[0,0],[0,self.link_state_size-self.num_features]] 101 | link_states = tf.pad(link_states, padding) 102 | 103 | # message passing part 104 | # links exchange information with their neighbors to update their states 105 | for _ in range(self.message_iterations): 106 | incoming_link_states = tf.gather(link_states, self.incoming_links) 107 | outcoming_link_states = tf.gather(link_states, self.outcoming_links) 108 | message_inputs = tf.cast(tf.concat([incoming_link_states, outcoming_link_states], axis=1), tf.float32) 109 | messages = self.create_message(message_inputs) 110 | 111 | aggregated_messages = self.message_aggregation(messages) 112 | link_update_input = tf.cast(tf.concat([link_states, aggregated_messages], axis=1), tf.float32) 113 | link_states = self.link_update(link_update_input) 114 | 115 | return link_states 116 | 117 | @tf.function 118 | def message_aggregation(self, messages): 119 | if self.aggregation == 'sum': 120 | aggregated_messages = tf.math.unsorted_segment_sum(messages, self.outcoming_links, num_segments=self.n_links) 121 | elif self.aggregation == 'min_max': 122 | agg_max = tf.math.unsorted_segment_max(messages, self.outcoming_links, num_segments=self.n_links) 123 | agg_min = tf.math.unsorted_segment_min(messages, self.outcoming_links, num_segments=self.n_links) 124 | aggregated_messages = tf.concat([agg_max, agg_min], axis=1) 125 | return aggregated_messages 126 | 127 | @tf.function 128 | def call(self, input): 129 | link_states = self.message_passing(input) 130 | 131 | policy = self.readout(link_states) 132 | policy = tf.reshape(policy, [-1]) 133 | 134 | return policy 135 | -------------------------------------------------------------------------------- /lib/critic.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import numpy as np 4 | import networkx 5 | 6 | import gin.tf 7 | 8 | 9 | @gin.configurable 10 | class Critic(keras.Model): 11 | def __init__(self, 12 | graph, 13 | num_features = 2, 14 | link_state_size=16, 15 | #message_hidden_layer_size=64, 16 | aggregation='min_max', 17 | first_hidden_layer_size=128, 18 | dropout_rate=0.5, 19 | final_hidden_layer_size=64, 20 | message_iterations = 8, 21 | activation_fn='tanh', 22 | final_activation_fn='linear'): 23 | super(Critic, self).__init__() 24 | # HYPERPARAMETERS 25 | self.num_features = num_features 26 | self.n_links = graph.number_of_edges() 27 | self.link_state_size = link_state_size 28 | self.message_hidden_layer_size = final_hidden_layer_size 29 | self.aggregation = aggregation 30 | self.message_iterations = message_iterations 31 | 32 | self.num_readout_input_aggregations = 4 33 | 34 | # FIXED INPUTS 35 | # for a link-link connection "i", self.incoming_links[i] is the incoming link 36 | # and self.outcoming_links[i] is the outcoming_link 37 | # see environment class function "_generate_link_indices_and_adjacencies()" for details 38 | self.incoming_links = graph.nodes()['graph_data']['incoming_links'] 39 | self.outcoming_links = graph.nodes()['graph_data']['outcoming_links'] 40 | 41 | # NEURAL NETWORKS 42 | self.hidden_layer_initializer = tf.keras.initializers.Orthogonal(gain=np.sqrt(2)) 43 | self.final_layer_initializer = tf.keras.initializers.Orthogonal(gain=1) 44 | self.kernel_regularizer = None #keras.regularizers.l2(0.01) 45 | #keras.initializers.VarianceScaling(scale=1.0 / np.sqrt(3.0),mode='fan_in', distribution='uniform') 46 | self.activation_fn = activation_fn 47 | self.final_hidden_layer_size = final_hidden_layer_size 48 | self.first_hidden_layer_size = first_hidden_layer_size 49 | self.dropout_rate = dropout_rate 50 | self.final_activation_fn = final_activation_fn 51 | self.define_network() 52 | 53 | 54 | def define_network(self): 55 | # message 56 | self.create_message = keras.models.Sequential(name='create_message') 57 | self.create_message.add(keras.layers.Dense(self.message_hidden_layer_size, 58 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 59 | self.create_message.add(keras.layers.Dense(self.link_state_size, 60 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 61 | 62 | # link update 63 | self.link_update = keras.models.Sequential(name='link_update') 64 | self.link_update.add(keras.layers.Dense(self.first_hidden_layer_size, 65 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 66 | self.link_update.add(keras.layers.Dense(self.final_hidden_layer_size, 67 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 68 | self.link_update.add(keras.layers.Dense(self.link_state_size, 69 | kernel_initializer=self.hidden_layer_initializer, activation=self.activation_fn)) 70 | 71 | # readout 72 | self.readout = keras.models.Sequential(name='readout') 73 | self.readout.add(keras.layers.Dense(self.first_hidden_layer_size, kernel_initializer=self.hidden_layer_initializer, 74 | kernel_regularizer=self.kernel_regularizer, activation=self.activation_fn)) 75 | self.readout.add(keras.layers.Dropout(self.dropout_rate)) 76 | self.readout.add(keras.layers.Dense(self.final_hidden_layer_size, kernel_initializer=self.hidden_layer_initializer, 77 | kernel_regularizer=self.kernel_regularizer, activation=self.activation_fn)) 78 | self.readout.add(keras.layers.Dropout(self.dropout_rate)) 79 | self.readout.add(keras.layers.Dense(1, kernel_initializer=self.final_layer_initializer, 80 | kernel_regularizer=self.kernel_regularizer, activation=self.final_activation_fn)) 81 | 82 | 83 | def build(self, input_shape=None): 84 | del input_shape 85 | self.create_message.build(input_shape = [None, 2 * self.link_state_size]) 86 | if self.aggregation == 'sum': 87 | self.link_update.build(input_shape = [None, 2 * self.link_state_size]) 88 | elif self.aggregation == 'min_max': 89 | self.link_update.build(input_shape = [None, 3 * self.link_state_size]) 90 | self.readout.build(input_shape = [None, self.link_state_size*self.num_readout_input_aggregations]) 91 | self.built = True 92 | 93 | @tf.function 94 | def message_passing(self, input): 95 | input_tensor = tf.convert_to_tensor(input) 96 | link_states = tf.reshape(input_tensor, [self.num_features,self.n_links]) 97 | link_states = tf.transpose(link_states) 98 | padding = [[0,0],[0,self.link_state_size-self.num_features]] 99 | link_states = tf.pad(link_states, padding) 100 | 101 | # message passing part 102 | # links exchange information with their neighbors to update their states 103 | for _ in range(self.message_iterations): 104 | incoming_link_states = tf.gather(link_states, self.incoming_links) 105 | outcoming_link_states = tf.gather(link_states, self.outcoming_links) 106 | message_inputs = tf.cast(tf.concat([incoming_link_states, outcoming_link_states], axis=1), tf.float32) 107 | messages = self.create_message(message_inputs) 108 | 109 | aggregated_messages = self.message_aggregation(messages) 110 | link_update_input = tf.cast(tf.concat([link_states, aggregated_messages], axis=1), tf.float32) 111 | link_states = self.link_update(link_update_input) 112 | 113 | return link_states 114 | 115 | @tf.function 116 | def message_aggregation(self, messages): 117 | if self.aggregation == 'sum': 118 | aggregated_messages = tf.math.unsorted_segment_sum(messages, self.outcoming_links, num_segments=self.n_links) 119 | elif self.aggregation == 'min_max': 120 | agg_max = tf.math.unsorted_segment_max(messages, self.outcoming_links, num_segments=self.n_links) 121 | agg_min = tf.math.unsorted_segment_min(messages, self.outcoming_links, num_segments=self.n_links) 122 | aggregated_messages = tf.concat([agg_max, agg_min], axis=1) 123 | return aggregated_messages 124 | 125 | @tf.function 126 | def generate_readout_input(self, link_states): 127 | ls_mean = tf.reduce_mean(link_states, axis=0) 128 | ls_max = tf.reduce_max(link_states, axis=0) 129 | ls_min = tf.reduce_min(link_states, axis=0) 130 | ls_std = tf.math.reduce_std(link_states, axis=0) 131 | 132 | readout_input = tf.concat([ls_mean,ls_max,ls_min,ls_std], axis=0) 133 | readout_input = tf.expand_dims(readout_input, axis=0) 134 | return readout_input 135 | 136 | @tf.function 137 | def call(self, input): 138 | 139 | link_states = self.message_passing(input) 140 | #link_states = tf.expand_dims(link_states, axis=0) 141 | 142 | readout_input = self.generate_readout_input(link_states) 143 | 144 | V = self.readout(readout_input) 145 | V = tf.reshape(V, [-1]) 146 | 147 | return V 148 | 149 | -------------------------------------------------------------------------------- /lib/run_experiment.py: -------------------------------------------------------------------------------- 1 | 2 | from environment.environment import Environment 3 | from agents.ppo_agent import PPOAgent 4 | 5 | import os 6 | import logging 7 | import tensorflow as tf 8 | 9 | import gin.tf 10 | 11 | @gin.configurable 12 | class Runner(object): 13 | 14 | def __init__(self, 15 | algorithm='PPO', 16 | reload_model=False, 17 | model_dir=None, 18 | only_eval=False, 19 | base_dir='logs', 20 | checkpoint_base_dir='checkpoints', 21 | save_checkpoints=True): 22 | 23 | env = Environment() 24 | self.save_checkpoints = save_checkpoints 25 | if algorithm == 'PPO': 26 | if reload_model: 27 | old_actor = False #if '_agg' in model_dir else True 28 | self.agent = PPOAgent(env, old_actor=old_actor, save_checkpoints=save_checkpoints) 29 | else: 30 | self.agent = PPOAgent(env, save_checkpoints=save_checkpoints) 31 | else: 32 | #Insert here your customized RL algorithm 33 | assert (False), 'RL Algorithm %s is not implemented' %algorithm 34 | 35 | self.base_dir= base_dir 36 | self.checkpoint_base_dir = checkpoint_base_dir 37 | 38 | self.only_eval = only_eval 39 | if reload_model or self.only_eval: 40 | self.agent.load_saved_model(model_dir, only_eval) 41 | self.set_logs_and_checkpoints() 42 | 43 | def run_experiment(self): 44 | if self.only_eval: 45 | self.agent.only_evaluate() 46 | else: 47 | self.agent.train_and_evaluate() 48 | 49 | def set_logs_and_checkpoints(self): 50 | experiment_identifier = self.agent.set_experiment_identifier(self.only_eval) 51 | 52 | writer_dir = os.path.join(self.base_dir, experiment_identifier) 53 | if not os.path.exists(writer_dir): 54 | os.makedirs(writer_dir) 55 | 56 | checkpoint_dir = os.path.join(self.checkpoint_base_dir, experiment_identifier) 57 | if self.save_checkpoints and (not os.path.exists(checkpoint_dir)): 58 | os.makedirs(checkpoint_dir) 59 | 60 | self.agent.set_writer_and_checkpoint_dir(writer_dir, checkpoint_dir) 61 | 62 | f = open(os.path.join(writer_dir, 'out.log'), 'w+') 63 | f.close() 64 | fh = logging.FileHandler(os.path.join(writer_dir, 'out.log')) 65 | fh.setLevel(logging.DEBUG) # or any level you want 66 | tf.get_logger().addHandler(fh) 67 | 68 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | appnope==0.1.0 3 | astor==0.8.1 4 | astroid==2.5.7 5 | astunparse==1.6.3 6 | attrs==19.3.0 7 | backcall==0.1.0 8 | bleach==3.1.0 9 | cached-property==1.5.2 10 | cachetools==4.1.0 11 | certifi==2020.4.5.1 12 | chardet==3.0.4 13 | clang==5.0 14 | cloudpickle==1.6.0 15 | cycler==0.10.0 16 | decorator==4.4.1 17 | defusedxml==0.6.0 18 | dm-tree==0.1.5 19 | entrypoints==0.3 20 | flatbuffers==1.12 21 | gast==0.4.0 22 | gin-config==0.3.0 23 | google-auth==1.14.1 24 | google-auth-oauthlib==0.4.1 25 | google-pasta==0.2.0 26 | grpcio==1.39.0 27 | h5py==3.1.0 28 | idna==2.9 29 | ignnition==1.0.2 30 | importlib==1.0.4 31 | importlib-metadata==1.6.0 32 | importlib-resources==5.1.4 33 | ipykernel==5.1.4 34 | ipython==7.12.0 35 | ipython-genutils==0.2.0 36 | ipywidgets==7.5.1 37 | isort==4.3.21 38 | jedi==0.16.0 39 | Jinja2==2.11.1 40 | joblib==0.15.1 41 | jsonschema==3.2.0 42 | jupyter==1.0.0 43 | jupyter-client==5.3.4 44 | jupyter-console==6.1.0 45 | jupyter-core==4.6.1 46 | keras==2.6.0 47 | Keras-Applications==1.0.8 48 | Keras-Preprocessing==1.1.2 49 | kiwisolver==1.1.0 50 | lazy-object-proxy==1.4.3 51 | Markdown==3.2.1 52 | MarkupSafe==1.1.1 53 | matplotlib==3.1.3 54 | mccabe==0.6.1 55 | mistune==0.8.4 56 | nbconvert==5.6.1 57 | nbformat==5.0.4 58 | networkx==2.4 59 | notebook==6.0.3 60 | numpy==1.19.5 61 | oauthlib==3.1.0 62 | opt-einsum==3.3.0 63 | pandas==1.0.4 64 | pandocfilters==1.4.2 65 | parso==0.6.1 66 | pexpect==4.8.0 67 | pickle5==0.0.11 68 | pickleshare==0.7.5 69 | prometheus-client==0.7.1 70 | prompt-toolkit==3.0.3 71 | protobuf==3.11.3 72 | ptyprocess==0.6.0 73 | pyasn1==0.4.8 74 | pyasn1-modules==0.2.8 75 | Pygments==2.5.2 76 | pylint==2.8.2 77 | pyparsing==2.4.6 78 | pyrsistent==0.16.0 79 | python-dateutil==2.8.1 80 | pytz==2020.1 81 | PyYAML==5.4.1 82 | pyzmq==18.1.1 83 | qtconsole==4.6.0 84 | requests==2.23.0 85 | requests-oauthlib==1.3.0 86 | rsa==4.0 87 | scikit-learn==0.23.1 88 | scipy==1.4.1 89 | seaborn==0.10.1 90 | Send2Trash==1.5.0 91 | six==1.15.0 92 | sklearn==0.0 93 | tensorboard==2.6.0 94 | tensorboard-data-server==0.6.1 95 | tensorboard-plugin-wit==1.7.0 96 | tensorboardX==1.9 97 | tensorflow==2.6.0 98 | tensorflow-estimator==2.6.0 99 | tensorflow-probability==0.11.1 100 | termcolor==1.1.0 101 | terminado==0.8.3 102 | testpath==0.4.4 103 | threadpoolctl==2.1.0 104 | toml==0.10.2 105 | tornado==6.0.3 106 | traitlets==4.3.3 107 | typed-ast==1.4.1 108 | typing-extensions==3.7.4.3 109 | urllib3==1.25.9 110 | wcwidth==0.1.8 111 | webencodings==0.5.1 112 | Werkzeug==1.0.1 113 | widgetsnbextension==3.5.1 114 | wrapt==1.12.1 115 | zipp==3.1.0 116 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | from agents.ppo_agent import PPOAgent 7 | from lib.run_experiment import Runner 8 | from utils.functions import load_gin_configs 9 | 10 | from absl import app 11 | from absl import flags 12 | 13 | 14 | flags.DEFINE_string('base_dir', 'logs', 15 | 'Base directory to host all required sub-directories.') 16 | flags.DEFINE_string('network', 'logs', 17 | 'Base directory to host all required sub-directories.') 18 | flags.DEFINE_multi_string( 19 | 'gin_files', ["configs/general.gin", "configs/ppo.gin"], 'List of paths to gin configuration files') 20 | flags.DEFINE_multi_string( 21 | 'gin_bindings', [], 22 | 'Gin bindings to override the values set in the config files ') 23 | 24 | FLAGS = flags.FLAGS 25 | 26 | 27 | def main(unused_argv): 28 | """Main method. 29 | Args: 30 | unused_argv: Arguments (unused). 31 | """ 32 | load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings) 33 | 34 | orig_stdout = sys.stdout 35 | orig_stderr = sys.stderr 36 | 37 | runner = Runner() 38 | 39 | f = open(os.path.join(runner.agent.writer_dir, 'out.txt'), 'w+') 40 | sys.stdout = f 41 | sys.stderr = f 42 | 43 | runner.run_experiment() 44 | 45 | sys.stdout = orig_stdout 46 | sys.stderr = orig_stderr 47 | f.close() 48 | 49 | 50 | if __name__ == '__main__': 51 | app.run(main) 52 | -------------------------------------------------------------------------------- /utils/defo_process_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python3 2 | 3 | import numpy as np 4 | import re 5 | import sys 6 | sys.path.append('../') 7 | 8 | node_to_index_dic = {} 9 | index_to_node_lst = [] 10 | 11 | def index_to_node(n): 12 | return(index_to_node_lst[n]) 13 | 14 | def node_to_index(node): 15 | return(node_to_index_dic[node]) 16 | 17 | 18 | class Defo_results: 19 | 20 | net_size = 0 21 | MP_matrix = None 22 | ecmp_routing_matrix = None 23 | routing_matrix = None 24 | links_bw = None 25 | 26 | def __init__(self, graph_file, results_file): 27 | self.graph_file = graph_file 28 | self.results_file = results_file 29 | 30 | self.process_graph_file() 31 | 32 | self.process() 33 | 34 | def process_graph_file(self): 35 | with open(self.graph_file) as fd: 36 | line = fd.readline() 37 | camps = line.split(" ") 38 | self.net_size = int(camps[1]) 39 | # Remove : label x y 40 | line = fd.readline() 41 | 42 | for i in range (self.net_size): 43 | line = fd.readline() 44 | node = line[0:line.find(" ")] 45 | node_to_index_dic[node] = i 46 | index_to_node_lst.append(node) 47 | 48 | self.links_bw = [] 49 | for i in range(self.net_size): 50 | self.links_bw.append({}) 51 | for line in fd: 52 | if (not line.startswith("Link_") and not line.startswith("edge_")): 53 | continue 54 | camps = line.split(" ") 55 | src = int(camps[1]) 56 | dst = int(camps[2]) 57 | bw = float(camps[4]) 58 | self.links_bw[src][dst] = bw 59 | 60 | 61 | 62 | 63 | def process (self): 64 | with open(self.results_file) as fd: 65 | while (True): 66 | line = fd.readline() 67 | if (line == ""): 68 | break 69 | if (line.startswith("*")): 70 | if (line == "***Next hops priority 2 (sr paths)***\n"): 71 | self._read_middle_points(fd) 72 | if (line == "***Next hops priority 3 (ecmp paths)***\n"): 73 | self._read_ecmp_routing(fd) 74 | break 75 | self._gen_routing_matrix() 76 | 77 | def _read_middle_points(self,fd): 78 | self.MP_matrix = np.zeros((self.net_size,self.net_size),dtype="object") 79 | while (True): 80 | pos = fd.tell() 81 | line = fd.readline() 82 | #print (line) 83 | if (line.startswith("*")): 84 | fd.seek(pos) 85 | return 86 | if (not line.startswith("seq")): 87 | continue 88 | line = line[line.find(": ")+2:] 89 | if (line[-1]=='\n'): 90 | line = line[:-1] 91 | 92 | ptr = 0 93 | mp_path = [] 94 | while (True): 95 | prev_ptr = ptr 96 | ptr = line.find(" -> ",ptr) 97 | if (ptr == -1): 98 | mp_path.append(line[prev_ptr:]) 99 | break 100 | else: 101 | mp_path.append(line[prev_ptr:ptr]) 102 | ptr += 4 103 | src = node_to_index(mp_path[0]) 104 | dst = node_to_index(mp_path[-1]) 105 | self.MP_matrix[src,dst] = mp_path 106 | 107 | 108 | def _read_ecmp_routing(self,fd): 109 | self.ecmp_routing_matrix = np.zeros((self.net_size,self.net_size),dtype="object") 110 | next_node_matrix = np.zeros((self.net_size,self.net_size),dtype="object") 111 | dst_node = None 112 | while (True): 113 | line = fd.readline() 114 | if (line == ""): 115 | break 116 | if (line.startswith("Destination")): 117 | dst_node_str = line[line.find(" ")+1:-1] 118 | dst_node = node_to_index(dst_node_str) 119 | if (line.startswith("node")): 120 | src_node_str = line[6:line.find(", ")] 121 | src_node = node_to_index(src_node_str) 122 | sub_line = line[line.find("[")+1:line.find("]")] 123 | ptr = 0 124 | next_node_lst = [] 125 | while (True): 126 | prev_ptr = ptr 127 | ptr = sub_line.find(", ",ptr) 128 | if (ptr == -1): 129 | next_node_lst.append(sub_line[prev_ptr:]) 130 | break 131 | else: 132 | next_node_lst.append(sub_line[prev_ptr:ptr]) 133 | ptr += 2 134 | 135 | next_node_matrix[src_node,dst_node] = next_node_lst 136 | 137 | for i in range (self.net_size): 138 | for j in range (self.net_size): 139 | end_paths = [] 140 | paths_info = [{"path":[index_to_node(i)],"proportion":1.0}] 141 | while (len(paths_info) != 0): 142 | for path_info in paths_info: 143 | path = path_info["path"] 144 | if (node_to_index(path[-1]) == j): 145 | paths_info.remove(path_info) 146 | end_paths.append(path_info) 147 | continue 148 | next_lst = next_node_matrix[node_to_index(path[-1]),j] 149 | num_next_hops = len(next_lst) 150 | if (num_next_hops > 1): 151 | for next_node in next_lst: 152 | new_path = list(path) 153 | new_path.append(next_node) 154 | paths_info.append({"path":new_path,"proportion":path_info["proportion"]/num_next_hops}) 155 | paths_info.remove(path_info) 156 | else: 157 | path.append(next_lst[0]) 158 | self.ecmp_routing_matrix[i,j] = end_paths 159 | 160 | def _gen_routing_matrix(self): 161 | self.routing_matrix = np.zeros((self.net_size,self.net_size),dtype="object") 162 | for i in range(self.net_size): 163 | for j in range(self.net_size): 164 | if (i == j): 165 | continue 166 | #print(self.MP_matrix) 167 | end_path_info_list = [] 168 | mp_path = self.MP_matrix[i,j] 169 | #print (i,j,mp_path) 170 | if type(mp_path) is not list: 171 | continue 172 | src_mp = mp_path[0] 173 | for mp in mp_path: 174 | dst_mp = mp 175 | sub_path_info_lst = self.ecmp_routing_matrix[node_to_index(src_mp),node_to_index(dst_mp)] 176 | if (len(end_path_info_list) == 0): 177 | for sub_path_info in sub_path_info_lst: 178 | end_path_info_list.append({"path":sub_path_info["path"][:-1],"proportion":sub_path_info["proportion"]}) 179 | elif (len(sub_path_info_lst) > 1): 180 | aux_end_path_list = [] 181 | for path_info in end_path_info_list: 182 | for sub_path_info in sub_path_info_lst: 183 | new_path = list(path_info["path"]) 184 | new_path.extend(sub_path_info["path"][:-1]) 185 | aux_end_path_list.append({"path":new_path,"proportion":path_info["proportion"]*sub_path_info["proportion"]}) 186 | end_path_info_list = aux_end_path_list 187 | else: 188 | for path_info in end_path_info_list: 189 | path_info["path"].extend(sub_path_info_lst[0]["path"][:-1]) 190 | src_mp = dst_mp 191 | for path_info in end_path_info_list: 192 | path_info["path"].append(dst_mp) 193 | self.routing_matrix[i,j] = end_path_info_list 194 | 195 | def _get_traffic_matrix (self,traffic_file): 196 | tm = np.zeros((self.net_size,self.net_size)) 197 | with open(traffic_file) as fd: 198 | fd.readline() 199 | fd.readline() 200 | for line in fd: 201 | camps = line.split(" ") 202 | tm[int(camps[1]),int(camps[2])] = float(camps[3]) 203 | return (tm) 204 | 205 | def _link_utilization(self, routing_matrix, traffic_file): 206 | link_utilization = [] 207 | traffic_matrix = self._get_traffic_matrix(traffic_file) 208 | for i in range(self.net_size): 209 | link_utilization.append({}) 210 | for i in range(self.net_size): 211 | for j in range (self.net_size): 212 | if (i==j): 213 | continue 214 | traffic_all_path = traffic_matrix[i,j] 215 | routings_lst = routing_matrix[i,j] 216 | if routings_lst == 0: 217 | continue 218 | for path_info in routings_lst: 219 | path = path_info["path"] 220 | traffic = traffic_all_path*path_info["proportion"] 221 | n0 = path[0] 222 | for n1 in path[1:]: 223 | N0 = node_to_index(n0) 224 | N1 = node_to_index(n1) 225 | if N1 in link_utilization[N0]: 226 | link_utilization[N0][N1] += traffic 227 | else: 228 | link_utilization[N0][N1] = traffic 229 | n0 = n1 230 | max_lu = 0 231 | 232 | for i in range(self.net_size): 233 | for j in link_utilization[i].keys(): 234 | link_traffic = link_utilization[i][j] 235 | link_capacity = self.links_bw[i][j] 236 | link_utilization[i][j] = link_traffic / link_capacity 237 | if (link_utilization[i][j] > max_lu): 238 | max_lu = link_utilization[i][j] 239 | #print ("max link utilization :",max_lu) 240 | return (link_utilization) 241 | 242 | def get_opt_link_utilization(self,traffic_file): 243 | return (self._link_utilization(self.routing_matrix,traffic_file)) 244 | 245 | def get_direct_link_utilization(self,traffic_file): 246 | return (self._link_utilization(self.ecmp_routing_matrix,traffic_file)) 247 | 248 | 249 | def get_traffic_matrix(tm_file, net_size=14): 250 | tm = np.zeros((net_size,net_size)) 251 | with open(tm_file) as fd: 252 | fd.readline() 253 | fd.readline() 254 | for line in fd: 255 | camps = line.split(" ") 256 | tm[int(camps[1]),int(camps[2])] = float(camps[3]) 257 | return (tm) 258 | 259 | def get_link_utilization(results): 260 | return [item for sublist in [[value for value in results[i].values()] for i in range(len(results))] for item in sublist] 261 | 262 | 263 | if (__name__ == "__main__"): 264 | 265 | network = 'NSFNet' 266 | defo_dir = "./datasets/defo_" + network + '_dataset/' 267 | init_seed = 100 268 | final_seed = 199 269 | 270 | ecmp, defo = [], [] 271 | graph_file = defo_dir + network + '_defo.txt' 272 | for i in range(init_seed, final_seed+1): 273 | results_file = defo_dir +'Results-' + str(i) 274 | tm_file = defo_dir + 'TM-' + str(i) 275 | results = Defo_results(graph_file,results_file) 276 | results_ecmp = results.get_direct_link_utilization(tm_file) 277 | lu_ecmp = get_link_utilization(results_ecmp) 278 | results_defo = results.get_opt_link_utilization(tm_file) 279 | lu_defo = get_link_utilization(results_defo) 280 | ecmp.append(np.max(lu_ecmp)) 281 | defo.append(np.max(lu_defo)) 282 | 283 | #print ("============== Direct =====================") 284 | #print (results.get_direct_link_utilization(tm_file)) 285 | #print ("============== Optim =====================") 286 | #print (results.get_opt_link_utilization(tm_file)) 287 | -------------------------------------------------------------------------------- /utils/functions.py: -------------------------------------------------------------------------------- 1 | import gin.tf 2 | from itertools import tee 3 | import numpy as np 4 | 5 | 6 | def load_gin_configs(gin_files, gin_bindings): 7 | """Loads gin configuration files. 8 | Args: 9 | gin_files: list, of paths to the gin configuration files for this 10 | experiment. 11 | gin_bindings: list, of gin parameter bindings to override the values in 12 | the config files. 13 | """ 14 | gin.parse_config_files_and_bindings(gin_files, 15 | bindings=gin_bindings, 16 | skip_unknown=False) 17 | 18 | 19 | def linearly_decaying_epsilon(decay_period, step, warmup_steps, epsilon): 20 | """Returns the current epsilon for the agent's epsilon-greedy policy. 21 | This follows the Nature DQN schedule of a linearly decaying epsilon (Mnih et 22 | al., 2015). The schedule is as follows: 23 | Begin at 1. until warmup_steps steps have been taken; then 24 | Linearly decay epsilon from 1. to epsilon in decay_period steps; and then 25 | Use epsilon from there on. 26 | Args: 27 | decay_period: float, the period over which epsilon is decayed. 28 | step: int, the number of training steps completed so far. 29 | warmup_steps: int, the number of steps taken before epsilon is decayed. 30 | epsilon: float, the final value to which to decay the epsilon parameter. 31 | Returns: 32 | A float, the current epsilon value computed according to the schedule. 33 | """ 34 | steps_left = decay_period + warmup_steps - step 35 | bonus = (1.0 - epsilon) * steps_left / decay_period 36 | bonus = np.clip(bonus, 0., 1. - epsilon) 37 | return epsilon + bonus 38 | 39 | 40 | def pairwise_iteration(iterable): 41 | "s -> (s0,s1), (s1,s2), (s2, s3), ..." 42 | a, b = tee(iterable) 43 | next(b, None) 44 | return zip(a, b) 45 | 46 | 47 | def find_min_max_path_length(path_lengths): 48 | max_length = 0.0 49 | min_length = 100.0 50 | for source in path_lengths.keys(): 51 | for dest in path_lengths[source].keys(): 52 | length = path_lengths[source][dest] 53 | if length > max_length: 54 | max_length = length 55 | elif length < min_length: 56 | min_length = length 57 | return min_length, max_length 58 | 59 | 60 | def get_traffic_matrix(tm_file, nodes): 61 | tm = np.zeros((nodes,nodes)) 62 | with open(tm_file) as fd: 63 | fd.readline() 64 | fd.readline() 65 | for line in fd: 66 | camps = line.split(" ") 67 | tm[int(camps[1]),int(camps[2])] = float(camps[3]) 68 | return (tm) 69 | -------------------------------------------------------------------------------- /utils/tf_logs.py: -------------------------------------------------------------------------------- 1 | import gin.tf 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | def training_step_logs(writer, env, training_step, loss, action, state, weights): 7 | with writer.as_default(): 8 | with tf.name_scope('Training'): 9 | tf.summary.scalar("Loss", loss, step=training_step) 10 | tf.summary.scalar("Selected Link", action, step=training_step) 11 | if env.link_traffic_to_states: 12 | link_utilization = np.mean(state[:env.n_links]) 13 | tf.summary.scalar("Mean Link Utilization", link_utilization, step=training_step) 14 | tf.summary.scalar("Max Link Utilization", np.max(state[:env.n_links]), step=training_step) 15 | tf.summary.scalar("Min Link Utilization", np.min(state[:env.n_links]), step=training_step) 16 | if env.weigths_to_states: 17 | tf.summary.scalar("Weight mean", np.mean(weights), step=training_step) 18 | tf.summary.scalar("Weight max", np.max(weights), step=training_step) 19 | tf.summary.scalar("Weight min", np.min(weights), step=training_step) 20 | tf.summary.scalar("Weight std", np.std(weights), step=training_step) 21 | writer.flush() 22 | 23 | 24 | def training_episode_logs(writer, env, episode, states, assigned_rewards, losses=None, actor_losses=None, critic_losses=None): 25 | with writer.as_default(): 26 | with tf.name_scope('Training'): 27 | tf.summary.scalar("Reward mean", np.mean(assigned_rewards), step=episode) 28 | #tf.summary.scalar("Reward max", np.max(assigned_rewards), step=episode) 29 | #tf.summary.scalar("Reward min", np.min(assigned_rewards), step=episode) 30 | #if env.link_traffic_to_states: 31 | #mean_link_utilization = [np.mean(elem[:env.n_links]) for elem in states] 32 | #tf.summary.scalar("Mean Link Utilization mean", np.mean(mean_link_utilization), step=episode) 33 | #tf.summary.scalar("Mean Link Utilization max", np.max(mean_link_utilization), step=episode) 34 | #tf.summary.scalar("Mean Link Utilization min", np.min(mean_link_utilization), step=episode) 35 | if losses is not None: 36 | tf.summary.scalar("Loss mean", np.mean(losses), step=episode) 37 | if actor_losses is not None: 38 | tf.summary.scalar("Actor Loss mean", np.mean(actor_losses), step=episode) 39 | if critic_losses is not None: 40 | tf.summary.scalar("Critic Loss mean", np.mean(critic_losses), step=episode) 41 | writer.flush() 42 | 43 | 44 | def eval_step_logs(writer, env, eval_step, state, reward=None, prob=None, value=None): 45 | network = ('+').join(env.env_type) 46 | with writer.as_default(): 47 | with tf.name_scope('Eval'): 48 | if reward is not None: 49 | tf.summary.scalar("reward", reward, step=eval_step) 50 | if prob is not None: 51 | tf.summary.scalar("Prob", prob, step=eval_step) 52 | if value is not None: 53 | tf.summary.scalar("Value", value, step=eval_step) 54 | 55 | traffic = state[:env.n_links] 56 | #tf.summary.scalar("traffic mean", np.mean(traffic), step=eval_step) 57 | #tf.summary.scalar("traffic std", np.std(traffic), step=eval_step) 58 | #tf.summary.scalar(network + " - Max Traffic", np.max(traffic), step=eval_step) 59 | #tf.summary.scalar("traffic min", np.min(traffic), step=eval_step) 60 | 61 | weights = env.raw_weights 62 | #tf.summary.scalar("weights mean", np.mean(weights), step=eval_step) 63 | #tf.summary.scalar("weights std", np.std(weights), step=eval_step) 64 | #tf.summary.scalar("weights max", np.max(weights), step=eval_step) 65 | #tf.summary.scalar("weights min", np.min(weights), step=eval_step) 66 | tf.summary.scalar(network + " - Weights Diff Min Max", np.max(weights) - np.min(weights), step=eval_step) 67 | 68 | writer.flush() 69 | 70 | def eval_final_log(writer, eval_episode, max_link_utilization, network): 71 | with writer.as_default(): 72 | with tf.name_scope('Eval'): 73 | #tf.summary.scalar("Number of Nodes", num_nodes, step=eval_episode) 74 | #tf.summary.scalar("Number of Sample", num_sample, step=eval_episode) 75 | tf.summary.scalar(network + " - Starting Max LU", max_link_utilization[0], step=eval_episode) 76 | #idx_min_max = np.argmin(max_link_utilization) 77 | #tf.summary.scalar("Min Max LU", max_link_utilization[idx_min_max], step=eval_episode) 78 | #tf.summary.scalar("Min Max LU - Mean LU", mean_link_utilization[idx_min_max], step=eval_episode) 79 | #episode_length = len(max_link_utilization) 80 | idx_min_max = np.argmin(max_link_utilization) 81 | tf.summary.scalar(network + " - Min Max LU", max_link_utilization[idx_min_max], step=eval_episode) 82 | #idx_min_mean = np.argmin(mean_link_utilization) 83 | #tf.summary.scalar("Min Mean LU", mean_link_utilization[idx_min_mean], step=eval_episode) 84 | #tf.summary.scalar("Min Mean LU - Max LU", max_link_utilization[idx_min_mean], step=eval_episode) 85 | writer.flush() 86 | 87 | 88 | def eval_top_log(writer, eval_episode, min_max, network): 89 | with writer.as_default(): 90 | with tf.name_scope('Eval'): 91 | mean_min_max = np.mean(min_max) 92 | tf.summary.scalar(network + " - MEAN Min Max LU", mean_min_max, step=eval_episode) 93 | 94 | writer.flush() 95 | --------------------------------------------------------------------------------