├── Notice.txt ├── README.md ├── RL ├── config.py ├── env.py ├── env_config.py ├── log.py ├── net.py ├── replymemory.py ├── rl.py ├── test.py └── train.py ├── mininet ├── generate_matrices.py ├── generate_nodes_topo.py ├── iperf_script.py ├── links_info │ └── links_info.xml └── topologies │ ├── topology-anonymised.xml │ └── topology2.xml └── ryu ├── arp_handler.py ├── network_delay.py ├── network_monitor.py ├── network_structure.py ├── setting.py └── shortest_path_forwarding.py /Notice.txt: -------------------------------------------------------------------------------- 1 | Dear Researcher 2 | Thank you for your interest in our work. 3 | However, since we found that in 2024, a research group of graduate students used our open source code to participate in the China ## Network Technology Competition and won a prize, with almost no changes to our work and code, and in 2023, when we reviewed a manuscript for some journal, we found that the manuscript submitted by a graduate student from Sichuan Province also had stolen manuscripts and codes with almost no changes. Therefore, we withdrew all the codes from publication and only some of them were made available. The interested research groups could reimplement our code according to the detailed steps introduced in our paper. 4 | We strongly condemn this immoral act. 5 | Authors. 6 | November, 2024 7 | 8 | 9 | 声明 10 | 11 | 由于发现24年有研究课题小组的研究生用我们开源的代码参加中国##学生网络技术比赛并获奖,几乎没有任何的内容设计和代码上的改动,我们课题组在23年为某一中文核心期刊审稿时发现四川省某高校的一个研究生的投稿稿件也是盗用了我们课题组未上网的中文稿件和代码,因此我们暂停公开所有代码,有兴趣的课题组可以根据我们论文介绍的详细步骤实现我们的代码。 12 | 我们严厉谴责这种不道德的行为。 13 | 2014年11月 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DRL-M4MR: An intelligent multicast routing approach based on DQN deep reinforcement learning in SDN 2 | 3 | Traditional multicast routing methods have some problems in constructing a multicast tree. These problems include limited access to network state information, poor adaptability to dynamic and complex changes in the network, and inflexible data forwarding. To address these defects, the optimal multicast routing problem in software-defined networking (SDN) is tailored as a multiobjective optimization problem, and DRL-M4MR, an intelligent multicast routing algorithm based on the deep Q network (DQN) deep reinforcement learning (DRL) method is designed to construct a multicast tree in a software-defined network. First, combining the characteristics of SDN global network-aware information, the multicast tree state matrix, link bandwidth matrix, link delay matrix and link packet loss rate matrix are designed as the state space of the reinforcement learning agent to solve the problem in that the original method cannot make full use of network status information. Second, the action space of the agent is all the links in the network, and the action selection strategy is designed to add the links to the current multicast tree in four cases. Third, single-step and final reward function forms are designed to guide the agent to make decisions to construct the optimal multicast tree. The double network architectures, dueling network architectures and prioritized experience replay are adopted to improve the learning efficiency and convergence of the agent. Finally, after the DRL-M4MR agent is trained, the SDN controller installs the multicast flow entries by reversely traversing the multicast tree to the SDN switches to implement intelligent multicast routing. The experimental results show that, compared with existing algorithms, the multicast tree constructed by DRL-M4MR can obtain better bandwidth, delay, and packet loss rate performance after training, and it can make more intelligent multicast routing decisions in a dynamic network environment. 4 | 5 | -------------------------------------------------------------------------------- /RL/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : config.py 3 | # @Date : 2022-05-18 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # @From : 6 | import time 7 | from pathlib import Path 8 | import platform 9 | import torch 10 | import numpy 11 | 12 | sys_platform = platform.system() 13 | 14 | 15 | class Config: 16 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 17 | TIME = time.strftime("%Y%m%d%H%M%S", time.localtime()) 18 | 19 | # dtype 20 | NUMPY_TYPE = numpy.float32 21 | TORCH_TYPE = torch.float32 22 | 23 | # Net 24 | NUM_STATES = 14 ** 2 * 4 25 | NUM_ACTIONS = 14 ** 2 26 | 27 | PKL_STEP = 3 28 | PKL_NUM = 120 29 | PKL_START = 10 # index 从0开始 30 | PKL_CUT_NUM = PKL_START + PKL_NUM * PKL_STEP # 结束index 31 | 32 | # DQN 33 | BATCH_SIZE = 8 34 | 35 | # memory pool 36 | MEMORY_CAPACITY = 1024 * 2 37 | 38 | # nsteps 39 | N_STEPS = 1 40 | 41 | # hyper-parameters 42 | LR = 1e-5 # learning rate 43 | REWARD_DECAY = 0.9 # gamma 44 | 45 | USE_DECAY = True 46 | E_GREEDY_START = 1 # epsilon 47 | E_GREEDY_FINAL = 0.01 # epsilon 48 | E_GREEDY_DECAY = 700 # epsilon 49 | E_GREEDY = [E_GREEDY_START, E_GREEDY_FINAL, E_GREEDY_DECAY] 50 | 51 | E_GREEDY_ORI = 0.2 52 | 53 | TAU = 1 54 | 55 | EPISODES = 4000 # 训练代数 56 | UPDATE_TARGET_FREQUENCY = 10 57 | 58 | REWARD_DEFAULT = [1.0, 0.1, -0.1, -1] 59 | 60 | DISCOUNT = 0.9 61 | 62 | BETA1 = 1 63 | BETA2 = 1 64 | BETA3 = 1 65 | 66 | A_STEP = 0 67 | B_STEP = None 68 | 69 | A_PATH = 0 70 | B_PATH = 1 71 | 72 | START_SYMBOL = 1 73 | END_SYMBOL = 2 74 | STEP_SYMBOL = 1 75 | BRANCH_SYMBOL = 2 76 | 77 | # control 78 | # START_LEARN = 200 79 | CLAMP = False # 梯度裁剪 80 | 81 | # file path and pkl path 82 | if sys_platform == "Windows": 83 | xml_topology_path = Path( 84 | r'D:\WorkSpace\Hello_Myself\Hello_Multicast\RLMulticastProject\mininet\topologies\topology2.xml') 85 | pkl_weight_path = Path( 86 | r"D:\WorkSpace\Hello_Myself\Hello_Multicast\RLMulticastProject\ryu\pickle\2022-03-11-19-40-21") 87 | 88 | 89 | else: 90 | xml_topology_path = Path(r'/home/dell/RLMulticastProject/mininet/topologies/topology2.xml') 91 | pkl_weight_path = Path(r"/home/dell/RLMulticastProject/ryu/pickle/2022-03-11-19-40-21") 92 | 93 | # nodes 94 | start_node = 12 95 | end_nodes = [2, 4, 11] 96 | 97 | @classmethod 98 | def set_num_states_actions(cls, state_space_num, action_space_num): 99 | cls.NUM_STATES = state_space_num 100 | cls.NUM_ACTIONS = action_space_num 101 | 102 | @classmethod 103 | def log_params(cls, logger): 104 | rewards_info = "\n===rewards===\n" + \ 105 | f" REWARD_DEFAULT:{cls.REWARD_DEFAULT}\n" 106 | lr_info = "===LR===\n" + \ 107 | f" LR:{cls.LR}\n" 108 | episodes_info = "===EPISODES===\n" + \ 109 | f" EPISODES:{cls.EPISODES}\n" 110 | batchsize_info = "===BATCH_SIZE===\n" + \ 111 | f" BATCH_SIZE:{cls.BATCH_SIZE}\n" 112 | 113 | update_infp = "===UPDATE_TARGET_FREQUENCY===\n" + \ 114 | f" UPDATE_TARGET_FREQUENCY:{cls.UPDATE_TARGET_FREQUENCY}\n" 115 | gamma_info = "===gamma==\n" + \ 116 | f" REWARD_DECAY:{cls.REWARD_DECAY}\n" 117 | nsteps_info = "===nsteps===\n" + \ 118 | f" N_STEPS:{cls.N_STEPS}\n" 119 | egreedy_info = "===egreedy===\n" + \ 120 | f" [E_GREEDY_START, E_GREEDY_FINAL, E_GREEDY_DECAY:{cls.E_GREEDY}\n" 121 | pickle_info = "===Pickle Param===\n" + \ 122 | f" PKL_START, PKL_NUM, PKL_STEP:{cls.PKL_STEP}, {cls.PKL_NUM}, {cls.PKL_START}\n" 123 | 124 | env_info = "===ENV===\n" + \ 125 | f" DISCOUNT:{cls.DISCOUNT},\n BETA1, BETA2, BETA3:{cls.BETA1},{cls.BETA2},{cls.BETA3}\n" 126 | logger.info( 127 | rewards_info + lr_info + episodes_info + batchsize_info + update_infp + gamma_info + nsteps_info + egreedy_info + pickle_info + env_info) 128 | 129 | # logger.info(cls.__dict__) 130 | 131 | @classmethod 132 | def set_lr(cls, lr): 133 | cls.LR = lr 134 | 135 | @classmethod 136 | def set_nsteps(cls, nsteps): 137 | cls.N_STEPS = nsteps 138 | 139 | @classmethod 140 | def set_batchsize(cls, batchsize): 141 | cls.BATCH_SIZE = batchsize 142 | 143 | @classmethod 144 | def set_egreedy(cls, egreedy): 145 | cls.E_GREEDY_DECAY = egreedy 146 | 147 | @classmethod 148 | def set_gamma(cls, gamma): 149 | cls.REWARD_DECAY = gamma 150 | 151 | @classmethod 152 | def set_update_frequency(cls, update_frequency): 153 | cls.UPDATE_TARGET_FREQUENCY = update_frequency 154 | 155 | @classmethod 156 | def set_tau(cls, tau): 157 | cls.TAU = tau 158 | 159 | @classmethod 160 | def set_rewards(cls, rewards): 161 | cls.REWARD_DEFAULT = rewards 162 | 163 | @classmethod 164 | def print_cls_dict(cls): 165 | print(cls.__dict__) 166 | -------------------------------------------------------------------------------- /RL/env.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : env.py 3 | # @Date : 2022-05-18 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # @From : 6 | import copy 7 | import math 8 | from itertools import zip_longest 9 | from math import exp 10 | from itertools import tee 11 | from functools import reduce 12 | 13 | import numpy as np 14 | import networkx as nx 15 | 16 | # from env_config import * 17 | from config import Config 18 | 19 | 20 | class MulticastEnv: 21 | """ 22 | 类初始化时,需要一个nx.Graph的图 23 | 修改图时,调用self.modify_graph(graph),修改self.graph属性 24 | 调用环境前,使用self.reset(start, ends),重置环境的属性、列表、矩阵等 25 | 使用 self.step(link) 前进一步, 修改当前树,修改枝,修改路由矩阵、列表 26 | 27 | --- 28 | env = MulticastEnv(graph) 29 | env.reset(start_node, end_nodes) 30 | env.step(link) 31 | """ 32 | 33 | def __init__(self, graph, numpy_type=Config.NUMPY_TYPE, normalize=True): 34 | """ 35 | :param graph: nx.graph 36 | """ 37 | self.graph = graph 38 | self.nodes = sorted(graph.nodes) if graph is not None else None 39 | self.edges = sorted(graph.edges, key=lambda x: (x[0], x[1])) if graph is not None else None 40 | self.numpy_type = numpy_type 41 | 42 | self.tree_nodes = set() # 树节点 43 | self.route_graph = nx.Graph() # 路由列表 44 | self.branches = set() # 枝列表 45 | self.branches_graph = nx.Graph() 46 | self.mask = [] 47 | 48 | self.n_actions = None 49 | 50 | self.start = None 51 | self.ends = None 52 | self.ends_constant = None 53 | 54 | self.adj_matrix = None 55 | self.route_matrix = None # 路由矩阵 56 | self.state_matrix = None 57 | self.bw_matrix = None 58 | self.delay_matrix = None 59 | self.loss_matrix = None 60 | 61 | self.normal_bw_matrix = None 62 | self.normal_delay_matrix = None 63 | self.normal_loss_matrix = None 64 | 65 | self.step_normal_bw_matrix = None 66 | self.step_normal_delay_matrix = None 67 | self.step_normal_loss_matrix = None 68 | 69 | self.step_num = 1 70 | self.alley_num = 0 71 | 72 | self.max_bw = np.finfo(np.float32).eps 73 | self.min_bw = 0 74 | 75 | self.max_delay = np.finfo(np.float32).eps 76 | self.min_delay = 0 77 | 78 | self.max_loss = np.finfo(np.float32).eps 79 | self.min_loss = 0 80 | 81 | self.all_link_delay_sum = 0 82 | self.all_link_non_loss_prob = 0 83 | 84 | if graph is not None: 85 | # self.parse_graph_edge_data() # 解析图 86 | self.set_nodes() 87 | self.set_edges() 88 | self.set_adj_matrix() # 设置邻接矩阵 89 | 90 | self.done_reward = None 91 | self.step_reward = None 92 | self.hell_reward = None 93 | self.alley_reward = None 94 | 95 | self.discount = Config.DISCOUNT 96 | 97 | self.beta1 = Config.BETA1 98 | self.beta2 = Config.BETA2 99 | self.beta3 = Config.BETA3 100 | 101 | self.scale = normalize 102 | self.a_step = Config.A_STEP 103 | self.b_step = Config.B_STEP 104 | 105 | self.a_path = Config.A_PATH 106 | self.b_path = Config.B_PATH 107 | 108 | self.start_symbol = Config.START_SYMBOL 109 | self.end_symbol = Config.END_SYMBOL 110 | self.step_symbol = Config.STEP_SYMBOL 111 | self.branch_symbol = Config.BRANCH_SYMBOL 112 | 113 | self.set_reward_conf(Config.REWARD_DEFAULT) 114 | self.set_a_b(0, self.step_reward, 0, self.done_reward) 115 | 116 | def initialize_params(self): 117 | """ 118 | 重置 为空 119 | :return: None 120 | """ 121 | self.tree_nodes = set() # 树节点 122 | self.route_graph = nx.Graph() # 树 123 | self.branches = set() # 枝列表 124 | self.branches_graph = nx.Graph() 125 | self.mask = [] 126 | 127 | self.n_actions = None 128 | 129 | self.start = None 130 | self.ends = None 131 | self.ends_constant = None 132 | 133 | self.adj_matrix = None 134 | self.route_matrix = None # 路由矩阵 135 | self.state_matrix = None 136 | self.bw_matrix = None 137 | self.delay_matrix = None 138 | self.loss_matrix = None 139 | 140 | self.normal_bw_matrix = None 141 | self.normal_delay_matrix = None 142 | self.normal_loss_matrix = None 143 | 144 | self.step_normal_bw_matrix = None 145 | self.step_normal_delay_matrix = None 146 | self.step_normal_loss_matrix = None 147 | 148 | self.step_num = 1 149 | self.alley_num = 0 150 | 151 | self.max_bw = np.finfo(np.float32).eps 152 | self.min_bw = 0 153 | 154 | self.max_delay = np.finfo(np.float32).eps 155 | self.min_delay = 0 156 | 157 | self.max_loss = np.finfo(np.float32).eps 158 | self.min_loss = 0 159 | 160 | self.all_link_delay_sum = 0 161 | self.all_link_non_loss_prob = 0 162 | 163 | def set_params(self): 164 | """ 165 | 设置值 166 | :return: None 167 | """ 168 | self.set_nodes() 169 | self.set_edges() 170 | self.set_adj_matrix() 171 | self.parse_graph_edge_data() 172 | 173 | def parse_graph_edge_data(self): 174 | """ 175 | 解析边的 bw, delay, loss 存入矩阵 176 | :return: bw_matrix, delay_matrix, loss_matrix 177 | """ 178 | _num = len(self.nodes) 179 | # 全 - 1 矩阵 180 | bw_matrix = - np.ones((_num, _num), dtype=self.numpy_type) 181 | delay_matrix = - np.ones((_num, _num), dtype=self.numpy_type) 182 | loss_matrix = - np.ones((_num, _num), dtype=self.numpy_type) 183 | 184 | all_link_delay_sum, all_link_non_loss_prob = 0, 1 185 | for edge in self.graph.edges.data(): 186 | _start, _end, _data = edge 187 | bw_matrix[_start - 1][_end - 1] = _data['bw'] 188 | bw_matrix[_end - 1][_start - 1] = _data['bw'] 189 | 190 | delay_matrix[_start - 1][_end - 1] = _data['delay'] 191 | delay_matrix[_end - 1][_start - 1] = _data['delay'] 192 | all_link_delay_sum += _data['delay'] 193 | 194 | loss_matrix[_start - 1][_end - 1] = _data['loss'] 195 | loss_matrix[_end - 1][_start - 1] = _data['loss'] 196 | all_link_non_loss_prob *= 1 - _data['loss'] 197 | 198 | maximum_spt_delay = list(nx.maximum_spanning_edges(self.graph, weight='delay')) 199 | self.all_link_delay_sum = reduce(lambda x, y: x + y, [link[2]['delay'] for link in maximum_spt_delay]) 200 | maximum_spt_loss = list(nx.maximum_spanning_edges(self.graph, weight='loss')) 201 | self.all_link_non_loss_prob = reduce(lambda x, y: x * y, [1 - link[2]['loss'] for link in maximum_spt_loss]) 202 | 203 | self.max_bw, self.min_bw, bw_matrix = self.get_non_minus_one_max_min(bw_matrix) 204 | self.max_delay, self.min_delay, delay_matrix = self.get_non_minus_one_max_min(delay_matrix) 205 | self.max_loss, self.min_loss, loss_matrix = self.get_non_minus_one_max_min(loss_matrix) 206 | 207 | self.bw_matrix = bw_matrix 208 | self.delay_matrix = delay_matrix 209 | self.loss_matrix = loss_matrix 210 | 211 | self.normal_bw_matrix, self.step_normal_bw_matrix = self.normalize_param_matrix(bw_matrix, self.max_bw, 212 | self.min_bw) 213 | self.normal_delay_matrix, self.step_normal_delay_matrix = self.normalize_param_matrix(delay_matrix, 214 | self.max_delay, 215 | self.min_delay) 216 | self.normal_loss_matrix, self.step_normal_loss_matrix = self.normalize_param_matrix(loss_matrix, self.max_loss, 217 | self.min_loss) 218 | return bw_matrix, delay_matrix, loss_matrix, all_link_delay_sum 219 | 220 | def normalize_param_matrix(self, matrix, max_value, min_value): 221 | """ 222 | 将矩阵按最大最小值正则 值在0到1之间 223 | """ 224 | normal_m = (matrix - min_value) / (max_value + 1e-6) 225 | normal_m -= normal_m * np.identity(len(self.nodes)) 226 | 227 | step_normal_m = self.a_step + normal_m * (self.b_step - self.a_step) 228 | return normal_m, step_normal_m 229 | 230 | def get_adjacent_link_params_max_min(self, node): 231 | """ 232 | 获得当前节点的周围链路的 参数 的最大最小值 233 | :param node: 当前节点 234 | :return: [(max_bw, min_bw), (max_delay, min_delay), (max_loss, min_loss)] 235 | """ 236 | adj_index = self.get_node_adj_index_list(node) 237 | max_bw = self.bw_matrix[self.node_to_index(node), adj_index].max() 238 | min_bw = self.bw_matrix[self.node_to_index(node), adj_index].min() 239 | 240 | max_delay = self.delay_matrix[self.node_to_index(node), adj_index].max() 241 | min_delay = self.delay_matrix[self.node_to_index(node), adj_index].min() 242 | 243 | max_loss = self.loss_matrix[self.node_to_index(node), adj_index].max() 244 | min_loss = self.loss_matrix[self.node_to_index(node), adj_index].min() 245 | 246 | return [(min_bw, max_bw), (min_delay, max_delay), (min_loss, max_loss)] 247 | 248 | @staticmethod 249 | def get_non_minus_one_max_min(matrix): 250 | """ 251 | 返回除去-1的最大最小值 252 | :param matrix: 要求最大最小的矩阵 253 | :return: max, _min, matrix 【最大最小值, matrix将-1改为0】 254 | """ 255 | non_minus_one_mask = np.where(matrix != -1) 256 | _max = matrix[non_minus_one_mask].max() 257 | _min = matrix[non_minus_one_mask].min() 258 | matrix[np.where(matrix == -1)] = 0 259 | return _max, _min, matrix 260 | 261 | def modify_graph(self, graph: nx.Graph): 262 | """ 263 | 修改 属性 并返回 graph 264 | :param graph: networkx的图 265 | :return: self.graph 266 | """ 267 | self.graph = graph 268 | return self.graph 269 | 270 | def set_nodes(self): 271 | """ 272 | 修改并返回 nodes 273 | :return: self.nodes 274 | """ 275 | 276 | self.nodes = sorted(self.graph.nodes) 277 | return self.nodes 278 | 279 | def set_edges(self): 280 | """ 281 | 设置 边 282 | :return: None 283 | """ 284 | # [(1, 3), (1, 4), (1, 5), (1, 9), (1, 11), 285 | # (2, 4), (2, 5), (2, 6), (4, 5), (4, 6), (4, 7), 286 | # (4, 9), (4, 11), (5, 6), (5, 7), (5, 8), (7, 8), 287 | # (7, 9), (8, 14), (9, 10), (9, 13), (12, 13), (12, 14)] 288 | edges = [(e[0], e[1]) if e[0] < e[1] else (e[1], e[0]) for e in self.graph.edges] 289 | self.edges = sorted(edges, key=lambda x: (x[0], x[1])) 290 | 291 | def set_adj_matrix(self): 292 | """ 293 | 设置并返回邻接矩阵 294 | :return: adj_m 295 | """ 296 | adj_m = nx.adjacency_matrix(self.graph, self.nodes).todense() 297 | self.adj_matrix = np.array(adj_m, dtype=self.numpy_type) 298 | return self.adj_matrix 299 | 300 | def read_pickle_and_modify(self, path): 301 | """ 302 | 读取图graph的pickle文件 303 | 初始化所有参数 304 | :param path: pickle文件路径 305 | :return: 图 nx.graph 306 | """ 307 | pkl_graph = nx.read_gpickle(path) 308 | self.modify_graph(pkl_graph) 309 | 310 | self.initialize_params() 311 | self.set_params() 312 | 313 | return pkl_graph 314 | 315 | @staticmethod 316 | def node_to_index(node): 317 | """ 318 | 节点从1起, 索引从0起,将节点号转为索引号 319 | :param node: 节点号 320 | :return: 索引号 321 | """ 322 | return node - 1 323 | 324 | @staticmethod 325 | def index_to_node(index): 326 | """ 327 | 索引号转为节点号 328 | :param index: 索引号 329 | :return: 节点号 330 | """ 331 | return index + 1 332 | 333 | def add_tree_node(self, node): 334 | """ 335 | 向树中添加节点 336 | """ 337 | self.tree_nodes.add(node) 338 | 339 | def get_node_adj_index_list(self, node): 340 | """ 341 | 根据当前节点获得邻居节点的索引号 342 | :param node: 节点序号 343 | :return: 邻居节点的索引号列表 344 | """ 345 | index = self.node_to_index(node) 346 | adj_ids = np.nonzero(self.adj_matrix[index]) # return tuple 347 | adj_index_list = adj_ids[0].tolist() 348 | return adj_index_list 349 | 350 | def modify_branches(self, node): 351 | """ 352 | 修改枝列表 353 | 修改之前需要先 将满足条件的node添加到tree node, self.add_tree_node(node) 354 | prim算法中有可添加链路方法 355 | 2022/3/17 大改 356 | :return: None 357 | """ 358 | # A. 修改 branch 359 | # 取单行, 即为node的邻居节点行, 是邻居节点则为1, 否则是0 360 | adj_ids_list = self.get_node_adj_index_list(node) 361 | adj_nodes = set(map(self.index_to_node, adj_ids_list)) # 序号转为node 362 | # 1. 更新branches中的节点 363 | self.branches.update(adj_nodes) 364 | # 2. 移除branches中的tree集合的节点 365 | self.branches.difference_update(self.tree_nodes) 366 | 367 | # B. 修改 branches_graph 368 | # 1. 枝节点的周围节点添加进去 369 | for u, v in zip_longest([node], list(adj_nodes), fillvalue=node): # zip_longest 370 | if v in self.tree_nodes: 371 | self.branches_graph.remove_edge(u, v) 372 | else: 373 | self.branches_graph.add_edge(u, v) 374 | 375 | def get_branches_matrix(self): 376 | """ 377 | 获得当前枝的邻接矩阵表示 378 | """ 379 | branches_m = np.zeros((len(self.nodes), len(self.nodes)), dtype=self.numpy_type) 380 | for e in self.branches_graph.edges: 381 | u, v = e 382 | u = self.node_to_index(u) 383 | v = self.node_to_index(v) 384 | branches_m[u][v] = self.branch_symbol 385 | branches_m[v][u] = self.branch_symbol 386 | return branches_m 387 | 388 | def add_to_route_graph(self, link): 389 | """ 390 | 添加link到 self.route_graph中 391 | :param link: 链路 392 | :return: None 393 | """ 394 | self.route_graph.add_edge(*link) 395 | 396 | def get_mask_of_current_tree_branch(self): 397 | """ 398 | 获得是当前树的branch的mask 399 | (1, 0, 0, 1, ...) 400 | :return: mask 401 | """ 402 | _branches = self.branches_graph.edges 403 | edges = self.edges # 这个list要保证不变 404 | mask = np.zeros(len(self.graph.edges), dtype=bool) 405 | for edge in _branches: 406 | if edge in edges: 407 | # 将索引位置设置为True 408 | mask[edges.index(edge)] = True 409 | elif edge[::-1] in edges: 410 | mask[edges.index(edge[::-1])] = True 411 | else: 412 | raise IndexError("edge not in branches_graph") 413 | 414 | self.mask = mask 415 | 416 | return mask 417 | 418 | def _judge_link(self, link): 419 | """ 420 | 弃用,link[1]错误问题 421 | 判断link是否满足当前的情况,是否是当前树的枝 422 | :param link: 下一步链路 423 | :return: True or False 424 | """ 425 | # 是否是树枝,否则返回None 426 | end_node = link[1] 427 | if end_node in self.branches: 428 | return True 429 | else: 430 | return False 431 | 432 | def judge_end_node(self, link): 433 | """ 434 | 判断 link 中哪个是 长出来的新枝,并修正link方向 435 | :param link: 下一跳的link 436 | :return: tree_node, branch_node 437 | """ 438 | tree_node = self.tree_nodes & set(link) # 取交 439 | if len(tree_node) == 1: 440 | branch_node = set(link) - self.tree_nodes # link中包含而tree_nodes中不包含 441 | 442 | try: 443 | return list(tree_node)[0], list(branch_node)[0] 444 | except IndexError: 445 | raise IndexError(f"({link}, {tree_node}, {branch_node}, {self.tree_nodes})") 446 | else: 447 | return None 448 | 449 | def judge_link(self, link): 450 | """ 451 | 判断 link 是否是枝 452 | :param link: 下一跳的link 453 | :return: True or False 454 | """ 455 | if link in self.branches_graph.edges: 456 | return True 457 | else: 458 | return False 459 | 460 | def reset(self, start, ends): 461 | """ 462 | 构建初始的路径矩阵, 463 | :param start: 1, 2,... 比标号多1, 索引标号从0 起 464 | :param ends: list 465 | :return: route_matrix: numpy matrix 466 | """ 467 | if self.graph is None: 468 | raise Exception("graph is None") 469 | 470 | self.start = copy.deepcopy(start) 471 | self.ends = copy.deepcopy(ends) 472 | self.ends_constant = copy.deepcopy(ends) 473 | 474 | _num = len(self.nodes) 475 | # 全 0 矩阵 476 | route_matrix = np.zeros((_num, _num), dtype=self.numpy_type) 477 | # 将路径矩阵的源节点设置为 1 478 | _idx = self.node_to_index(start) 479 | route_matrix[_idx, _idx] = self.start_symbol 480 | # 将目的节点设置为 -1 481 | for end in self.ends: 482 | _idx = self.node_to_index(end) 483 | route_matrix[_idx, _idx] = self.end_symbol 484 | 485 | self.route_matrix = route_matrix 486 | self.state_matrix = route_matrix 487 | 488 | # 添加源节点到树 489 | self.add_tree_node(self.start) 490 | # 添加源节点的枝 491 | self.modify_branches(self.start) 492 | # 获得当前树的枝mask 493 | self.get_mask_of_current_tree_branch() 494 | 495 | return route_matrix 496 | 497 | def step(self, link): 498 | """ 499 | 更新,将新的路径放入route_matrix 500 | 2022/3/17 不是tree中的节点作为 end_node 而不是 link[1]作为end_node 501 | :param link: 链路 如(1, 2) 502 | :return: state_: numpy.matrix, reward_score: float, route_done: bool 503 | """ 504 | assert self.hell_reward is not None 505 | # 判断link是否满足当前的情况,是否是当前 树的枝 506 | link = self.judge_end_node(link) 507 | if link: 508 | tree_node, branch_node = link 509 | # 1. 路由矩阵打上记号 510 | self.route_matrix[self.node_to_index(tree_node)][self.node_to_index(branch_node)] = self.step_symbol 511 | self.route_matrix[self.node_to_index(branch_node)][self.node_to_index(tree_node)] = self.step_symbol 512 | # 2.1 添加目的节点到树节点集合中 513 | self.add_tree_node(branch_node) 514 | # 2.2 修改枝列表 515 | self.modify_branches(branch_node) 516 | # 2.3 获得当前树的枝mask 517 | self.get_mask_of_current_tree_branch() 518 | # 3. 链路添加到路由列表中 519 | self.add_to_route_graph((tree_node, branch_node)) 520 | 521 | # 判断是结束了还是往前进了一步 522 | _judge_flag = self._judge_end(branch_node) 523 | if _judge_flag == 'ALL': 524 | route_done = True 525 | state_ = None # 这个动作后状态,无动作 526 | reward_score = self.calculate_path_score() 527 | # reward_score = self.calculate_link_reward(link) 528 | # reward_score = (self.discount ** self.step_num) * reward_score 529 | # reward_score = self.done_reward + self.step_num * self.step_reward 530 | alley, _ = self.find_blind_alley() 531 | # reward_score += alley * self.alley_reward 532 | self.alley_num = alley 533 | 534 | elif _judge_flag == 'PART': 535 | route_done = False 536 | state_ = self.route_matrix 537 | reward_score = self.calculate_link_reward(link) 538 | # reward_score = self.calculate_path_score() 539 | # alley = self.find_blind_alley() 540 | # reward_score += alley * self.alley_reward 541 | elif _judge_flag == "NOT": 542 | route_done = False 543 | state_ = self.route_matrix 544 | reward_score = self.calculate_link_reward(link) 545 | else: 546 | raise Exception("link judge error") 547 | else: 548 | _judge_flag = "HELL" 549 | route_done = False 550 | state_ = self.route_matrix 551 | reward_score = self.hell_reward 552 | 553 | self.step_num += 1 554 | return state_, reward_score, route_done, _judge_flag 555 | 556 | def _judge_end(self, next_node): 557 | """ 558 | 判断是否是已经结束。 559 | 如果目的节点空了,表示已经找到所有目的节点 560 | :param next_node: 下一跳 561 | :return: "ALL", "PART", "NOT" 562 | """ 563 | assert next_node is not None, "next_node is None" 564 | if next_node in self.ends: 565 | _i = self.ends.index(next_node) 566 | self.ends.pop(_i) 567 | 568 | # 目的节点列表为空 569 | if len(self.ends) == 0: 570 | return "ALL" 571 | else: 572 | return "PART" 573 | return "NOT" 574 | 575 | def reward_score(self, link): 576 | """ 577 | 奖励 578 | :param link: 链路 579 | :return: reward 580 | """ 581 | score = self.calculate_link_reward(link) 582 | return score 583 | 584 | def step_reward_exp(self): 585 | """ 586 | step_reward * e**(1/nodes_num * step_num) 587 | :return: score 588 | """ 589 | score = self.step_reward * exp(1 / len(self.nodes) * self.step_num - 1) 590 | return score 591 | 592 | def discount_reward(self, score): 593 | """ 594 | 对reward进行discount 处理 score = score * discount ** step_num 595 | :param score: reward 596 | :return: score 597 | """ 598 | score *= self.discount ** (self.step_num - 1) 599 | return score 600 | 601 | def link_reward_func(self, bw, delay, loss, b): 602 | """ 603 | 计算reward 604 | R = β1 * bw + β2 * (1 - delay) + β3 * (1 - loss) 605 | :param bw: 带宽 606 | :param delay: 时延 607 | :param loss: 丢包率 608 | :param b: 上限 609 | """ 610 | return self.beta1 * bw + self.beta2 * (b - delay) + self.beta3 * (b - loss) 611 | # return -(self.beta1 * (b - bw) + self.beta2 * delay + self.beta3 * loss) 612 | 613 | def path_reward_func(self, bw, delay, loss, b): 614 | return self.beta1 * bw + self.beta2 * (b - delay) + self.beta3 * loss 615 | # return -(self.beta1 * (b - bw) + self.beta2 * delay + self.beta3 * (b - loss)) 616 | 617 | def calculate_link_reward(self, link): 618 | """ 619 | 计算每步选择link reward R 620 | :return: reward 621 | """ 622 | e0, e1 = self.node_to_index(link[0]), self.node_to_index(link[1]) 623 | bw_hat = self.step_normal_bw_matrix[e0, e1] 624 | delay_hat = self.step_normal_delay_matrix[e0, e1] 625 | loss_hat = self.step_normal_loss_matrix[e0, e1] 626 | r = self.link_reward_func(bw_hat, delay_hat, loss_hat, self.b_step) 627 | return r 628 | 629 | def calculate_path_score(self): 630 | """ 631 | 计算路径的回报 632 | :return: reward 633 | """ 634 | non_losses = np.array([]) 635 | delays = np.array([]) 636 | 637 | e2e_bw, e2e_delay, e2e_bw_hat, e2e_delay_hat = self.find_end_to_end_max_bw_delay(self.route_graph, self.start, 638 | self.ends_constant) 639 | 640 | for link in self.route_graph.edges: 641 | delay = self.graph[link[0]][link[1]]['delay'] 642 | loss = self.graph[link[0]][link[1]]['loss'] 643 | delays = np.append(delays, delay) 644 | non_losses = np.append(non_losses, 1 - loss) 645 | # delay_hat = e2e_delay_hat 646 | num = len(self.ends_constant) + 1 647 | delay_hat = self.min_max_normalize(delays.sum(), self.min_delay * num, self.all_link_delay_sum, a=self.a_path, 648 | b=self.b_path) 649 | 650 | # prob sum 651 | _min = 1 - self.max_loss 652 | _max = 1 - self.min_loss 653 | non_loss_hat = non_losses.prod() 654 | # non_loss_hat = self.min_max_normalize(non_losses.prod(), _min ** num, _max, a=self.a_path, b=self.b_path) 655 | 656 | # if delays.sum() < self.min_delay * num: 657 | # raise ValueError("delays.sum() < self.min_delay * num") 658 | 659 | r = self.path_reward_func(e2e_bw_hat, delay_hat, non_loss_hat, self.b_path) 660 | return r 661 | 662 | def find_end_to_end_max_bw(self, route_graph, start, ends): 663 | """ 664 | 找到端对端的最大剩余带宽 665 | :param route_graph:组播树 666 | :param start:源节点 667 | :param ends:目的节点 668 | :return bws:端对端的带宽列表 669 | """ 670 | bws = np.array([]) 671 | for end in ends: 672 | p = nx.shortest_path(route_graph, source=start, target=end) 673 | bw = self.max_bw 674 | 675 | def pairwise(iterable): 676 | a, b = tee(iterable) 677 | next(b, None) 678 | return zip(a, b) 679 | 680 | for e in pairwise(p): 681 | if self.graph.edges[e[0], e[1]]['bw'] < bw: 682 | bw = self.graph.edges[e[0], e[1]]['bw'] 683 | else: 684 | pass 685 | bws = np.append(bws, bw) 686 | return bws 687 | 688 | def find_end_to_end_max_bw_delay(self, route_graph, start, ends): 689 | """ 690 | 从源节点到各个目的节点路径,分别的最大剩余带宽 时延 691 | :param route_graph:组播树 692 | :param start:源节点 693 | :param ends:目的节点 694 | :return:最大剩余带宽列表 时延列表 bws, delays, bws_hat, delays_hat 695 | """ 696 | bws = np.array([]) 697 | delays = np.array([]) 698 | delays_hat = np.array([]) 699 | for end in ends: 700 | if end in route_graph.nodes: 701 | p = nx.shortest_path(route_graph, source=start, target=end) 702 | bw = self.max_bw 703 | delay = 0 704 | 705 | def pairwise(iterable): 706 | a, b = tee(iterable) 707 | next(b, None) 708 | return zip(a, b) 709 | 710 | for e in pairwise(p): 711 | if self.graph.edges[e[0], e[1]]['bw'] < bw: 712 | bw = self.graph.edges[e[0], e[1]]['bw'] 713 | else: 714 | pass 715 | delay += self.graph.edges[e[0], e[1]]['delay'] 716 | 717 | bws = np.append(bws, bw) 718 | delays = np.append(delays, delay) 719 | 720 | num = len(self.edges) 721 | delay_hat = self.min_max_normalize(delay, self.min_delay * num, self.all_link_delay_sum, a=self.a_path, 722 | b=self.b_path) 723 | 724 | delays_hat = np.append(delays_hat, delay_hat) 725 | else: 726 | continue 727 | bws_hat = self.min_max_normalize(bws.mean(), self.min_bw, self.max_bw, a=self.a_path, b=self.b_path) 728 | delays_hat = delays_hat.max() 729 | 730 | return bws, delays, bws_hat, delays_hat 731 | 732 | def min_max_normalize(self, x, min_value, max_value, a=None, b=None): 733 | """ 734 | 最大最小标准化 735 | :param x: 要标准化的值 736 | :param min_value: 最小值 737 | :param max_value: 最大值 738 | :param a: 下限 739 | :param b: 上限 740 | :return: 标准化后的x_hat 741 | """ 742 | if a is None: 743 | a = self.a_step 744 | if b is None: 745 | b = self.b_step 746 | 747 | # 加一个很小的值 748 | x_normal = (x - min_value) / (max_value - min_value + np.finfo(np.float32).eps) 749 | x_hat = a + x_normal * (b - a) 750 | return x_hat.astype(self.numpy_type) 751 | 752 | def get_route_params(self, mode='train'): 753 | """ 754 | 计算路径的剩余带宽, 时延, 丢包率 755 | :return : 路径的剩余带宽, 时延, 丢包率 756 | """ 757 | bw, delay, loss = 0, 0, 0 758 | num = 0 759 | alley, r_route_graph = self.find_blind_alley() 760 | for r in r_route_graph.edges: 761 | bw += self.graph[r[0]][r[1]]["bw"] 762 | delay += self.graph[r[0]][r[1]]["delay"] 763 | loss += self.graph[r[0]][r[1]]["loss"] 764 | num += 1 765 | 766 | bws = self.find_end_to_end_max_bw(self.route_graph, self.start, self.ends_constant) 767 | if mode == 'train': 768 | length = len(self.route_graph.edges) 769 | elif mode == 'test': 770 | length = len(r_route_graph.edges) 771 | 772 | return bws.mean(), delay / num, loss / num, length, alley 773 | 774 | def map_action(self, action): 775 | """ 776 | 根据动作的标号,返回选择的链路 777 | 如 action=1, 返回(1, 2) 778 | 779 | 2022/3/17 large modification 780 | 781 | :param action: 动作标号 782 | :return: 下一跳 783 | """ 784 | edge = self.edges[action] 785 | return edge 786 | 787 | def find_blind_alley(self): 788 | """ 789 | 找到死角 790 | """ 791 | alley = 0 792 | r_graph = self.route_graph.copy() 793 | while True: 794 | del_node = [] 795 | for pair in r_graph.degree: 796 | node, degree = pair 797 | if degree == 1 and node not in self.ends_constant and node != self.start: 798 | alley += 1 799 | del_node.append(node) 800 | for node in del_node: 801 | r_graph.remove_node(node) 802 | 803 | if not del_node: 804 | break 805 | 806 | return alley, r_graph 807 | 808 | def set_reward_conf(self, rewards_list): 809 | """ 810 | 设置奖励值 811 | """ 812 | done_reward, step_reward, alley_reward, hell_reward = rewards_list 813 | self.done_reward = done_reward 814 | self.step_reward = step_reward 815 | self.hell_reward = hell_reward 816 | self.alley_reward = alley_reward 817 | 818 | def set_a_b(self, a_step, b_step, a_path, b_path): 819 | """ 820 | 设置归一化ab的值 821 | """ 822 | 823 | self.a_step = a_step 824 | self.a_path = a_path 825 | self.b_step = abs(b_step) 826 | self.b_path = abs(b_path) 827 | -------------------------------------------------------------------------------- /RL/env_config.py: -------------------------------------------------------------------------------- 1 | # # -*- coding: utf-8 -*- 2 | # # @File : env_config.py 3 | # # @Date : 2022-05-18 4 | # # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # # @From : 6 | # # done_reward, step_reward, alley_reward, hell_reward 7 | # # [1.0, 0.01, -0.001, -1], [1.0, 0.1, -0.001, -1], [1.0, 1.0, -0.001, -1] 8 | # 9 | # reward_default = [1.0, 0.1, -0.1, -1] 10 | # 11 | # DISCOUNT = 0.9 12 | # 13 | # BETA1 = 1 14 | # BETA2 = 1 15 | # BETA3 = 1 16 | # 17 | # A_STEP = 0 18 | # B_STEP = None 19 | # 20 | # A_PATH = 0 21 | # B_PATH = 1 22 | # 23 | # START_SYMBOL = 1 24 | # END_SYMBOL = 2 25 | # STEP_SYMBOL = 1 26 | # BRANCH_SYMBOL = 2 27 | # 28 | # 29 | # def traverse_reward_list(reward_list): 30 | # for idx in range(len(reward_list)): 31 | # yield reward_list[idx] 32 | # 33 | # def set_rewards(rewards): 34 | # reward_default = rewards 35 | # return 36 | -------------------------------------------------------------------------------- /RL/log.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : log.py 3 | # @Date : 2022-05-18 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # @From : 6 | import logging 7 | import time 8 | import os 9 | from pathlib import Path 10 | 11 | 12 | class MyLog: 13 | """ 14 | 日志记录 15 | mylog = Mylog() 16 | logger = mylog.logger 17 | logger.info() 18 | logger.warning() 19 | """ 20 | def __init__(self, path: Path, filesave=False, consoleprint=True, name=None): 21 | """ 22 | log 23 | :param path: 运行日志的当前文件 Path(__file__) 24 | :param filesave: 是否存储日志 25 | :param consoleprint: 是否打印到终端 26 | """ 27 | 28 | self.formatter = logging.Formatter( 29 | f"{name}: %(message)s") 30 | self.logger = logging.getLogger(name=__name__) 31 | self.set_log_level() 32 | self.log_path = Path.joinpath(path.parent, 'Logs') 33 | if name is None: 34 | self.log_file = os.path.join(self.log_path, path.stem + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())) 35 | else: 36 | self.log_file = os.path.join(self.log_path, path.stem + name) 37 | 38 | self.mk_log_dir() 39 | # log_path = os.path.join(os.getcwd(), 'Logs') 40 | 41 | if filesave: 42 | self.file_handler() 43 | 44 | if consoleprint: 45 | self.console_handler() 46 | 47 | def set_log_level(self, level=logging.DEBUG): 48 | self.logger.setLevel(level) 49 | 50 | def mk_log_dir(self): 51 | try: 52 | # os.mkdir(log_path) 53 | Path.mkdir(self.log_path) 54 | except FileExistsError: 55 | for child in self.log_path.iterdir(): 56 | if child.stat().st_size == 0: 57 | Path.unlink(child) 58 | 59 | def file_handler(self): 60 | fh = logging.FileHandler(self.log_file + '.log', 61 | mode='w', encoding='utf-8', ) 62 | fh.setLevel(logging.INFO) 63 | 64 | fh.setFormatter(self.formatter) 65 | self.logger.addHandler(fh) 66 | 67 | def console_handler(self): 68 | ch = logging.StreamHandler() 69 | ch.setLevel(logging.DEBUG) 70 | 71 | ch.setFormatter(self.formatter) 72 | 73 | self.logger.addHandler(ch) 74 | 75 | def pd_to_csv(self, dataframe): 76 | dataframe.to_csv(self.log_file + '.csv') 77 | self.logger.info("csv saved") 78 | -------------------------------------------------------------------------------- /RL/net.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : net.py 3 | # @Date : 2021-12-06 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def weight_init(m): 11 | classname = m.__class__.__name__ 12 | if classname.find('Conv2d') != -1: 13 | torch.nn.init.xavier_normal_(m.weight.data) 14 | torch.nn.init.constant_(m.bias.data, 0.0) 15 | elif classname.find('Linear') != -1: 16 | torch.nn.init.xavier_normal_(m.weight) 17 | torch.nn.init.constant_(m.bias, 0.0) 18 | 19 | 20 | class MyMulticastNet(nn.Module): 21 | def __init__(self, states_channel, action_num): 22 | super(MyMulticastNet, self).__init__() 23 | self.conv1 = nn.Conv2d(states_channel, 32, kernel_size=(1, 1)) 24 | self.relu1 = nn.ReLU() 25 | 26 | self.fc1 = nn.Linear(32 * 14 * 14, 512) 27 | self.fc1_relu = nn.ReLU() 28 | 29 | # self.fc2 = nn.Linear(512, 256) 30 | # self.fc2_relu = nn.ReLU() 31 | 32 | self.adv1 = nn.Linear(512, 256) 33 | self.adv_relu = nn.ReLU() 34 | self.adv2 = nn.Linear(256, action_num) 35 | 36 | self.apply(weight_init) 37 | 38 | def forward(self, x): 39 | x = self.relu1(self.conv1(x)) 40 | 41 | x = x.view(x.shape[0], -1) 42 | x = self.fc1_relu(self.fc1(x)) 43 | # x = self.fc2_relu(self.fc2(x)) 44 | 45 | adv = self.adv_relu(self.adv1(x)) 46 | q_value = self.adv2(adv) 47 | 48 | return q_value 49 | 50 | 51 | class MyMulticastNet2(nn.Module): 52 | def __init__(self, states_channel, action_num): 53 | super(MyMulticastNet2, self).__init__() 54 | self.conv1 = nn.Conv2d(states_channel, 32, kernel_size=(1, 1)) 55 | self.relu1 = nn.ReLU() 56 | 57 | self.fc1 = nn.Linear(6272, 512) 58 | self.fc1_relu = nn.ReLU() 59 | 60 | self.fc2 = nn.Linear(512, 256) 61 | self.fc2_relu = nn.ReLU() 62 | 63 | self.adv = nn.Linear(256, action_num) 64 | self.val = nn.Linear(256, 1) 65 | 66 | self.apply(weight_init) 67 | 68 | def forward(self, x): 69 | x = self.relu1(self.conv1(x)) 70 | 71 | x = x.view(x.shape[0], -1) 72 | x = self.fc1_relu(self.fc1(x)) 73 | x = self.fc2_relu(self.fc2(x)) 74 | 75 | sate_value = self.val(x) 76 | advantage_function = self.adv(x) 77 | 78 | return sate_value + (advantage_function - advantage_function.mean()) 79 | 80 | 81 | class MyMulticastNet3(nn.Module): 82 | def __init__(self, states_channel, action_num): 83 | super(MyMulticastNet3, self).__init__() 84 | self.conv1_1 = nn.Conv2d(states_channel, 32, kernel_size=(5, 1)) 85 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=(5, 1)) 86 | 87 | self.conv2_1 = nn.Conv2d(states_channel, 32, kernel_size=(1, 5)) 88 | self.conv2_2 = nn.Conv2d(32, 32, kernel_size=(1, 5)) 89 | 90 | self.fc1 = nn.Linear(5376, 512) 91 | self.fc2 = nn.Linear(512, 256) 92 | 93 | self.adv = nn.Linear(256, action_num) 94 | self.val = nn.Linear(256, 1) 95 | 96 | self.apply(weight_init) 97 | 98 | def forward(self, x): 99 | x1_1 = F.leaky_relu(self.conv1_1(x)) 100 | x1_2 = F.leaky_relu(self.conv1_2(x1_1)) 101 | 102 | x2_1 = F.leaky_relu(self.conv2_1(x)) 103 | x2_2 = F.leaky_relu(self.conv2_2(x2_1)) 104 | 105 | x1_3 = x1_2.view(x.shape[0], -1) 106 | x2_3 = x2_2.view(x.shape[0], -1) 107 | x = torch.cat([x1_3, x2_3], dim=1) 108 | 109 | x = F.leaky_relu(self.fc1(x)) 110 | x = F.leaky_relu(self.fc2(x)) 111 | 112 | sate_value = self.val(x) 113 | advantage_function = self.adv(x) 114 | 115 | return sate_value + (advantage_function - advantage_function.mean()) 116 | 117 | 118 | class MyMulticastNet4(nn.Module): 119 | def __init__(self, states_channel, action_num): 120 | super(MyMulticastNet4, self).__init__() 121 | self.conv1_1 = nn.Conv2d(states_channel, 32, kernel_size=(3, 1)) 122 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=(3, 1)) 123 | self.conv1_3 = nn.Conv2d(32, 32, kernel_size=(3, 1)) 124 | 125 | self.conv2_1 = nn.Conv2d(states_channel, 32, kernel_size=(1, 3)) 126 | self.conv2_2 = nn.Conv2d(32, 32, kernel_size=(1, 3)) 127 | self.conv2_3 = nn.Conv2d(32, 32, kernel_size=(1, 3)) 128 | 129 | self.lstm = nn.LSTMCell(3584, 512) 130 | self.fc1 = nn.Linear(512, 256) 131 | 132 | self.adv = nn.Linear(256, action_num) 133 | self.val = nn.Linear(256, 1) 134 | 135 | self.apply(weight_init) 136 | self.lstm.bias_ih.data.fill_(0) 137 | self.lstm.bias_hh.data.fill_(0) 138 | 139 | def forward(self, inputs): 140 | x, (hx, cx) = inputs 141 | x1_1 = F.leaky_relu(self.conv1_1(x)) 142 | x1_2 = F.leaky_relu(self.conv1_2(x1_1)) 143 | x1_3 = F.leaky_relu(self.conv1_3(x1_2)) 144 | 145 | x2_1 = F.leaky_relu(self.conv2_1(x)) 146 | x2_2 = F.leaky_relu(self.conv2_2(x2_1)) 147 | x2_3 = F.leaky_relu(self.conv2_3(x2_2)) 148 | 149 | x1_3 = x1_3.view(x.shape[0], -1) 150 | x2_3 = x2_3.view(x.shape[0], -1) 151 | x = x1_3 + x2_3 152 | 153 | hx, cx = self.lstm(x, (hx, cx)) 154 | 155 | x = hx 156 | x = F.leaky_relu(self.fc1(x)) 157 | 158 | sate_value = self.val(x) 159 | advantage_function = self.adv(x) 160 | 161 | return sate_value + (advantage_function - advantage_function.mean()), (hx, cx) 162 | -------------------------------------------------------------------------------- /RL/replymemory.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : replymemory.py 3 | # @Date : 2022-05-18 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # @From : 6 | import random 7 | import operator 8 | from collections import namedtuple 9 | import torch 10 | import numpy as np 11 | 12 | from config import Config 13 | 14 | # 使用具名元组 快速建立一个类 15 | Transition = namedtuple('Transition', 16 | ('state', 'action', 'reward', 'next_state')) 17 | 18 | 19 | class ExperienceReplayMemory: 20 | def __init__(self, capacity, torch_type) -> None: 21 | self.capacity = capacity 22 | self.memory = [] 23 | self.position = 0 24 | self.torch_type = torch_type 25 | 26 | def push(self, *args): 27 | """保存变换""" 28 | if len(self.memory) < self.capacity: 29 | self.memory.append(None) 30 | s, a, r, s_ = args 31 | self.memory[self.position] = Transition(s, 32 | a.reshape(1, -1), 33 | torch.tensor(r, dtype=self.torch_type).reshape(1, -1), 34 | s_) 35 | # self.memory[self.position] = Transition(*torch.tensor(args)) 36 | self.position = (self.position + 1) % self.capacity 37 | 38 | def sample(self, batch_size): 39 | return random.sample(self.memory, batch_size), None, None 40 | 41 | def clean_memory(self): 42 | self.memory = [] 43 | 44 | def __len__(self): 45 | return len(self.memory) 46 | 47 | 48 | class SegmentTree: 49 | def __init__(self, capacity, operation, neutral_element): 50 | assert capacity > 0 and capacity & (capacity - 1) == 0 51 | self._capacity = capacity 52 | self._value = [neutral_element for _ in range(2 * capacity)] 53 | self._operation = operation 54 | 55 | def _reduce_helper(self, query_start, query_end, node, node_start, node_end): 56 | if query_start == node_start and query_end == node_end: 57 | return self._value[node] 58 | mid = (node_start + node_end) // 2 59 | if query_end <= mid: 60 | return self._reduce_helper(query_start, query_end, 2 * node, node_start, mid) 61 | else: 62 | if mid + 1 <= query_start: 63 | return self._reduce_helper(query_start, query_end, 2 * node + 1, mid + 1, node_end) 64 | else: 65 | return self._operation( 66 | self._reduce_helper(query_start, mid, 2 * node, node_start, mid), 67 | self._reduce_helper(mid + 1, query_end, 2 * node + 1, mid + 1, node_end) 68 | ) 69 | 70 | def reduce(self, start=0, end=None): 71 | if end is None: 72 | end = self._capacity 73 | if end <= 0: 74 | end += self._capacity 75 | end -= 1 76 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1) 77 | 78 | def __setitem__(self, idx, val): 79 | idx += self._capacity 80 | self._value[idx] = val 81 | idx //= 2 82 | while idx >= 1: 83 | self._value[idx] = self._operation( 84 | self._value[2 * idx], 85 | self._value[2 * idx + 1] 86 | ) 87 | idx //= 2 88 | 89 | def __getitem__(self, idx): 90 | assert 0 <= idx < self._capacity 91 | return self._value[self._capacity + idx] 92 | 93 | 94 | class SumSegmentTree(SegmentTree): 95 | def __init__(self, capacity): 96 | super(SumSegmentTree, self).__init__( 97 | capacity=capacity, 98 | operation=operator.add, 99 | neutral_element=0.0 100 | ) 101 | 102 | def sum(self, start=0, end=None): 103 | return super(SumSegmentTree, self).reduce(start, end) 104 | 105 | def find_prefixsum_idx(self, prefixsum): 106 | try: 107 | assert 0 <= prefixsum <= self.sum() + np.finfo(np.float32).eps 108 | except AssertionError: 109 | print(f"Prefix sum error: {prefixsum}") 110 | exit() 111 | idx = 1 112 | while idx < self._capacity: 113 | if self._value[2 * idx] > prefixsum: 114 | idx = 2 * idx 115 | else: 116 | prefixsum -= self._value[2 * idx] 117 | idx = 2 * idx + 1 118 | return idx - self._capacity 119 | 120 | 121 | class MinSegmentTree(SegmentTree): 122 | def __init__(self, capacity): 123 | super(MinSegmentTree, self).__init__( 124 | capacity=capacity, 125 | operation=min, 126 | neutral_element=float('inf') 127 | ) 128 | 129 | def min(self, start=0, end=None): 130 | return super(MinSegmentTree, self).reduce(start, end) 131 | 132 | 133 | class PrioritizedReplayMemory: 134 | def __init__(self, torch_type, size, alpha=0.6, beta_start=0.4, beta_frames=70000*Config.PKL_NUM): 135 | self.torch_type = torch_type 136 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 137 | self._storage = [] 138 | self._maxsize = size 139 | self._next_idx = 0 140 | 141 | assert alpha >= 0 142 | self._alpha = alpha 143 | 144 | self.beta_start = beta_start 145 | self.beta_frames = beta_frames 146 | self.frame = 1 147 | 148 | it_capacity = 1 149 | while it_capacity < size: 150 | it_capacity *= 2 151 | 152 | self._it_sum = SumSegmentTree(it_capacity) 153 | self._it_min = MinSegmentTree(it_capacity) 154 | self._max_priority = 1.0 155 | 156 | def __len__(self): 157 | return len(self._storage) 158 | 159 | def beta_by_frame(self, frame_idx): 160 | return min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames) 161 | 162 | def push(self, *data): 163 | s, a, r, s_ = data 164 | data = Transition(s, 165 | a.reshape(1, -1), 166 | torch.tensor(r, dtype=self.torch_type).reshape(1, -1), 167 | s_) 168 | 169 | idx = self._next_idx 170 | if self._next_idx >= len(self._storage): 171 | self._storage.append(data) 172 | else: 173 | self._storage[self._next_idx] = data 174 | self._next_idx = (self._next_idx + 1) % self._maxsize 175 | 176 | self._it_sum[idx] = self._max_priority ** self._alpha 177 | self._it_min[idx] = self._max_priority ** self._alpha 178 | 179 | def _encode_sample(self, idxes): 180 | return [self._storage[i] for i in idxes] 181 | 182 | def _sample_proportional(self, batch_size): 183 | res = [] 184 | for _ in range(batch_size): 185 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1) 186 | idx = self._it_sum.find_prefixsum_idx(mass) 187 | res.append(idx) 188 | return res 189 | 190 | def sample(self, batch_size): 191 | idxes = self._sample_proportional(batch_size) 192 | weights = [] 193 | p_min = self._it_min.min() / self._it_sum.sum() 194 | 195 | beta = self.beta_by_frame(self.frame) 196 | 197 | self.frame += 1 198 | max_weight = (p_min * len(self._storage)) ** (-beta) 199 | 200 | for idx in idxes: 201 | p_sample = self._it_sum[idx] / self._it_sum.sum() 202 | weight = (p_sample * len(self._storage)) ** (-beta) 203 | weights.append(weight / max_weight) 204 | weights = torch.tensor(weights, device=self.device, dtype=self.torch_type) 205 | encoded_sample = self._encode_sample(idxes) 206 | return encoded_sample, idxes, weights 207 | 208 | def update_priorities(self, idxes, priorities): 209 | assert len(idxes) == len(priorities) 210 | for idx, priority in zip(idxes, priorities): 211 | assert 0 <= idx < len(self._storage) 212 | self._it_sum[idx] = (priority + np.finfo(np.float32).eps) ** self._alpha 213 | self._it_min[idx] = (priority + np.finfo(np.float32).eps) ** self._alpha 214 | self._max_priority = max(self._max_priority, (priority + np.finfo(np.float32).eps)) 215 | -------------------------------------------------------------------------------- /RL/rl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : rl.py 3 | # @Date : 2022-05-18 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # @From : 6 | import math 7 | import os 8 | import pickle 9 | import random 10 | 11 | from pathlib import Path 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | 16 | from replymemory import Transition, ExperienceReplayMemory, PrioritizedReplayMemory 17 | 18 | 19 | class DQN: 20 | """ 21 | Deep Q Network 22 | 调用 self.flat_state(state) 将二维状态展平成一维状态 23 | 调用 self.choose_action(state) 从policy_net选择一个动作 24 | 调用 self.learn(state, action, reward, next_state)学习 25 | """ 26 | 27 | def __init__(self, 28 | conf, 29 | net): 30 | """ 31 | :params conf: config.Config() 实例 32 | :params net: net.Net 类 33 | """ 34 | 35 | self.state_num = conf.NUM_STATES 36 | self.action_num = conf.NUM_ACTIONS 37 | 38 | self.lr = conf.LR 39 | self.use_decay = conf.USE_DECAY 40 | self.gamma = conf.REWARD_DECAY 41 | self.epsilon_start, self.epsilon_final, self.epsilon_decay = conf.E_GREEDY 42 | self.epsilon = conf.E_GREEDY_ORI 43 | 44 | self.tau = conf.TAU 45 | 46 | self.device = conf.DEVICE 47 | self.batch_size = conf.BATCH_SIZE 48 | self.memory_capacity = conf.MEMORY_CAPACITY 49 | 50 | # self.start_learn = config.START_LEARN 51 | self.clamp = conf.CLAMP 52 | self.torch_type = conf.TORCH_TYPE 53 | 54 | self.experiment_time = conf.TIME 55 | 56 | self.n_steps = conf.N_STEPS 57 | self.n_step_buffer = [] 58 | 59 | self.policy_net, self.target_net = self.initialize_net(net) 60 | 61 | self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr) 62 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[2500, ], gamma=0.1) # [2500, 5000] 63 | # self.loss_func = nn.MSELoss() 64 | # HuberLoss 当delta=1时,等价于smoothl1一样 65 | self.loss_func = nn.SmoothL1Loss(reduction='none') 66 | self.memory_pool = self.initialize_replay_memory("PrioritizedReplayMemory") 67 | 68 | self.learning_step_count = 0 69 | self.update_target_count = 0 70 | self.update_target_frequency = conf.UPDATE_TARGET_FREQUENCY 71 | 72 | self.losses = [] 73 | 74 | self.train_or_eval = 'train' 75 | self.change_model_mode('train') 76 | 77 | self.move_net_to_device() 78 | 79 | Path("./saved_agents").mkdir(exist_ok=True) 80 | 81 | def initialize_replay_memory(self, mode): 82 | """ 83 | 初始化 回放池 84 | :param mode: ExperienceReplayMemory 和 PrioritizedReplayMemory 85 | :return: 经验回放的类的实例 86 | """ 87 | if mode == 'ExperienceReplayMemory': 88 | return ExperienceReplayMemory(self.memory_capacity, self.torch_type) 89 | elif mode == 'PrioritizedReplayMemory': 90 | return PrioritizedReplayMemory(self.torch_type, self.memory_capacity) 91 | 92 | def initialize_net(self, net): 93 | """ 94 | 初始化policy网络 和 target网络 95 | :param net: net.py中的net: nn.Module 96 | :return: policy_net, target_net: net的实例 97 | """ 98 | policy_net = net(self.state_num, self.action_num) 99 | target_net = net(self.state_num, self.action_num) 100 | return policy_net, target_net 101 | 102 | def move_net_to_device(self): 103 | """ 104 | 将网络移动到设备 105 | """ 106 | self.policy_net.to(self.device) 107 | self.target_net.to(self.device) 108 | 109 | def change_model_mode(self, train_or_eval='train'): 110 | """ 111 | 改变模型的模式,model.train() 或者 model.eval() 112 | :param train_or_eval: ['train', 'eval'] 113 | :return: None 114 | """ 115 | assert train_or_eval in ['train', 'eval'] 116 | if train_or_eval == 'train': 117 | self.policy_net.train() 118 | self.target_net.train() 119 | 120 | elif train_or_eval == 'eval': 121 | self.policy_net.eval() 122 | self.target_net.eval() 123 | 124 | self.train_or_eval = train_or_eval 125 | 126 | def decay_epsilon(self, epoch, mode='exp'): 127 | """ 128 | decay e-greedy 129 | :param epoch: 代数 130 | :param mode: 模式 131 | :return: eps 132 | """ 133 | if mode == 'exp': 134 | eps = self.epsilon_final + (self.epsilon_start - self.epsilon_final) * math.exp( 135 | -1. * epoch / self.epsilon_decay) 136 | else: 137 | eps = min(abs(self.epsilon_start - self.epsilon_final * epoch), self.epsilon_final) 138 | return eps 139 | 140 | def choose_action(self, s, epoch): 141 | """ 142 | 选择动作 e_greedy 143 | 2022/3/20 action_value * torch.from_numpy(mask) 如果全为0 那么选的时候可能会选不应该选的 144 | :param s: 状态 145 | :param epoch: 代数 146 | :return: action 147 | """ 148 | if self.use_decay: 149 | epsilon = self.decay_epsilon(epoch) 150 | else: 151 | epsilon = self.epsilon 152 | # temp = random.random() 153 | temp = np.random.rand() 154 | if temp >= epsilon: # greedy policy 155 | action_value = self.policy_net.forward(s.to(self.device)) 156 | # shape [1, actions_num]; max -> (values, indices) 157 | act_node = torch.max(action_value, 1)[1] # 最大值的索引 158 | else: # random policy 159 | act_node = torch.from_numpy(np.random.choice(np.array(range(self.action_num)), size=1)).to(self.device) 160 | # array([indice], dtype) 161 | act_node = act_node[0] 162 | return act_node 163 | 164 | def choose_max_action(self, s): 165 | """ 166 | test时用,每次选择最大action-state value的动作 167 | :param s: 状态 168 | :return act_node: action index 169 | """ 170 | action_value = self.policy_net.forward(s.to(self.device)) 171 | # shape [1, actions_num]; max -> (values, indices) 172 | act_node = torch.max(action_value, 1)[1] # 最大值的索引 173 | return act_node 174 | 175 | def store_transition_in_memory(self, s, a, r, s_): 176 | """ 177 | 存放到经验池中 178 | :param s: 状态s 179 | :param a: 动作a 180 | :param r: 回报r 181 | :param s_: 状态s撇 182 | :return: None 183 | """ 184 | self.n_step_buffer.append((s, a, r, s_)) 185 | if len(self.n_step_buffer) < self.n_steps and s_ is not None: 186 | return 187 | R = sum([self.n_step_buffer[i][2] * (self.gamma ** i) for i in range(self.n_steps)]) 188 | s, a, _, _ = self.n_step_buffer.pop(0) 189 | 190 | self.memory_pool.push(s, a, R, s_) 191 | 192 | def finish_n_steps(self): 193 | """ 194 | 将最后小于 n steps 的那几步,存入memory 195 | """ 196 | while len(self.n_step_buffer) > 0: 197 | R = sum([self.n_step_buffer[i][2] * (self.gamma ** i) for i in range(len(self.n_step_buffer))]) 198 | s, a, _, _ = self.n_step_buffer.pop(0) 199 | self.memory_pool.push(s, a, R, None) 200 | 201 | def get_batch_vars(self): 202 | """ 203 | 获得batch的state,action,reward,next_state 204 | :return: state_batch, action_batch, reward_batch, non_final_next_states, non_final_mask, indices, weight 205 | """ 206 | transitions, indices, weights = self.memory_pool.sample(self.batch_size) 207 | 208 | batch = Transition(*zip(*transitions)) 209 | # 竖着放一起 (B, Hin) 210 | state_batch = torch.cat(batch.state).to(self.device) 211 | action_batch = torch.cat(batch.action).to(self.device) 212 | reward_batch = torch.cat(batch.reward).to(self.device) 213 | 214 | # 计算非最终状态的掩码并连接批处理元素(最终状态将是模拟结束后的状态) 215 | non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, 216 | batch.next_state)), device=self.device, dtype=torch.bool) 217 | 218 | try: 219 | non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]).to(self.device) 220 | non_flag = False 221 | except Exception as e: 222 | non_final_next_states = None 223 | non_flag = True 224 | 225 | return state_batch, action_batch, reward_batch, non_final_next_states, non_flag, non_final_mask, indices, weights 226 | 227 | def calculate_loss(self, batch_vars): 228 | """ 229 | 计算损失 230 | :param batch_vars: state_batch, action_batch, reward_batch, non_final_next_states, non_final_mask, indices, weight 231 | :return: loss 232 | """ 233 | batch_state, batch_action, batch_reward, non_final_next_states, non_flag, non_final_mask, indices, weights = batch_vars 234 | 235 | # 从policy net中根据状态s,获得执行action的values 236 | _out = self.policy_net(batch_state) 237 | state_action_values = torch.gather(_out, 1, batch_action.type(torch.int64)) 238 | 239 | with torch.no_grad(): 240 | # 如果non_final_next_states是全None,那么max_next_state_values就全是0 241 | max_next_state_values = torch.zeros(self.batch_size, dtype=self.torch_type, device=self.device).unsqueeze(1) 242 | if not non_flag: 243 | # 从target net中根据非最终状态,获得相应的value值 244 | max_next_action = self.target_net(non_final_next_states).max(dim=1)[1].view(-1, 1) 245 | max_next_state_values[non_final_mask] = self.target_net(non_final_next_states).gather(1, 246 | max_next_action) 247 | # 计算期望的Q值 248 | expected_state_action_values = (max_next_state_values * self.gamma) + batch_reward 249 | 250 | self.memory_pool.update_priorities(indices, ( 251 | state_action_values - expected_state_action_values).detach().squeeze(1).abs().cpu().numpy().tolist()) 252 | 253 | # 计算Huber损失 254 | loss = self.loss_func(state_action_values, expected_state_action_values) * weights.unsqueeze(1) 255 | loss = loss.mean() 256 | 257 | return loss 258 | 259 | def learn(self, s, a, r, s_, episode): 260 | """ 261 | 从transition中学习 262 | :param s: 状态s 263 | :param a: 动作a 264 | :param r: 回报r 265 | :param s_: 状态s撇 266 | :return: None 267 | """ 268 | if self.train_or_eval == 'eval': 269 | return 270 | 271 | self.store_transition_in_memory(s, a, r, s_) 272 | 273 | if len(self.memory_pool) < self.batch_size: 274 | return 275 | 276 | batch_vars = self.get_batch_vars() 277 | loss = self.calculate_loss(batch_vars) 278 | 279 | # Optimize the model 280 | self.optimizer.zero_grad() 281 | loss.backward() 282 | if self.clamp: 283 | for param in self.policy_net.parameters(): 284 | param.grad.data.clamp_(-1, 1) 285 | 286 | self.optimizer.step() 287 | 288 | self.update_target_model(episode) 289 | self.append_loss_item(loss.data) 290 | 291 | def update_target_model(self, episode): 292 | """ 293 | 更新target网络 294 | :return: None 295 | """ 296 | self.update_target_count += 1 297 | self.update_count = self.update_target_count % self.update_target_frequency 298 | # self.update_count = episode % self.update_target_frequency 299 | if self.update_count == 0: 300 | # self.target_net.load_state_dict(self.policy_net.state_dict()) 301 | self._soft_update() 302 | 303 | def _soft_update(self): 304 | """ 305 | 软更新 306 | """ 307 | for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()): 308 | target_param.data.copy_( 309 | target_param.data * (1.0 - self.tau) + param.data * self.tau 310 | ) 311 | 312 | def append_loss_item(self, loss): 313 | """ 314 | 将loss添加到 losses中 315 | :param loss: 损失 316 | :return: None 317 | """ 318 | 319 | self.losses.append(loss) 320 | 321 | def save_weight(self, name): 322 | """ 323 | 保存模型 和 优化器的参数 324 | :return: None 325 | """ 326 | 327 | torch.save(self.policy_net.state_dict(), 328 | f'./saved_agents/policy_net-{name}.pt') 329 | torch.save(self.target_net.state_dict(), 330 | f'./saved_agents/target_net-{name}.pt') 331 | torch.save(self.optimizer.state_dict(), 332 | f'./saved_agents/optim-{name}.pt') 333 | 334 | def load_weight(self, fname_model, fname_optim): 335 | """ 336 | 加载模型和优化器的参数,policy_net网络参数,target_net网络参数,optimizer参数 337 | :return: None 338 | """ 339 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 340 | if os.path.isfile(fname_model): 341 | self.policy_net.load_state_dict(torch.load(fname_model, map_location=device)) 342 | self.target_net.load_state_dict(self.policy_net.state_dict()) 343 | else: 344 | raise FileNotFoundError 345 | 346 | if fname_optim is not None and os.path.isfile(fname_optim): 347 | self.optimizer.load_state_dict(torch.load(fname_optim)) 348 | 349 | def save_replay(self): 350 | """ 351 | 保存经验池的数据, pickle形式 352 | :return: None 353 | """ 354 | pickle.dump(self.memory_pool, open(f'./saved_agents/exp_replay_agent-{self.experiment_time}.pkl', 'wb')) 355 | 356 | def load_replay(self, fname): 357 | """ 358 | 加载经验池的数据 359 | :return: Nones 360 | """ 361 | if os.path.isfile(fname): 362 | self.memory_pool = pickle.load(open(fname, 'rb')) 363 | -------------------------------------------------------------------------------- /RL/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : test.py 3 | # @Date : 2022-05-18 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # @From : 6 | import os 7 | import pickle 8 | from pathlib import Path 9 | from collections import namedtuple 10 | import random 11 | from collections.abc import Iterable 12 | import time 13 | from itertools import tee 14 | 15 | import networkx 16 | import xml.etree.ElementTree as ET 17 | import numpy as np 18 | import torch 19 | from torch.utils.tensorboard import SummaryWriter 20 | import networkx as nx 21 | from networkx.algorithms.approximation import steiner_tree 22 | import matplotlib.pyplot as plt 23 | 24 | from log import MyLog 25 | from rl import DQN 26 | from env import MulticastEnv 27 | from net import * 28 | from config import Config 29 | from env_config import * 30 | 31 | np.random.seed(2022) 32 | torch.random.manual_seed(2022) 33 | random.seed(2022) 34 | 35 | mylog = MyLog(Path(__file__), filesave=True, consoleprint=True) 36 | logger = mylog.logger 37 | RouteParams = namedtuple("RouteParams", ('bw', 'delay', 'loss')) 38 | plt.style.use("seaborn-whitegrid") 39 | 40 | 41 | class Train: 42 | def __init__(self, Conf, Env, RL, Net, name, mode): 43 | """ 44 | Conf: class, 配置类 45 | Env: class, 环境类 46 | RL: class, 强化学习类 47 | Net: class, 神经网络类 48 | """ 49 | self.graph = None 50 | self.nodes_num = None 51 | self.edges_num = None 52 | 53 | self.state_channel = 4 54 | 55 | self.record_dict = {} # 记录训练时所有 56 | 57 | self.config = self.set_initial_config(Conf) # 配置 58 | self.config.log_params(logger) 59 | 60 | # 1. 设置 env 61 | self.env = self.set_initial_env(Env) 62 | # 2. 初始化 图、节点数、边数 63 | self.set_init_topology() 64 | # 3. 设置config中的NUM_STATES NUM_ACTIONS 65 | self.set_num_states_actions() 66 | # 4. 设置 RL 67 | self.rl = self.set_initial_rl(RL, Net) 68 | 69 | self.writer = SummaryWriter(f"./runs/{name}") if mode != 'eval' else None 70 | 71 | self.reward_list_idx = "" 72 | 73 | def set_init_topology(self): 74 | """ 75 | 1. 解析xml文件 76 | 2. 设置 self.graph 77 | self.nodes_num 78 | self.edges_num 79 | """ 80 | graph, nodes_num, edges_num = self.parse_xml_topology(self.config.xml_topology_path) 81 | self.graph = graph 82 | self.nodes_num = nodes_num 83 | self.edges_num = edges_num 84 | 85 | def set_num_states_actions(self): 86 | """ 87 | 根据解析topo设置config中的状态空间和动作空间,供深度网络使用 88 | """ 89 | state_space_num = self.state_channel # input channel 90 | action_space_num = self.edges_num # output channel 91 | self.config.set_num_states_actions(state_space_num, action_space_num) 92 | 93 | def set_initial_env(self, Env): 94 | """ 95 | Env类初始化 96 | :param Env: 环境类 97 | :return : Env的实例 98 | """ 99 | env = Env(self.graph, self.config.NUMPY_TYPE) 100 | return env 101 | 102 | def set_initial_rl(self, RL, Net): 103 | """ 104 | RL类初始化 105 | :param RL: RL类 如DQN 106 | :param Net: Net类 107 | :return : RL的实例 108 | """ 109 | rl = RL(self.config, Net) 110 | return rl 111 | 112 | def set_initial_config(self, Config): 113 | """ 114 | 配置初始化 115 | :param Config: 配置类 116 | :return: conf实例 117 | """ 118 | conf = Config() 119 | return conf 120 | 121 | @staticmethod 122 | def parse_xml_topology(topology_path): 123 | """ 124 | parse topology from topology.xml 125 | :param topology_path: 拓扑的xml文件路径 126 | :return: topology graph, networkx.Graph() 127 | nodes_num, int 128 | edges_num, int 129 | """ 130 | tree = ET.parse(topology_path) 131 | root = tree.getroot() 132 | topo_element = root.find("topology") 133 | graph = networkx.Graph() 134 | for child in topo_element.iter(): 135 | # parse nodes 136 | if child.tag == 'node': 137 | node_id = int(child.get('id')) 138 | graph.add_node(node_id) 139 | # parse link 140 | elif child.tag == 'link': 141 | from_node = int(child.find('from').get('node')) 142 | to_node = int(child.find('to').get('node')) 143 | graph.add_edge(from_node, to_node) 144 | 145 | nodes_num = len(graph.nodes) 146 | edges_num = len(graph.edges) 147 | 148 | print('nodes: ', nodes_num, '\n', graph.nodes, '\n', 149 | 'edges: ', edges_num, '\n', graph.edges) 150 | return graph, nodes_num, edges_num 151 | 152 | def pkl_file_path_yield(self, pkl_dir, n: int = 3000, step: int = 1): 153 | """ 154 | 生成保存的pickle文件的路径, 按序号递增的方式生成 155 | :param pkl_dir: Path, pkl文件的目录 156 | :param n: 截取 157 | :param step: 间隔 158 | """ 159 | a = os.listdir(pkl_dir) 160 | assert n < len(a), "n should small than len(a)" 161 | b = sorted(a, key=lambda x: int(x.split('-')[0])) 162 | for p in b[:n:step]: 163 | yield pkl_dir / p 164 | 165 | def record_route_params(self, episode, bw, delay, loss): 166 | """ 167 | 将route信息存入字典中 {episode: RouteParams} 168 | :param episode: 训练的第几代, 作为key 169 | :param bw: 剩余带宽和 170 | :param delay: 时延和 171 | :param loss: 丢包率和 172 | :return: None 173 | """ 174 | self.record_dict.setdefault(episode, RouteParams(bw, delay, loss)) 175 | with open("./recorder.pkl", "wb+") as f: 176 | pickle.dump(self.record_dict, f) 177 | 178 | @staticmethod 179 | def record_reward(r): 180 | with open('./reward.pkl', 'wb') as f: 181 | pickle.dump(r, f) 182 | 183 | def choice_multicast_nodes(self): 184 | """ 185 | 随机选择 源节点和目的节点 186 | source 至少 1 187 | destination 至少 2 188 | 所有节点至多 nodes_num - 1 189 | :return: multicast nodes 190 | """ 191 | _a = list(self.graph.nodes()) 192 | _size = np.random.randint(3, self.nodes_num) 193 | # 无放回抽取 194 | multicast_nodes = np.random.choice(_a, size=_size, replace=False) 195 | multicast_nodes = list(multicast_nodes) 196 | start_node = multicast_nodes.pop(0) 197 | end_nodes = multicast_nodes 198 | 199 | return start_node, end_nodes 200 | 201 | @staticmethod 202 | def flat_and_combine_state(route, bw, delay, loss): 203 | """ 204 | 将二维矩阵展平成一维, 将多个一维拼接 205 | :param route: 路由矩阵 206 | :param bw: 剩余带宽矩阵 207 | :param delay: 时延矩阵 208 | :param loss: 丢包率矩阵 209 | :return: 展平的tensor 210 | """ 211 | if route is not None: 212 | flatten_route = torch.flatten(torch.from_numpy(route)).unsqueeze(0) 213 | flatten_bw = torch.flatten(torch.from_numpy(bw)).unsqueeze(0) 214 | flatten_delay = torch.flatten(torch.from_numpy(delay)).unsqueeze(0) 215 | flatten_loss = torch.flatten(torch.from_numpy(loss)).unsqueeze(0) 216 | combine_state = torch.cat([flatten_route, flatten_bw, flatten_delay, flatten_loss], dim=1) 217 | return combine_state 218 | else: 219 | return None 220 | 221 | def combine_state(self, route): 222 | """ 223 | 组合 多channel 224 | :param route: 路由矩阵 225 | :return: 多channel tensor 226 | """ 227 | if route is not None: 228 | if self.state_channel == 1: 229 | combine_state = torch.stack( 230 | [torch.from_numpy(route) + torch.from_numpy(self.env.get_branches_matrix())], dim=0) 231 | elif self.state_channel == 2: 232 | combine_state = torch.stack( 233 | [ 234 | torch.from_numpy(route), 235 | torch.from_numpy(self.env.get_branches_matrix()), 236 | ], 237 | dim=0) 238 | elif self.state_channel == 4: 239 | combine_state = torch.stack( 240 | [ 241 | torch.from_numpy(route) + torch.from_numpy(self.env.get_branches_matrix()), 242 | torch.from_numpy(self.env.normal_bw_matrix), 243 | torch.from_numpy(self.env.normal_delay_matrix), 244 | torch.from_numpy(self.env.normal_loss_matrix) 245 | ], 246 | dim=0) 247 | else: 248 | combine_state = torch.stack( 249 | [ 250 | torch.from_numpy(route), 251 | torch.from_numpy(self.env.get_branches_matrix()), 252 | torch.from_numpy(self.env.normal_bw_matrix), 253 | torch.from_numpy(self.env.normal_delay_matrix), 254 | torch.from_numpy(self.env.normal_loss_matrix) 255 | ], 256 | dim=0) 257 | return torch.unsqueeze(combine_state, dim=0) 258 | else: 259 | return None 260 | 261 | @staticmethod 262 | def kmb_algorithm(graph, src_node, dst_nodes, weight=None): 263 | """ 264 | 经典KMB算法 组播树 265 | 266 | :param graph: networkx.graph 267 | :param src_node: 源节点 268 | :param dst_nodes: 目的节点 269 | :param weight: 计算的权重 270 | :return: 返回图形的最小Steiner树的近似值 271 | """ 272 | terminals = [src_node] + dst_nodes 273 | st_tree = steiner_tree(graph, terminals, weight) 274 | return st_tree 275 | 276 | @staticmethod 277 | def spanning_tree(graph, weight=None): 278 | """ 279 | 生成树算法 280 | :param graph: networkx.graph 281 | :param weight: 计算的权重 282 | :return: iterator 最小生成树 283 | """ 284 | spanning_tree = nx.algorithms.minimum_spanning_tree(graph, weight) 285 | return spanning_tree 286 | 287 | def get_tree_params(self, tree, graph): 288 | """ 289 | 计算tree的bw, delay, loss 参数和 290 | :param tree: 要计算的树 291 | :param graph: 计算图中的数据 292 | :return: bw, delay, loss, len 293 | """ 294 | bw, delay, loss = 0, 0, 0 295 | if isinstance(tree, nx.Graph): 296 | edges = tree.edges 297 | elif isinstance(tree, Iterable): 298 | edges = tree 299 | else: 300 | raise ValueError("tree param error") 301 | num = 0 302 | for r in edges: 303 | bw += graph[r[0]][r[1]]["bw"] 304 | delay += graph[r[0]][r[1]]["delay"] 305 | loss += graph[r[0]][r[1]]["loss"] 306 | num += 1 307 | bw_mean = self.env.find_end_to_end_max_bw(tree, self.env.start, self.env.ends_constant).mean() 308 | return bw_mean, delay / num, loss / num, len(tree.edges) 309 | 310 | @staticmethod 311 | def modify_bw_weight(graph): 312 | """ 313 | 将delay取负,越大表示越小 314 | :param graph: 图 315 | :return: weight 316 | """ 317 | _g = graph.copy() 318 | for edge in graph.edges: 319 | _g[edge[0]][edge[1]]['bw'] = 1 / (graph[edge[0]][edge[1]]['bw'] + 1) 320 | return _g 321 | 322 | def get_kmb_params(self, graph, start_node, end_nodes): 323 | """ 324 | 获得以 bw 为权重的 steiner tree 返回该树的 bw和 325 | 获得以 delay 为权重的 steiner tree 返回该树的 delay和 326 | 获得以 loss 为权重的 steiner tree 返回该树的 loss和 327 | 获得以 hope 为权重的 steiner tree 返回该树的 长度length 328 | :param graph: 图 329 | :param start_node: 源节点 330 | :param end_nodes: 目的节点 331 | :return: bw, delay, loss, length 332 | """ 333 | _g = self.modify_bw_weight(graph) 334 | # kmb算法 计算权重为-bw 335 | kmb_bw_tree = self.kmb_algorithm(_g, start_node, end_nodes, 336 | weight='bw') 337 | bw_bw, bw_delay, bw_loss, bw_length = self.get_tree_params(kmb_bw_tree, graph) 338 | 339 | # kmb算法 计算权重为delay 340 | kmb_delay_tree = self.kmb_algorithm(graph, start_node, end_nodes, 341 | weight='delay') 342 | delay_bw, delay_delay, delay_loss, delay_length = self.get_tree_params(kmb_delay_tree, graph) 343 | 344 | # kmb算法 计算权重为loss 345 | kmb_loss_tree = self.kmb_algorithm(graph, start_node, end_nodes, 346 | weight='loss') 347 | loss_bw, loss_delay, loss_loss, loss_length = self.get_tree_params(kmb_loss_tree, graph) 348 | 349 | # kmb算法 为None 350 | kmb_hope_tree = self.kmb_algorithm(graph, start_node, end_nodes, weight=None) 351 | length_bw, length_delay, length_loss, length_length = self.get_tree_params(kmb_hope_tree, graph) 352 | 353 | bw_ = [bw_bw, delay_bw, loss_bw, length_bw] 354 | delay_ = [bw_delay, delay_delay, loss_delay, length_delay] 355 | loss_ = [bw_loss, delay_loss, loss_loss, length_loss] 356 | length_ = [bw_length, delay_length, loss_length, length_length] 357 | return bw_, delay_, loss_, length_ 358 | 359 | def get_spanning_tree_params(self, graph): 360 | """ 361 | 获得以 bw 为权重的 spanning tree 返回该树的 bw和 362 | 获得以 delay 为权重的 spanning tree 返回该树的 delay和 363 | 获得以 loss 为权重的 spanning tree 返回该树的 loss和 364 | 获得以 hope 为权重的 spanning tree 返回该树的 长度length 365 | :param graph: 366 | :return: 367 | """ 368 | _g = self.modify_bw_weight(graph) 369 | spanning_bw_tree = self.spanning_tree(_g, weight='bw') 370 | bw, _, _, _ = self.get_tree_params(spanning_bw_tree, graph) 371 | spanning_delay_tree = self.spanning_tree(graph, weight='delay') 372 | _, delay, _, _ = self.get_tree_params(spanning_delay_tree, graph) 373 | spanning_loss_tree = self.spanning_tree(graph, weight='loss') 374 | _, _, loss, _ = self.get_tree_params(spanning_loss_tree, graph) 375 | spanning_length_tree = self.spanning_tree(graph, weight=None) 376 | _, _, _, length = self.get_tree_params(spanning_length_tree, graph) 377 | 378 | return bw, delay, loss, length 379 | 380 | def print_train_info(self, episode, index, reward, link): 381 | logger.info(f"[{episode}][{index}] reward: {reward}") 382 | logger.info(f"[{episode}][{index}] tree_nodes: {self.env.tree_nodes}") 383 | logger.info(f"[{episode}][{index}] route_list: {self.env.route_graph.edges}") 384 | logger.info(f"[{episode}][{index}] branches: {self.env.branches}") 385 | logger.info(f"[{episode}][{index}] link: {link}") 386 | logger.info(f"[{episode}][{index}] step_num: {self.env.step_num}") 387 | # logger.info(f'[{episode}][{index}]: {self.env.route_matrix}') 388 | logger.info("=======================================================") 389 | 390 | def update(self): 391 | """ 392 | 状态更新 rl学习 393 | 1. 循环代数, 进行训练 394 | 2. 读取一个graph, 环境reset 395 | 3. while True 直到跑出path 396 | 397 | 2022/3/17 修改link方向BUG 398 | """ 399 | pkl_cut_num = self.config.PKL_CUT_NUM 400 | pkl_step = self.config.PKL_STEP 401 | loss_step = 0 402 | for episode in range(self.config.EPISODES): 403 | # start_node, end_nodes = self.choice_multicast_nodes() 404 | start_node = 12 405 | end_nodes = [2, 4, 11] 406 | 407 | logger.info(f"[{episode}] start_node: {start_node}") 408 | logger.info(f"[{episode}] end_nodes: {end_nodes}") 409 | 410 | episode_reward = np.array([]) 411 | episode_bw, kmb_bw, spanning_bw = np.array([]), np.array([]), np.array([]) 412 | episode_delay, kmb_delay, spanning_delay = np.array([]), np.array([]), np.array([]) 413 | episode_loss, kmb_loss, spanning_loss = np.array([]), np.array([]), np.array([]) 414 | episode_length, kmb_length, spanning_length = np.array([]), np.array([]), np.array([]) 415 | episode_steps = np.array([]) 416 | 417 | for index, pkl_path in enumerate( 418 | self.pkl_file_path_yield(self.config.pkl_weight_path, n=pkl_cut_num, step=pkl_step)): 419 | 420 | self.env.read_pickle_and_modify(pkl_path) 421 | state = self.env.reset(start_node, end_nodes) 422 | 423 | reward_temp = 0 424 | while True: 425 | # 1. 二维 426 | combine_state = self.combine_state(state) 427 | # 2.1 动作选择 428 | action = self.rl.choose_action(combine_state, episode) 429 | # 2.2 将动作映射为链路 430 | link = self.env.map_action(action) 431 | # 3. 环境交互 432 | new_state, reward, done, flag = self.env.step(link) 433 | # 4. 下一个状态 434 | combine_new_state = self.combine_state(new_state) 435 | # 5. RL学习 436 | self.rl.learn(combine_state, action, reward, combine_new_state) 437 | reward_temp += reward 438 | if len(self.rl.losses) > 0: 439 | self.writer.add_scalar("Optim/Loss", self.rl.losses.pop(0), loss_step) 440 | loss_step += 1 441 | if done: 442 | self.rl.finish_n_steps() 443 | 444 | if flag == "ALL": 445 | bw, delay, loss, length = self.env.get_route_params() # 获得路径的所以链路bw和,delay和,loss和 446 | # 添加到数组中 447 | episode_bw = np.append(episode_bw, bw) 448 | episode_delay = np.append(episode_delay, delay) 449 | episode_loss = np.append(episode_loss, loss) 450 | episode_length = np.append(episode_length, length) 451 | 452 | # kmb 算法 453 | bw, delay, loss, length = self.get_kmb_params(self.env.graph, start_node, end_nodes) 454 | kmb_bw = np.append(kmb_bw, bw) 455 | kmb_delay = np.append(kmb_delay, delay) 456 | kmb_loss = np.append(kmb_loss, loss) 457 | kmb_length = np.append(kmb_length, length) 458 | 459 | episode_reward = np.append(episode_reward, reward_temp) 460 | episode_steps = np.append(episode_steps, self.env.step_num - 1) 461 | self.print_train_info(episode, index, reward, link) 462 | break 463 | 464 | # 6. 状态改变 465 | state = new_state 466 | 467 | # self.writer.add_scalars('Episode/reward', 468 | # {"reward": episode_reward.mean(), "reward_max": episode_reward.max(initial=0)}, 469 | # episode) 470 | # self.writer.add_scalars('Episode/steps', 471 | # {"reward": episode_steps.mean(), "reward_max": episode_steps.max(initial=0)}, 472 | # episode) 473 | # 474 | # self.writer.add_scalars('Episode/bw', {"rl": episode_bw.mean(), "kmb_bw": kmb_bw.mean(), 475 | # }, episode) 476 | # self.writer.add_scalars('Episode/delay', {"rl": episode_delay.mean(), "kmb_delay": kmb_delay.mean(), 477 | # }, episode) 478 | # self.writer.add_scalars('Episode/loss', {"rl": episode_loss.mean(), "kmb_loss": kmb_loss.mean(), 479 | # }, episode) 480 | # self.writer.add_scalars('Episode/length', {"rl": episode_length.mean(), "kmb_length": kmb_length.mean(), 481 | # }, episode) 482 | self.writer.add_scalar('Episode/reward', episode_reward.mean(), episode) 483 | self.writer.add_scalar('Episode/steps', episode_steps.mean(), episode) 484 | self.writer.add_scalar('Episode/bw', episode_bw.mean(), episode) 485 | self.writer.add_scalar('Episode/delay', episode_delay.mean(), episode) 486 | self.writer.add_scalar('Episode/loss', episode_loss.mean(), episode) 487 | self.writer.add_scalar('Episode/length', episode_length.mean(), episode) 488 | self.writer.add_scalar("learning_rate", self.rl.optimizer.param_groups[0]['lr'], episode) 489 | self.rl.scheduler.step() 490 | self.rl.save_weight() 491 | 492 | logger.debug('train over') 493 | 494 | def set_reward_list_idx(self, idx): 495 | self.reward_list_idx = idx 496 | 497 | def compare_test(self, 498 | weight_file): 499 | self.rl.change_model_mode('eval') 500 | 501 | pkl_cut_num = self.config.PKL_CUT_NUM 502 | pkl_step = self.config.PKL_STEP 503 | start_node = 12 504 | end_nodes = [2, 4, 11] 505 | 506 | episode_bw, kmb_bw, spanning_bw = [], [], [] 507 | episode_delay, kmb_delay, spanning_delay = [], [], [] 508 | episode_loss, kmb_loss, spanning_loss = [], [], [] 509 | episode_length, kmb_length, spanning_length = [], [], [] 510 | 511 | self.rl.load_weight(weight_file, None) 512 | for index, pkl_path in enumerate( 513 | self.pkl_file_path_yield(self.config.pkl_weight_path, n=pkl_cut_num, step=pkl_step)): 514 | 515 | self.env.read_pickle_and_modify(pkl_path) 516 | state = self.env.reset(start_node, end_nodes) 517 | 518 | while True: 519 | # 1. 二维 520 | combine_state = self.combine_state(state) 521 | # 2.1 动作选择 522 | action = self.rl.choose_max_action(combine_state) 523 | # 2.2 将动作映射为链路 524 | link = self.env.map_action(action) 525 | # 3. 环境交互 526 | new_state, reward, done, flag = self.env.step(link) 527 | if done: 528 | if flag == "ALL": 529 | bw, delay, loss, length = self.env.get_route_params() # 获得路径的所以链路bw和,delay和,loss和 530 | # 添加到数组中 531 | episode_bw.append(bw) 532 | episode_delay.append(delay) 533 | episode_loss.append(loss) 534 | episode_length.append(length) 535 | 536 | # kmb 算法 537 | bw, delay, loss, length = self.get_kmb_params(self.env.graph, start_node, end_nodes) 538 | kmb_bw.append(bw) 539 | kmb_delay.append(delay) 540 | kmb_loss.append(loss) 541 | kmb_length.append(length) 542 | break 543 | 544 | # 6. 状态改变 545 | state = new_state 546 | self.plot_compare_figure(episode_bw, kmb_bw, "traffic", "mean bw", "bw") 547 | self.plot_compare_figure(episode_delay, kmb_delay, "traffic", "mean delay", "delay") 548 | self.plot_compare_figure(episode_loss, kmb_loss, "traffic", "mean loss", "loss") 549 | 550 | # self.plot_compare_figure_subplots(episode_bw, episode_delay, episode_loss, kmb_bw, kmb_delay, kmb_loss) 551 | 552 | def plot_compare_figure(self, rl_result, kmb_result, x_label, y_label, title): 553 | width = 0.18 554 | plt.bar(range(len(kmb_result)), rl_result, width, label='rl') 555 | 556 | kmb_bw = [kmb_result[i][0] for i in range(len(kmb_result))] 557 | kmb_delay = [kmb_result[i][1] for i in range(len(kmb_result))] 558 | kmb_loss = [kmb_result[i][2] for i in range(len(kmb_result))] 559 | 560 | plt.bar([x + width for x in range(len(kmb_result))], kmb_bw, width, label='kmb_bw') 561 | plt.bar([x + 2 * width for x in range(len(kmb_result))], kmb_delay, width, label='kmb_delay') 562 | plt.bar([x + 3 * width for x in range(len(kmb_result))], kmb_loss, width, label='kmb_loss') 563 | 564 | plt.xticks(range(len(kmb_result)), range(len(kmb_result)), rotation=0, fontsize='small') 565 | plt.xlabel(x_label) 566 | plt.ylabel(y_label) 567 | # plt.title(title) 568 | plt.legend(bbox_to_anchor=(0., 1.0), loc='lower left', ncol=4, ) 569 | 570 | _path = Path('./images') 571 | if _path.exists(): 572 | plt.savefig(_path / f'{title}.png') 573 | else: 574 | _path.mkdir(exist_ok=True) 575 | plt.savefig(_path / f'{title}.png') 576 | plt.show() 577 | 578 | def plot_compare_figure_subplots(self, rl_bw_result, rl_delay_result, rl_loss_result, kmb_bw_result, 579 | kmb_delay_result, kmb_loss_result): 580 | width = 0.18 581 | fig, ax = plt.subplots(1, 3) 582 | 583 | def get_bw_delay_loss(result): 584 | _bw = [result[i][0] for i in range(len(result))] 585 | _delay = [result[i][1] for i in range(len(result))] 586 | _loss = [result[i][2] for i in range(len(result))] 587 | return _bw, _delay, _loss 588 | 589 | def ax_plot(i, kmb_result, rl_result): 590 | kmb_bw, kmb_delay, kmb_loss = get_bw_delay_loss(kmb_result) 591 | ax[i].bar(range(len(rl_result)), rl_result, width, label='rl') 592 | ax[i].bar([x + width for x in range(len(kmb_bw))], kmb_bw, width, label='kmb_bw') 593 | ax[i].bar([x + 2 * width for x in range(len(kmb_bw))], kmb_delay, width, label='kmb_delay') 594 | ax[i].bar([x + 3 * width for x in range(len(kmb_bw))], kmb_loss, width, label='kmb_loss') 595 | 596 | # bw 597 | ax_plot(0, kmb_bw_result, rl_bw_result) 598 | ax[0].set_ylabel("mean bw") 599 | ax_plot(1, kmb_delay_result, rl_delay_result) 600 | ax[1].set_ylabel("mean delay") 601 | ax_plot(2, kmb_loss_result, rl_loss_result) 602 | ax[2].set_ylabel("mean loss") 603 | 604 | plt.xticks(range(len(rl_bw_result)), range(len(rl_bw_result)), rotation=0, fontsize='small') 605 | # plt.title(title) 606 | plt.legend(bbox_to_anchor=(0., 1.0), loc='lower left', ncol=4, ) 607 | 608 | _path = Path('./images') 609 | if _path.exists(): 610 | plt.savefig(_path / 'result.png') 611 | else: 612 | _path.mkdir(exist_ok=True) 613 | plt.savefig(_path / f'result.png') 614 | plt.show() 615 | 616 | def read_data_path(self, data_path="./data"): 617 | data_path = Path(data_path) 618 | data_path_dict = {"lr": [], "nsteps": [], "batchsize": [], 'egreedy': [], "gamma": [], "update": [], 619 | "rewardslist": [], "tau": []} 620 | 621 | for data_doc in data_path.iterdir(): 622 | if data_doc.match('*_lr_*'): 623 | data_path_dict['lr'].append(data_doc) 624 | elif data_doc.match('*_nsteps_*'): 625 | data_path_dict['nsteps'].append(data_doc) 626 | elif data_doc.match('*_batchsize_*'): 627 | data_path_dict['batchsize'].append(data_doc) 628 | elif data_doc.match('*_egreedy_*'): 629 | data_path_dict['egreedy'].append(data_doc) 630 | elif data_doc.match('*_gamma_*'): 631 | data_path_dict['gamma'].append(data_doc) 632 | elif data_doc.match('*_updatefrequency_*'): 633 | data_path_dict['update'].append(data_doc) 634 | elif data_doc.match('*_rewardslist_*'): 635 | data_path_dict['rewardslist'].append(data_doc) 636 | elif data_doc.match('*_tau_*'): 637 | data_path_dict['tau'].append(data_doc) 638 | 639 | return data_path_dict 640 | 641 | def get_diff_data_from_multi_file(self, data_path_list): 642 | data_dict = {"bw": [], "delay": [], "loss": [], 'length': [], "final_reward": [], "episode_reward": [], 643 | "steps": []} 644 | 645 | for path in data_path_list: 646 | for child in path.iterdir(): 647 | data = np.load(child) 648 | if child.match("*bw*"): 649 | data_dict['bw'].append(data) 650 | elif child.match('*delay*'): 651 | data_dict['delay'].append(data) 652 | elif child.match("*loss*"): 653 | data_dict['loss'].append(data) 654 | elif child.match("*length*"): 655 | data_dict['length'].append(data) 656 | elif child.match("*final_reward*"): 657 | data_dict['final_reward'].append(data) 658 | elif child.match("*episode_reward*"): 659 | data_dict['episode_reward'].append(data) 660 | elif child.match("*steps*"): 661 | data_dict['steps'].append(data) 662 | 663 | return data_dict 664 | 665 | def get_compare_data(self, data_path="./data"): 666 | data_path_dict = self.read_data_path(data_path) 667 | data_npy_dict = {} 668 | for k in data_path_dict.keys(): 669 | data_npy_dict[k] = self.get_diff_data_from_multi_file(data_path_dict[k]) 670 | 671 | return data_npy_dict 672 | 673 | def plot_line_chart(self, data_list, name): 674 | for data in data_list: 675 | x = range(len(data)) 676 | plt.plot(x, data) 677 | plt.legend() 678 | plt.save('./') 679 | 680 | 681 | if __name__ == '__main__': 682 | pt_file_path = "./saved_agents/policy_net-2022-05-10-16-05-33.pt" 683 | exp_time = time.strftime("%Y%m%d%H%M%S", time.localtime()) 684 | _name = f"[{exp_time}]_test" 685 | train = Train(Config, MulticastEnv, DQN, MyMulticastNet3, name=_name, mode='eval') 686 | train.compare_test(pt_file_path) 687 | -------------------------------------------------------------------------------- /mininet/generate_matrices.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : generate_matrices.py 3 | # @Date : 2021-12-27 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # @From : 6 | import argparse 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import numpy.random 11 | from tmgen.models import modulated_gravity_tm 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | def set_seed(): 16 | numpy.random.seed(args.seed) 17 | 18 | 19 | def generate_tm(): 20 | tm = modulated_gravity_tm(args.num_nodes, args.num_tms, args.mean_traffic, args.pm_ratio, args.t_ratio, 21 | args.diurnal_freq, args.spatial_variance, args.temporal_variance) 22 | 23 | mean_time_tm = [] 24 | for t in range(args.num_tms): 25 | mean_time_tm.append(tm.at_time(t).mean()) 26 | print(f"time: {t} h, mean traffic: {mean_time_tm[-1]}") 27 | 28 | # 构造一个(num_nodes, num_nodes, num_tms)的0-1均匀分布生成的矩阵 29 | _size = (args.num_nodes,) * 2 30 | _size += (args.num_tms,) 31 | 32 | temp = np.random.random(_size) 33 | mask = temp < args.communicate_ratio 34 | communicate_tm = tm.matrix * mask 35 | 36 | mean_communicate_tm = [] 37 | for t in range(args.num_tms): 38 | mean_communicate_tm.append(communicate_tm[:, :, t].mean()) 39 | print(f"time: {t} h, mean communicate nodes traffic: {mean_communicate_tm[-1]}") 40 | 41 | np_save(tm.matrix, "traffic_matrix") 42 | np_save(mean_time_tm, "mean_time_tm") 43 | 44 | np_save(communicate_tm, "communicate_tm") 45 | np_save(mean_communicate_tm, "mean_communicate_tm") 46 | 47 | plot_tm_mean(mean_time_tm, title="mean_time_tm") 48 | plot_tm_mean(mean_communicate_tm, title="mean_communicate_tm") 49 | 50 | 51 | def np_save(file_data, file_name): 52 | Path('./tm_statistic').mkdir(exist_ok=True) 53 | np.save(f'./tm_statistic/{file_name}.npy', file_data) 54 | print(f'save {file_name}') 55 | 56 | 57 | def plot_tm_mean(mean_list, x_label='time', y_label='mean_traffic', title='mean'): 58 | fig = plt.figure() 59 | plt.xlabel(x_label) 60 | plt.ylabel(y_label) 61 | # plt.title(title) 62 | x = list(range(len(mean_list))) 63 | y = mean_list 64 | plt.bar(x, y) 65 | Path("./figure").mkdir(exist_ok=True) 66 | plt.savefig(f"./figure/{title}.pdf", dpi=300, bbox_inches='tight', pad_inches=0) 67 | plt.show() 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser(description="Generate traffic matrices") 72 | parser.add_argument("--seed", default=2020, help="random seed") 73 | parser.add_argument("--num_nodes", default=14, help="number of nodes of network") 74 | parser.add_argument("--num_tms", default=24, help="total number of matrices") 75 | # 1.55 * 1e3 * 0.75 76 | parser.add_argument("--mean_traffic", default=5 * 10 ** 3 * 0.75, help="mean volume of traffic (Kbps)") 77 | parser.add_argument("--pm_ratio", default=1.5, help="peak-to-mean ratio") 78 | parser.add_argument("--t_ratio", default=0.75, help="trough-to-mean ratio") 79 | parser.add_argument("--diurnal_freq", default=1 / 24, help="Frequency of modulation") 80 | parser.add_argument("--spatial_variance", default=500, 81 | help="Variance on the volume of traffic between origin-destination pairs") 82 | parser.add_argument("--temporal_variance", default=0.03, help="Variance on the volume in time") 83 | parser.add_argument("--communicate_ratio", default=0.7, help="percentage of nodes to communicate") 84 | args = parser.parse_args() 85 | 86 | # set_seed() 87 | # generate_tm() 88 | 89 | mean_time_tm = np.load(r"D:\WorkSpace\Hello_Myself\Hello_Multicast\RLMulticastProject\mininet\tm_statistic\tm_statistic\mean_time_tm.npy") 90 | plot_tm_mean(mean_time_tm) -------------------------------------------------------------------------------- /mininet/generate_nodes_topo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : GEANT_23nodes_topo.py.py 3 | # @Date : 2021-12-09 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | import os 6 | import random 7 | from pathlib import Path 8 | import time 9 | import json 10 | import threading 11 | import xml.etree.ElementTree as ET 12 | 13 | import networkx 14 | 15 | from mininet.topo import Topo 16 | from mininet.net import Mininet 17 | from mininet.node import RemoteController 18 | from mininet.link import TCLink 19 | from mininet.cli import CLI 20 | from mininet.log import setLogLevel 21 | from mininet.util import dumpNodeConnections 22 | 23 | random.seed(2020) 24 | 25 | 26 | def generate_port(node_idx1, node_idx2): 27 | if (node_idx2 > 9) and (node_idx1 > 9): 28 | port = str(node_idx1) + "0" + str(node_idx2) 29 | else: 30 | port = str(node_idx1) + "00" + str(node_idx2) # test 31 | 32 | return int(port) 33 | 34 | 35 | def generate_switch_port(graph): 36 | switch_port_dict = {} 37 | for node in graph.nodes: 38 | switch_port_dict.setdefault(node, list(range(graph.degree[node]))) 39 | return switch_port_dict 40 | 41 | 42 | def parse_xml_topology(topology_path): 43 | """ 44 | parse topology from topology.xml 45 | :return: topology graph, networkx.Graph() 46 | nodes_num, int 47 | edges_num, int 48 | """ 49 | tree = ET.parse(topology_path) 50 | root = tree.getroot() 51 | topo_element = root.find("topology") 52 | graph = networkx.Graph() 53 | for child in topo_element.iter(): 54 | # parse nodes 55 | if child.tag == 'node': 56 | node_id = int(child.get('id')) 57 | graph.add_node(node_id) 58 | # parse link 59 | elif child.tag == 'link': 60 | from_node = int(child.find('from').get('node')) 61 | to_node = int(child.find('to').get('node')) 62 | graph.add_edge(from_node, to_node) 63 | 64 | nodes_num = len(graph.nodes) 65 | edges_num = len(graph.edges) 66 | 67 | print('nodes: ', nodes_num, '\n', graph.nodes, '\n', 68 | 'edges: ', edges_num, '\n', graph.edges) 69 | return graph, nodes_num, edges_num 70 | 71 | 72 | def create_topo_links_info_xml(path, links_info): 73 | """ 74 | 75 | (switch1, switch2) 76 | (1, 1) 77 | 100 78 | 5ms 79 | 1 80 | 81 | 82 | :param path: 保存路径 83 | :param links_info: 链路信息字典 {link: {ports, bw, delay, loss}} 84 | :return: None 85 | """ 86 | # 根节点 87 | root = ET.Element('links_info') 88 | 89 | for link, info in links_info.items(): 90 | # 子节点 91 | child = ET.SubElement(root, 'links') 92 | child.text = str(link) 93 | 94 | # 二级子节点 95 | sub_child1 = ET.SubElement(child, 'ports') 96 | sub_child1.text = str((info['port1'], info['port2'])) 97 | 98 | sub_child2 = ET.SubElement(child, 'bw') 99 | sub_child2.text = str(info['bw']) 100 | 101 | sub_child2 = ET.SubElement(child, 'delay') 102 | sub_child2.text = str(info['delay']) 103 | 104 | sub_child2 = ET.SubElement(child, 'loss') 105 | sub_child2.text = str(info['loss']) 106 | 107 | tree = ET.ElementTree(root) 108 | Path(path).parent.mkdir(exist_ok=True) 109 | tree.write(path, encoding='utf-8', xml_declaration=True) 110 | print('saved links info xml.') 111 | 112 | 113 | def get_mininet_device(net, idx: list, device='h'): 114 | """ 115 | 获得idx中mininet的实例, 如 h1, h2 ...; s1, s2 ... 116 | :param net: mininet网络实例 117 | :param idx: 设备标号集合, list 118 | :param device: 设备名称 'h', 's' 119 | :return d: dict{idx: 设备mininet实例} 120 | """ 121 | d = {} 122 | for i in idx: 123 | d.setdefault(i, net.get(f'{device}{i}')) 124 | 125 | return d 126 | 127 | 128 | def run_corresponding_sh_script(devices: dict, label_path): 129 | """ 130 | 对应的host运行对应的shell脚本 131 | :param devices: {idx: device} 132 | :param label_path: './24nodes/TM-{}/{}/{}_' 133 | """ 134 | p = label_path + '{}.sh' 135 | for i, d in devices.items(): 136 | if i < 9: 137 | i = f'0{i}' 138 | else: 139 | i = f'{i}' 140 | p = p.format(i) 141 | _cmd = f'bash {p}' 142 | d.cmd(_cmd) 143 | print(f"---> complete run {label_path}") 144 | 145 | 146 | def run_ip_add_default(hosts: dict): 147 | """ 148 | 运行 ip route add default via 10.0.0.x 命令 149 | """ 150 | _cmd = 'ip route add default via 10.0.0.' 151 | for i, h in hosts.items(): 152 | print(_cmd + str(i)) 153 | h.cmd(_cmd + str(i)) 154 | print("---> run ip add default complete") 155 | 156 | 157 | def _test_cmd(devices: dict, my_cmd): 158 | for i, d in devices.items(): 159 | d.cmd(my_cmd) 160 | print(f'exec {my_cmd}zzz{i}') 161 | # print(f'return {r}') 162 | 163 | 164 | def run_iperf(path, host): 165 | _cmd = 'bash ' + path + '&' 166 | host.cmd(_cmd) 167 | 168 | 169 | def all_host_run_iperf(hosts: dict, path, finish_file): 170 | """ 171 | path = r'./iperfTM/' 172 | """ 173 | idxs = len(os.listdir(path)) 174 | path = path + '/TM-' 175 | for idx in range(idxs): 176 | script_path = path + str(idx) 177 | for i, h in hosts.items(): 178 | servers_cmd = script_path + '/Servers/server_' + str(i) + '.sh' 179 | _cmd = 'bash ' 180 | print(_cmd + servers_cmd) 181 | h.cmd(_cmd + servers_cmd) 182 | 183 | for i, h in hosts.items(): 184 | clients_cmd = script_path + '/Clients/client_' + str(i) + '.sh' 185 | _cmd = 'bash ' 186 | print(_cmd + clients_cmd + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 187 | h.cmd(_cmd + clients_cmd) 188 | 189 | time.sleep(300) 190 | 191 | write_iperf_time(finish_file) 192 | 193 | 194 | def write_pinall_time(finish_file): 195 | with open(finish_file, "w+") as f: 196 | _content = { 197 | "ping_all_finish_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), 198 | "start_save_flag": True, 199 | "finish_flag": False 200 | } 201 | json.dump(_content, f) 202 | 203 | 204 | def write_iperf_time(finish_file): 205 | with open(finish_file, "r+") as f: 206 | _read = json.load(f) 207 | _content = { 208 | "iperf_finish_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()), 209 | "finish_flag": True, 210 | } 211 | _read.update(_content) 212 | 213 | with open(finish_file, "w+") as f: 214 | json.dump(_read, f) 215 | 216 | 217 | def remove_finish_file(finish_file): 218 | try: 219 | os.remove(finish_file) 220 | except FileNotFoundError: 221 | pass 222 | 223 | 224 | def net_h1_ping_others(net): 225 | hosts = net.hosts 226 | for h in hosts[1:]: 227 | net.ping((hosts[0], h)) 228 | 229 | 230 | class GEANT23nodesTopo(Topo): 231 | def __init__(self, graph): 232 | super(GEANT23nodesTopo, self).__init__() 233 | self.node_idx = graph.nodes 234 | self.edges_pairs = graph.edges 235 | self.bw1 = 100 # Gbps -> M 236 | self.bw2 = 25 # Gbps -> M 237 | self.bw3 = 1.15 # Mbps 238 | self.bw4 = 100 # host -- switch 239 | self.delay = 20 240 | self.loss = 10 241 | 242 | self.host_port = 9 243 | self.snooper_port = 10 244 | 245 | self.bw1_links = [(12, 22), (12, 10), (12, 2), (13, 17), (4, 2), (4, 16), (1, 3), (1, 7), (1, 16), (3, 10), 246 | (3, 21), (10, 16), (10, 17), (7, 17), (7, 2), (7, 21), (16, 9), (20, 17)] 247 | self.bw2_links = [(13, 19), (13, 2), (19, 7), (23, 17), (23, 2), (8, 5), (8, 9), (18, 2), (18, 21), (5, 16), 248 | (3, 11), (10, 11), (22, 20), (20, 15), (9, 15)] 249 | self.bw3_links = [(13, 14), (19, 6), (3, 14), (7, 6)] 250 | 251 | def _return_bw(link: tuple): 252 | if link in self.bw1_links: 253 | return self.bw1 254 | elif link in self.bw2_links: 255 | return self.bw2 256 | elif link in self.bw3_links: 257 | return self.bw3 258 | else: 259 | raise ValueError 260 | 261 | # 添加交换机 262 | switches = {} 263 | for s in self.node_idx: 264 | switches.setdefault(s, self.addSwitch('s{0}'.format(s))) 265 | print('添加交换机:', s) 266 | 267 | switch_port_dict = generate_switch_port(graph) 268 | links_info = {} 269 | # 添加链路 270 | for l in self.edges_pairs: 271 | port1 = switch_port_dict[l[0]].pop(0) + 1 272 | port2 = switch_port_dict[l[1]].pop(0) + 1 273 | bw = _return_bw(l) 274 | 275 | _d = str(random.randint(0, self.delay)) + 'ms' 276 | _l = random.randint(0, self.loss) 277 | 278 | self.addLink(switches[l[0]], switches[l[1]], port1=port1, port2=port2, 279 | bw=bw, delay=_d, loss=_l) 280 | 281 | links_info.setdefault(l, {"port1": port1, "port2": port2, "bw": bw, "delay": _d, "loss": _l}) 282 | 283 | create_topo_links_info_xml(links_info_xml_path, links_info) 284 | 285 | # 添加host 286 | for i in self.node_idx: 287 | _h = self.addHost(f'h{i}', ip=f'10.0.0.{i}', mac=f'00.00.00.00.00.0{i}') 288 | self.addLink(_h, switches[i], port1=0, port2=self.host_port, 289 | bw=self.bw4) 290 | 291 | # add snooper 292 | # snooper = self.addSwitch("s30") 293 | # for i in self.node_idx: 294 | # self.addLink(snooper, switches[i], port1=i, port2=self.snooper_port) 295 | 296 | 297 | class Nodes14Topo(Topo): 298 | def __init__(self, graph): 299 | super(Nodes14Topo, self).__init__() 300 | self.node_idx = graph.nodes 301 | self.edges_pairs = graph.edges 302 | 303 | self.random_bw = 30 # Gbps -> M * 10 304 | self.bw4 = 50 # host -- switch 305 | 306 | self.delay = 20 # ms 307 | self.loss = 10 # % 308 | 309 | self.host_port = 9 310 | self.snooper_port = 10 311 | 312 | # 添加交换机 313 | switches = {} 314 | for s in self.node_idx: 315 | switches.setdefault(s, self.addSwitch('s{0}'.format(s))) 316 | print('添加交换机:', s) 317 | 318 | switch_port_dict = generate_switch_port(graph) 319 | links_info = {} 320 | # 添加链路 321 | for l in self.edges_pairs: 322 | port1 = switch_port_dict[l[0]].pop(0) + 1 323 | port2 = switch_port_dict[l[1]].pop(0) + 1 324 | 325 | _bw = random.randint(5, self.random_bw) 326 | _d = str(random.randint(1, self.delay)) + 'ms' 327 | _l = random.randint(0, self.loss) 328 | 329 | self.addLink(switches[l[0]], switches[l[1]], port1=port1, port2=port2, 330 | bw=_bw, delay=_d, loss=_l) 331 | 332 | links_info.setdefault(l, {"port1": port1, "port2": port2, "bw": _bw, "delay": _d, "loss": _l}) 333 | 334 | create_topo_links_info_xml(links_info_xml_path, links_info) 335 | 336 | # 添加host 337 | for i in self.node_idx: 338 | _h = self.addHost(f'h{i}', ip=f'10.0.0.{i}', mac=f'00.00.00.00.00.0{i}') 339 | self.addLink(_h, switches[i], port1=0, port2=self.host_port, 340 | bw=self.bw4) 341 | 342 | 343 | def main(graph, topo, finish_file): 344 | print('===Remove old finish file') 345 | remove_finish_file(finish_file) 346 | 347 | net = Mininet(topo=topo, link=TCLink, controller=RemoteController, waitConnected=True, build=False) 348 | c0 = net.addController('c0', ip='127.0.0.1', port=6633) 349 | 350 | net.build() 351 | net.start() 352 | 353 | print("get hosts device list") 354 | hosts = get_mininet_device(net, graph.nodes, device='h') 355 | 356 | print("===Dumping host connections") 357 | dumpNodeConnections(net.hosts) 358 | print('===Wait ryu init') 359 | time.sleep(40) 360 | # 添加网关ip 361 | # run_ip_add_default(hosts) 362 | 363 | # net.pingAll() 364 | net_h1_ping_others(net) 365 | write_pinall_time(finish_file) 366 | 367 | # iperf脚本 368 | # hosts[1].cmd('iperf -s -u -p 1002 -1 &') 369 | # hosts[2].cmd('iperf -c 10.0.0.1 -u -p 1002 -b 20000k -t 30 &') 370 | print('===Run iperf scripts') 371 | t = threading.Thread(target=all_host_run_iperf, args=(hosts, iperf_path, finish_file), name='iperf', daemon=True) 372 | print('===Thread iperf start') 373 | t.start() 374 | # all_host_run_iperf(hosts, iperf_path) 375 | 376 | CLI(net) 377 | net.stop() 378 | 379 | 380 | if __name__ == '__main__': 381 | xml_topology_path = r'./topologies/topology2.xml' 382 | links_info_xml_path = r'./links_info/links_info.xml' 383 | iperf_path = "./iperfTM" 384 | iperf_interval = 0 385 | finish_file = './finish_time.json' 386 | 387 | graph, nodes_num, edges_num = parse_xml_topology(xml_topology_path) 388 | # topo = GEANT23nodesTopo(graph) 389 | topo = Nodes14Topo(graph) 390 | 391 | setLogLevel('info') 392 | main(graph, topo, finish_file) 393 | -------------------------------------------------------------------------------- /mininet/iperf_script.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : iperf_script.py 3 | # @Date : 2022-02-10 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # @From : 6 | import argparse 7 | import shutil 8 | from pathlib import Path 9 | import numpy as np 10 | 11 | 12 | def read_npy(file=None): 13 | if file is None: 14 | file = args.file 15 | tms = np.load(file) 16 | return tms 17 | 18 | 19 | def create_script(tms): 20 | label = 0 21 | tms = np.transpose(tms, (2, 0, 1)) 22 | for tm in tms: 23 | # FOR CREATING FOLDERS PER TRAFFIC MATRIX 24 | Path(f'./iperfTM/TM-{label}').mkdir(parents=True, exist_ok=True) 25 | nameTM = Path(f'./iperfTM/TM-{label}') 26 | label += 1 27 | print('------', nameTM) 28 | Path.mkdir(nameTM, exist_ok=True) 29 | 30 | # --------------------FLOWS-------------------------- 31 | # FOR CREATING FOLDERS PER NODE 32 | for i in range(len(tm[0])): 33 | Path.mkdir(nameTM / Path('Clients'), exist_ok=True) 34 | Path.mkdir(nameTM / Path('Servers'), exist_ok=True) 35 | 36 | # Default parameters 37 | time_duration = args.time_duration 38 | port = args.port 39 | ip_dest = args.ip_dest 40 | throughput = args.throughput # take it in kbps from TM 41 | 42 | # UDP with time = 10s 43 | # -c: ip_destination 44 | # -b: throughput in k,m or g (Kbps, Mbps or Gbps) 45 | # -t: time in seconds 46 | 47 | # SERVER SIDE 48 | # iperf3 -s 49 | 50 | # CLIENT SIDE with iperf3 51 | # iperf3 -c -u -p -b -t -V -J 52 | 53 | # As we do not consider throughput in the same node, when src=dest the thro = 0 54 | for src in range(len(tm[0])): 55 | for dst in range(len(tm[0])): 56 | if src == dst: 57 | print("src: ", src, "dst: ", dst) 58 | tm[src][dst] = 0.0 59 | 60 | for src in range(1, len(tm[0]) + 1): 61 | with open(str(nameTM) + "/Clients/client_{0}.sh".format(str(src)), 'w') as fileClient: 62 | outputstring_a1 = "#!/bin/bash \necho Generating traffic..." 63 | fileClient.write(outputstring_a1) 64 | for dst in range(1, len(tm[0]) + 1): 65 | throughput = float(tm[src - 1][dst - 1]) 66 | # throughput_g = throughput / (100) # scale the throughput value to mininet link capacities 67 | temp1 = '' 68 | if src != dst: 69 | temp1 = '' 70 | temp1 += '\n' 71 | temp1 += 'iperf3 -c ' 72 | temp1 += '10.0.0.{0} '.format(str(dst)) 73 | if dst > 9: 74 | temp1 += '-p {0}0{1} '.format(str(src), str(dst)) 75 | else: 76 | temp1 += '-p {0}00{1} '.format(str(src), str(dst)) 77 | temp1 += '-u -b ' + str(format(throughput, '.2f')) + 'k' 78 | # temp1 += ' -w 256k -t ' + str(time_duration) 79 | temp1 += ' -t ' + str(time_duration) 80 | temp1 += ' >/dev/null 2>&1 &\n' # & at the end of the line it's for running the process in bkg 81 | temp1 += 'sleep 0.4' 82 | fileClient.write(temp1) 83 | 84 | # print(na) 85 | for dst in range(len(tm[0])): 86 | dst_ = dst + 1 87 | with open(str(nameTM) + "/Servers/server_{0}.sh".format(str(dst_)), 'w') as fileServer: 88 | outputstring_a2 = '#!/bin/bash \necho Initializing server listening...' 89 | fileServer.write(outputstring_a2) 90 | for src in range(len(tm[0])): 91 | src_ = src + 1 92 | temp2 = '' 93 | if src != dst: 94 | # n = n+1 95 | temp2 = '' 96 | temp2 += '\n' 97 | temp2 += 'iperf3 -s ' 98 | if dst_ > 9: 99 | temp2 += '-p {0}0{1} '.format(str(src_), str(dst_)) 100 | else: 101 | temp2 += '-p {0}00{1} '.format(str(src_), str(dst_)) 102 | temp2 += '-1' 103 | temp2 += ' >/dev/null 2>&1 &\n' # & at the end of the line it's for running the process in bkg 104 | temp2 += 'sleep 0.3' 105 | fileServer.write(temp2) 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = argparse.ArgumentParser(description="Generate traffic matrices") 110 | parser.add_argument("--seed", default=2020, help="random seed") 111 | # time_duration = 30 112 | # port = 2022 113 | # ip_dest = "10.0.0.1" 114 | # throughput = 0.0 # take it in kbps from TM 115 | parser.add_argument("--time_duration", default=30, help="time_duration") 116 | parser.add_argument("--port", default=2022, help="port") 117 | parser.add_argument("--ip_dest", default="10.0.0.1", help="ip_dest") 118 | parser.add_argument("--throughput", default=0.0, help="take it in kbps from TM") 119 | parser.add_argument("--file", 120 | default=r'tm_statistic/communicate_tm.npy', 121 | help="take it in kbps from TM") 122 | args = parser.parse_args() 123 | shutil.rmtree("iperfTM") 124 | tms = read_npy() 125 | create_script(tms) 126 | -------------------------------------------------------------------------------- /mininet/links_info/links_info.xml: -------------------------------------------------------------------------------- 1 | 2 | (1, 4)(1, 1)2420ms2(1, 5)(2, 1)2615ms7(1, 9)(3, 1)2012ms6(1, 3)(4, 1)226ms9(1, 11)(5, 1)715ms3(2, 4)(1, 2)1114ms10(2, 6)(2, 1)1120ms2(2, 5)(3, 2)2015ms1(4, 11)(3, 2)276ms7(4, 9)(4, 2)2418ms7(4, 5)(5, 3)302ms7(4, 7)(6, 1)612ms3(4, 6)(7, 2)818ms1(5, 6)(4, 3)249ms2(5, 8)(5, 1)518ms4(5, 7)(6, 2)198ms4(7, 8)(3, 2)2115ms3(7, 9)(4, 3)212ms9(8, 14)(3, 1)2120ms10(9, 10)(4, 1)128ms9(9, 13)(5, 1)1410ms2(12, 14)(1, 2)2814ms6(12, 13)(2, 2)2410ms2 -------------------------------------------------------------------------------- /mininet/topologies/topology-anonymised.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Topology generated by geant-TM-all.pl 5 | Steve Uhlig 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 | 455 | 456 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 | 624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | -------------------------------------------------------------------------------- /mininet/topologies/topology2.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | -------------------------------------------------------------------------------- /ryu/arp_handler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : arp_handler.py 3 | # @Date : 2022-03-09 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | from ryu.base import app_manager 6 | from ryu.base.app_manager import lookup_service_brick 7 | from ryu.ofproto import ofproto_v1_3 8 | from ryu.controller import ofp_event 9 | from ryu.controller.handler import set_ev_cls, MAIN_DISPATCHER 10 | from ryu.lib.packet import packet 11 | from ryu.lib.packet import arp, ipv4, ethernet 12 | 13 | ETHERNET = ethernet.ethernet.__name__ 14 | ETHERNET_MULTICAST = "ff:ff:ff:ff:ff:ff" 15 | ARP = arp.arp.__name__ 16 | 17 | 18 | class ArpHandler(app_manager.RyuApp): 19 | OFP_VERSION = [ofproto_v1_3.OFP_VERSION] 20 | 21 | def __init__(self, *args, **kwargs): 22 | super(ArpHandler, self).__init__(*args, **kwargs) 23 | self.discovery = lookup_service_brick('discovery') 24 | self.monitor = lookup_service_brick('monitor') 25 | 26 | self.arp_table = {} 27 | self.sw = {} 28 | self.mac_to_port = {} 29 | 30 | @set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER) 31 | def _packet_in_handler(self, ev): 32 | """ 33 | 处理PacketIn事件 34 | 1. arp包 是否已经记录 35 | """ 36 | # print("shortest---> _packet_in_handler: PacketIn") 37 | msg = ev.msg 38 | datapath = msg.datapath 39 | ofproto = datapath.ofproto 40 | parser = datapath.ofproto_parser 41 | 42 | in_port = msg.match['in_port'] 43 | 44 | pkt = packet.Packet(msg.data) 45 | # print("shortest---> _packet_in_handler: pkt:\n ", pkt) 46 | 47 | arp_pkt = pkt.get_protocol(arp.arp) 48 | ipv4_pkt = pkt.get_protocol(ipv4.ipv4) 49 | 50 | eth = pkt.get_protocols(ethernet.ethernet)[0] 51 | src = eth.src 52 | 53 | header_list = dict((p.protocol_name, p) for p in pkt.protocols if type(p) != bytes) 54 | # print("shortest--->_packet_in_handler: header_list:\n ", header_list) 55 | if isinstance(arp_pkt, arp.arp): 56 | self.arp_table[arp_pkt.src_ip] = src 57 | 58 | if self.arp_handler(header_list, datapath, in_port, msg.buffer_id): 59 | # 1:reply or drop; 0: flood 60 | # print("ARP_PROXY_13") 61 | return None 62 | else: 63 | arp_src_ip = arp_pkt.src_ip 64 | arp_dst_ip = arp_pkt.dst_ip 65 | location = self.discovery.get_host_ip_location(arp_dst_ip) 66 | # print("shortest--->zzzzzz----> location", location) 67 | if location: # 如果有这个主机的位置 68 | # print("shortest--->_packet_in_handler: ==Reply Arp to knew host") 69 | dpid_dst, out_port = location 70 | datapath = self.monitor.datapaths_table[dpid_dst] 71 | out = self._build_packet_out(datapath, ofproto.OFP_NO_BUFFER, ofproto.OFPP_CONTROLLER, 72 | out_port, msg.data) 73 | datapath.send_msg(out) 74 | return 75 | else: 76 | print("shortest--->_packet_in_handler: ==Flooding") 77 | for dpid in self.discovery.switch_all_ports_table: 78 | for port in self.discovery.switch_all_ports_table[dpid]: 79 | if (dpid, port) not in self.discovery.access_table.keys(): # 如果不知道 80 | datapath = self.monitor.datapaths_table[dpid] 81 | out = self._build_packet_out(datapath, ofproto.OFP_NO_BUFFER, 82 | ofproto.OFPP_CONTROLLER, port, msg.data) 83 | datapath.send_msg(out) 84 | return 85 | 86 | def arp_handler(self, header_list, datapath, in_port, msg_buffer_id): 87 | header_list = header_list 88 | datapath = datapath 89 | in_port = in_port 90 | 91 | # if ETHERNET in header_list: 92 | eth_dst = header_list[ETHERNET].dst 93 | eth_src = header_list[ETHERNET].src 94 | 95 | # print("shortest---> arp_handler eth_dst eth_src: \n", eth_dst, eth_src) 96 | 97 | if eth_dst == ETHERNET_MULTICAST and ARP in header_list: 98 | arp_dst_ip = header_list[ARP].dst_ip 99 | if (datapath.id, eth_src, arp_dst_ip) in self.sw: # break loop 100 | # print("shortest---> arp_handler: ====BREAK LOOP") 101 | out = datapath.ofproto_parser.OFPPacketOut( 102 | datapath=datapath, 103 | buffer_id=datapath.ofproto.OFP_NO_BUFFER, 104 | in_port=in_port, 105 | actions=[], data=None 106 | ) 107 | datapath.send_msg(out) 108 | return True 109 | else: 110 | self.sw[(datapath.id, eth_src, arp_dst_ip)] = in_port 111 | 112 | if ARP in header_list: 113 | # print("discovery---> arp_handler: ====ARP ARP") 114 | hwtype = header_list[ARP].hwtype 115 | proto = header_list[ARP].proto 116 | hlen = header_list[ARP].hlen 117 | plen = header_list[ARP].plen 118 | opcode = header_list[ARP].opcode 119 | 120 | arp_src_ip = header_list[ARP].src_ip 121 | arp_dst_ip = header_list[ARP].dst_ip 122 | 123 | actions = [] 124 | 125 | if opcode == arp.ARP_REQUEST: 126 | if arp_dst_ip in self.arp_table: # arp reply 127 | # print("shortest---> arp_handler: ====ARP REPLY") 128 | actions.append(datapath.ofproto_parser.OFPActionOutput(in_port)) 129 | 130 | ARP_Reply = packet.Packet() 131 | ARP_Reply.add_protocol(ethernet.ethernet(ethertype=header_list[ETHERNET].ethertype, 132 | dst=eth_src, 133 | src=self.arp_table[arp_dst_ip])) 134 | ARP_Reply.add_protocol(arp.arp(opcode=arp.ARP_REPLY, 135 | src_mac=self.arp_table[arp_dst_ip], 136 | src_ip=arp_dst_ip, 137 | dst_mac=eth_src, 138 | dst_ip=arp_src_ip)) 139 | 140 | ARP_Reply.serialize() 141 | 142 | out = datapath.ofproto_parser.OFPPacketOut( 143 | datapath=datapath, 144 | buffer_id=datapath.ofproto.OFP_NO_BUFFER, 145 | in_port=datapath.ofproto.OFPP_CONTROLLER, 146 | actions=actions, 147 | data=ARP_Reply.data 148 | ) 149 | datapath.send_msg(out) 150 | return True 151 | return False 152 | 153 | # 构造输出的包 154 | def _build_packet_out(self, datapath, buffer_id, src_port, dst_port, data): 155 | """ 构造输出的包""" 156 | actions = [] 157 | if dst_port: 158 | actions.append(datapath.ofproto_parser.OFPActionOutput(dst_port)) 159 | 160 | msg_data = None 161 | if buffer_id == datapath.ofproto.OFP_NO_BUFFER: 162 | if data is None: 163 | return None 164 | msg_data = data 165 | 166 | out = datapath.ofproto_parser.OFPPacketOut(datapath=datapath, buffer_id=buffer_id, 167 | data=msg_data, in_port=src_port, actions=actions) 168 | 169 | return out 170 | -------------------------------------------------------------------------------- /ryu/network_delay.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : network_delay.py 3 | # @Date : 2021-08-12 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # network_delay.py 6 | import copy 7 | import time 8 | 9 | from ryu.base import app_manager 10 | from ryu.base.app_manager import lookup_service_brick 11 | from ryu.controller import ofp_event 12 | from ryu.controller.handler import set_ev_cls, MAIN_DISPATCHER, DEAD_DISPATCHER 13 | from ryu.ofproto import ofproto_v1_3 14 | from ryu.lib import hub 15 | 16 | from ryu.topology.switches import Switches, LLDPPacket 17 | 18 | import setting 19 | import network_structure 20 | import network_monitor 21 | 22 | 23 | class NetworkDelayDetector(app_manager.RyuApp): 24 | """ 测量链路的时延""" 25 | OFP_VERSION = [ofproto_v1_3.OFP_VERSION] 26 | _CONTEXTS = {'switches': Switches} 27 | 28 | def __init__(self, *args, **kwargs): 29 | super(NetworkDelayDetector, self).__init__(*args, **kwargs) 30 | self.name = 'detector' 31 | 32 | self.network_structure = lookup_service_brick('discovery') 33 | self.network_monitor = lookup_service_brick('monitor') 34 | self.switch_module = lookup_service_brick('switches') 35 | 36 | self.switch_module = kwargs['switches'] 37 | # self.network_structure = kwargs['discovery'] 38 | # self.network_monitor = kwargs['monitor'] 39 | 40 | self.echo_delay_table = {} # {dpid: ryu_ofps_delay} 41 | self.lldp_delay_table = {} # {src_dpid: {dst_dpid: delay}} 42 | self.echo_interval = 0.05 43 | 44 | # self.datapaths_table = self.network_monitor.datapaths_table 45 | self._delay_thread = hub.spawn(self.scheduler) 46 | 47 | def scheduler(self): 48 | while True: 49 | self._send_echo_request() 50 | self.create_delay_graph() 51 | 52 | # if setting.PRINT_SHOW: 53 | # self.show_delay_stats() 54 | 55 | hub.sleep(setting.DELAY_PERIOD) 56 | 57 | # 利用echo发送时间,与接收时间相减 58 | # 1. 发送echo request 59 | def _send_echo_request(self): 60 | """ 发送echo请求""" 61 | datapaths_table = self.network_monitor.datapaths_table.values() 62 | if datapaths_table is not None: 63 | for datapath in list(datapaths_table): 64 | parser = datapath.ofproto_parser 65 | data = time.time() 66 | echo_req = parser.OFPEchoRequest(datapath, b"%.12f" % data) 67 | datapath.send_msg(echo_req) 68 | hub.sleep(self.echo_interval) # 防止发太快,这边收不到 69 | 70 | # 2. 接收echo reply 71 | @set_ev_cls(ofp_event.EventOFPEchoReply, MAIN_DISPATCHER) 72 | def _ehco_reply_handler(self, ev): 73 | now_timestamp = time.time() 74 | data = ev.msg.data 75 | ryu_ofps_delay = now_timestamp - eval(data) # 现在时间减去发送的时间 76 | self.echo_delay_table[ev.msg.datapath.id] = ryu_ofps_delay 77 | 78 | # 利用LLDP时延 79 | @set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER) 80 | def _packet_in_handler(self, ev): 81 | """ 解析LLDP包, 这个处理程序可以接收所有可以接收的数据包, swicthes.py l:769""" 82 | # print("detector---> PacketIn") 83 | try: 84 | recv_timestamp = time.time() 85 | msg = ev.msg 86 | dpid = msg.datapath.id 87 | src_dpid, src_port_no = LLDPPacket.lldp_parse(msg.data) 88 | 89 | # print("---> self.switch_module.ports", self.switch_module.ports) 90 | 91 | for port in self.switch_module.ports.keys(): 92 | if src_dpid == port.dpid and src_port_no == port.port_no: 93 | send_timestamp = self.switch_module.ports[port].timestamp 94 | if send_timestamp: 95 | delay = recv_timestamp - send_timestamp 96 | # else: 97 | # delay = 0 98 | self.lldp_delay_table.setdefault(src_dpid, {}) 99 | self.lldp_delay_table[src_dpid][dpid] = delay # 存起来 100 | except LLDPPacket.LLDPUnknownFormat as e: 101 | return 102 | 103 | def create_delay_graph(self): 104 | # 遍历所有的边 105 | # print('---> create delay graph') 106 | for src, dst in self.network_structure.graph.edges: 107 | delay = self.calculate_delay(src, dst) 108 | self.network_structure.graph[src][dst]['delay'] = delay * 1000 # ms 109 | # print("--->" * 2, self.network_structure.count + 1) 110 | 111 | def calculate_delay(self, src, dst): 112 | """ 113 | ┌------Ryu------┐ 114 | | | 115 | src echo latency| |dst echo latency 116 | | | 117 | SwitchA------------SwitchB 118 | --->fwd_delay---> 119 | <---reply_delay<--- 120 | """ 121 | 122 | fwd_delay = self.lldp_delay_table[src][dst] 123 | reply_delay = self.lldp_delay_table[dst][src] 124 | ryu_ofps_src_delay = self.echo_delay_table[src] 125 | ryu_ofps_dst_delay = self.echo_delay_table[dst] 126 | 127 | delay = (fwd_delay + reply_delay - ryu_ofps_src_delay - ryu_ofps_dst_delay) / 2 128 | return max(delay, 0) 129 | 130 | def show_delay_stats(self): 131 | self.logger.info("==============================DDDD delay=================================") 132 | self.logger.info("src dst : delay") 133 | for src in self.lldp_delay_table.keys(): 134 | for dst in self.lldp_delay_table[src].keys(): 135 | delay = self.lldp_delay_table[src][dst] 136 | self.logger.info("%s <---> %s : %s", src, dst, delay) 137 | -------------------------------------------------------------------------------- /ryu/network_monitor.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @File : network_monitor.py 3 | # @Date : 2021-08-12 4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来- 5 | # network_monitor.py 6 | import copy 7 | from operator import attrgetter 8 | 9 | from ryu.base import app_manager 10 | from ryu.ofproto import ofproto_v1_3 11 | from ryu.controller import ofp_event 12 | from ryu.controller.handler import set_ev_cls, MAIN_DISPATCHER, DEAD_DISPATCHER 13 | from ryu.lib import hub 14 | from ryu.base.app_manager import lookup_service_brick 15 | 16 | import setting 17 | from setting import print_pretty_table, print_pretty_list 18 | 19 | 20 | class NetworkMonitor(app_manager.RyuApp): 21 | """ 监控网络流量状态""" 22 | OFP_VERSIONS = [ofproto_v1_3.OFP_VERSION] 23 | 24 | def __init__(self, *args, **kwargs): 25 | super(NetworkMonitor, self).__init__(*args, **kwargs) 26 | self.name = 'monitor' 27 | # {dpid: datapath} 28 | self.datapaths_table = {} 29 | # {dpid:{port_no: (config, state, curr_speed, max_speed)}} 30 | self.dpid_port_fueatures_table = {} 31 | # {(dpid, port_no): (stat.tx_bytes, stat.rx_bytes, stat.rx_errors, stat.duration_sec, 32 | # stat.duration_nsec, stat.tx_packets, stat.rx_packets)} 33 | self.port_stats_table = {} 34 | # {dpid:{(in_port, ipv4_dsts, out_port): (packet_count, byte_count, duration_sec, duration_nsec)}} 35 | self.flow_stats_table = {} 36 | # {(dpid, port_no): [speed, .....]} 37 | self.port_speed_table = {} 38 | # {dpid: {(in_port, ipv4_dsts, out_port): speed}} 39 | self.flow_speed_table = {} 40 | 41 | self.port_flow_dpid_stats = {'port': {}, 'flow': {}} 42 | # {dpid: {port_no: curr_bw}} 43 | self.port_curr_speed = {} 44 | 45 | self.port_loss = {} 46 | 47 | self.discovery = lookup_service_brick("discovery") # 创建一个NetworkStructure的实例 48 | 49 | # self.monitor_thread = hub.spawn(self._monitor) 50 | self.monitor_thread = hub.spawn(self.scheduler) 51 | self.save_thread = hub.spawn(self.save_bw_loss_graph) 52 | 53 | def print_parameters(self): 54 | # print("monitor---> self.datapaths_table", self.datapaths_table) 55 | # print("monitor---> self.dpid_port_fueatures_table", self.dpid_port_fueatures_table) 56 | # print("monitor---> self.port_stats_table", self.port_stats_table) 57 | # print("monitor---> self.flow_stats_table", self.flow_stats_table) 58 | # print("monitor---> self.port_speed_table", self.port_speed_table) 59 | # print("monitor---> self.flow_speed_table", self.flow_speed_table) 60 | # print("monitor---> self.port_curr_speed", self.port_curr_speed) 61 | 62 | logger = self.logger.info if setting.LOGGER else print 63 | 64 | # {dpid: datapath} 65 | # print_pretty_table(self.datapaths_table, ['dpid', 'datapath'], [6, 64], 'MMMM datapaths_table', 66 | # logger) 67 | 68 | # # {dpid:{port_no: (config, state, curr_speed, max_speed)}} 69 | # print_pretty_table(self.dpid_port_fueatures_table, 70 | # ['dpid', 'port_no:(config, state, curr_speed, max_speed)'], 71 | # [6, 40], 'MMMM dpid_port_fueatures_table', logger) 72 | 73 | # {(dpid, port_no): (stat.tx_bytes, stat.rx_bytes, stat.rx_errors, stat.duration_sec, 74 | # stat.duration_nsec, stat.tx_packets, stat.rx_packets)} 75 | # print_pretty_table(self.port_stats_table, 76 | # ['(dpid, port_no)', 77 | # '(stat.tx_bytes, stat.rx_bytes, stat.rx_errors, stat.duration_sec, stat.duration_nsec, stat.tx_packets, stat.rx_packets)'], 78 | # [18, 120], 'MMMM port_stats_table', logger) 79 | 80 | # {(dpid, port_no): [speed, .....]} 81 | # print_pretty_table(self.port_speed_table, 82 | # ['(dpid, port_no)', 'speed'], 83 | # [18, 40], 'MMMM port_speed_table', logger) 84 | 85 | print("'MMMM port_loss: \n", self.port_loss) 86 | 87 | def print_parameters_(self): 88 | print("monitor---------- %s ----------", self.name) 89 | for attr, value in self.__dict__.items(): 90 | print("monitor\n---> %s: %s" % attr, value) 91 | print("monitor===================================") 92 | 93 | def scheduler(self): 94 | while True: 95 | self.port_flow_dpid_stats['flow'] = {} 96 | self.port_flow_dpid_stats['port'] = {} 97 | 98 | self._request_stats() 99 | if setting.PRINT_SHOW: 100 | self.print_parameters() 101 | hub.sleep(setting.MONITOR_PERIOD) 102 | 103 | def save_bw_loss_graph(self): 104 | while True: 105 | self.create_bandwidth_graph() 106 | self.create_loss_graph() 107 | hub.sleep(setting.MONITOR_PERIOD) 108 | 109 | @set_ev_cls(ofp_event.EventOFPStateChange, [MAIN_DISPATCHER, DEAD_DISPATCHER]) 110 | def _state_change_handler(self, ev): 111 | """ 存放所有的datapath实例""" 112 | datapath = ev.datapath # OFPStateChange类可以直接获得datapath 113 | if ev.state == MAIN_DISPATCHER: 114 | if datapath.id not in self.datapaths_table: 115 | print("MMMM---> register datapath: %016x" % datapath.id) 116 | self.datapaths_table[datapath.id] = datapath 117 | 118 | # 一些初始化 119 | self.dpid_port_fueatures_table.setdefault(datapath.id, {}) 120 | self.flow_stats_table.setdefault(datapath.id, {}) 121 | 122 | elif ev.state == DEAD_DISPATCHER: 123 | if datapath.id in self.datapaths_table: 124 | print("MMMM---> unreigster datapath: %016x" % datapath.id) 125 | del self.datapaths_table[datapath.id] 126 | 127 | # 主动发送request,请求状态信息 128 | def _request_stats(self): 129 | # print("MMMM---> send request ---> ---> send request ---> ") 130 | datapaths_table = self.datapaths_table.values() 131 | 132 | for datapath in list(datapaths_table): 133 | self.dpid_port_fueatures_table.setdefault(datapath.id, {}) 134 | # print("MMMM---> send stats request: %016x", datapath.id) 135 | ofproto = datapath.ofproto 136 | parser = datapath.ofproto_parser 137 | 138 | # 1. 端口描述请求 139 | req = parser.OFPPortDescStatsRequest(datapath, 0) # 140 | datapath.send_msg(req) 141 | 142 | # 2. 端口统计请求 143 | req = parser.OFPPortStatsRequest(datapath, 0, ofproto.OFPP_ANY) # 所有端口 144 | datapath.send_msg(req) 145 | 146 | # 3. 单个流统计请求 147 | # req = parser.OFPFlowStatsRequest(datapath) 148 | # datapath.send_msg(req) 149 | 150 | # 处理上面请求的回复OFPPortDescStatsReply 151 | @set_ev_cls(ofp_event.EventOFPPortDescStatsReply, MAIN_DISPATCHER) 152 | def port_desc_stats_reply_handler(self, ev): 153 | """ 存储端口描述信息, 见OFPPort类, 配置、状态、当前速度""" 154 | # print("MMMM---> EventOFPPortDescStatsReply") 155 | msg = ev.msg 156 | dpid = msg.datapath.id 157 | ofproto = msg.datapath.ofproto 158 | 159 | config_dict = {ofproto.OFPPC_PORT_DOWN: 'Port Down', 160 | ofproto.OFPPC_NO_RECV: 'No Recv', 161 | ofproto.OFPPC_NO_FWD: 'No Forward', 162 | ofproto.OFPPC_NO_PACKET_IN: 'No Pakcet-In'} 163 | 164 | state_dict = {ofproto.OFPPS_LINK_DOWN: "Link Down", 165 | ofproto.OFPPS_BLOCKED: "Blocked", 166 | ofproto.OFPPS_LIVE: "Live"} 167 | 168 | for ofport in ev.msg.body: # 这一直有bug,修改properties 169 | if ofport.port_no != ofproto_v1_3.OFPP_LOCAL: # 0xfffffffe 4294967294 170 | 171 | if ofport.config in config_dict: 172 | config = config_dict[ofport.config] 173 | else: 174 | config = 'Up' 175 | 176 | if ofport.state in state_dict: 177 | state = state_dict[ofport.state] 178 | else: 179 | state = 'Up' 180 | 181 | # 存储配置,状态, curr_speed,max_speed=0 182 | port_features = (config, state, ofport.curr_speed, ofport.max_speed) 183 | # print("MMMM---> ofport.curr_speed", ofport.curr_speed) 184 | self.dpid_port_fueatures_table[dpid][ofport.port_no] = port_features 185 | 186 | @set_ev_cls(ofp_event.EventOFPPortStatsReply, MAIN_DISPATCHER) 187 | def port_stats_table_reply_handler(self, ev): 188 | """ 存储端口统计信息, 见OFPPortStats, 发送bytes、接收bytes、生效时间duration_sec等 189 | Replay message content: 190 | (stat.port_no, 191 | stat.rx_packets, stat.tx_packets, 192 | stat.rx_bytes, stat.tx_bytes, 193 | stat.rx_dropped, stat.tx_dropped, 194 | stat.rx_errors, stat.tx_errors, 195 | stat.rx_frame_err, stat.rx_over_err, 196 | stat.rx_crc_err, stat.collisions, 197 | stat.duration_sec, stat.duration_nsec)) 198 | """ 199 | # print("MMMM---> EventOFPPortStatsReply") 200 | body = ev.msg.body 201 | dpid = ev.msg.datapath.id 202 | self.port_flow_dpid_stats['port'][dpid] = body 203 | # self.port_curr_speed.setdefault(dpid, {}) 204 | 205 | for stat in sorted(body, key=attrgetter("port_no")): 206 | port_no = stat.port_no 207 | if port_no != ofproto_v1_3.OFPP_LOCAL: 208 | key = (dpid, port_no) 209 | value = (stat.tx_bytes, stat.rx_bytes, stat.rx_errors, 210 | stat.duration_sec, stat.duration_nsec, stat.tx_packets, stat.rx_packets) 211 | self._save_stats(self.port_stats_table, key, value, 5) # 保存信息,最多保存前5次 212 | 213 | pre_bytes = 0 214 | # delta_time = setting.MONITOR_PERIOD 215 | delta_time = setting.SCHEDULE_PERIOD 216 | stats = self.port_stats_table[key] # 获得已经存了的统计信息 217 | 218 | if len(stats) > 1: # 有两次以上的信息 219 | pre_bytes = stats[-2][0] + stats[-2][1] 220 | delta_time = self._calculate_delta_time(stats[-1][3], stats[-1][4], 221 | stats[-2][3], stats[-2][4]) # 倒数第一个统计信息,倒数第二个统计信息 222 | 223 | speed = self._calculate_speed(stats[-1][0] + stats[-1][1], 224 | pre_bytes, delta_time) 225 | self._save_stats(self.port_speed_table, key, speed, 5) 226 | self._calculate_port_speed(dpid, port_no, speed) 227 | 228 | self.calculate_loss_of_link() 229 | 230 | @set_ev_cls(ofp_event.EventOFPFlowStatsReply, MAIN_DISPATCHER) 231 | def _flow_stats_reply_handler(self, ev): 232 | """ 存储flow的状态,算这个干啥。。""" 233 | msg = ev.msg 234 | body = msg.body 235 | datapath = msg.datapath 236 | dpid = datapath.id 237 | 238 | self.port_flow_dpid_stats['flow'][dpid] = body 239 | # print("MMMM---> body", body) 240 | 241 | for stat in sorted([flowstats for flowstats in body if flowstats.priority == 1], 242 | key=lambda flowstats: (flowstats.match.get('in_port'), flowstats.match.get('ipv4_dst'))): 243 | # print("MMMM---> stat.match", stat.match) 244 | # print("MMMM---> stat", stat) 245 | key = (stat.match['in_port'], stat.match['ipv4_dst'], 246 | stat.instructions[0].actions[0].port) 247 | value = (stat.packet_count, stat.byte_count, stat.duration_sec, stat.duration_nsec) 248 | self._save_stats(self.flow_stats_table[dpid], key, value, 5) 249 | 250 | pre_bytes = 0 251 | # delta_time = setting.MONITOR_PERIOD 252 | delta_time = setting.SCHEDULE_PERIOD 253 | value = self.flow_stats_table[dpid][key] 254 | if len(value) > 1: 255 | pre_bytes = value[-2][1] 256 | # print("MMMM---> _flow_stats_reply_handler delta_time: now", value[-1][2], value[-1][3], "pre", 257 | # value[-2][2], 258 | # value[-2][3]) 259 | delta_time = self._calculate_delta_time(value[-1][2], value[-1][3], 260 | value[-2][2], value[-2][3]) 261 | speed = self._calculate_speed(self.flow_stats_table[dpid][key][-1][1], pre_bytes, delta_time) 262 | self.flow_speed_table.setdefault(dpid, {}) 263 | self._save_stats(self.flow_speed_table[dpid], key, speed, 5) 264 | 265 | # 存多次数据,比如一个端口存上一次的统计信息和这一次的统计信息 266 | @staticmethod 267 | def _save_stats(_dict, key, value, keep): 268 | if key not in _dict: 269 | _dict[key] = [] 270 | _dict[key].append(value) 271 | 272 | if len(_dict[key]) > keep: 273 | _dict[key].pop(0) # 弹出最早的数据 274 | 275 | def _calculate_delta_time(self, now_sec, now_nsec, pre_sec, pre_nsec): 276 | """ 计算统计时间, 即两个消息时间差""" 277 | return self._calculate_seconds(now_sec, now_nsec) - self._calculate_seconds(pre_sec, pre_nsec) 278 | 279 | @staticmethod 280 | def _calculate_seconds(sec, nsec): 281 | """ 计算 sec + nsec 的和,单位为 seconds""" 282 | return sec + nsec / 10 ** 9 283 | 284 | @staticmethod 285 | def _calculate_speed(now_bytes, pre_bytes, delta_time): 286 | """ 计算统计流量速度""" 287 | if delta_time: 288 | 289 | return (now_bytes - pre_bytes) / delta_time 290 | else: 291 | return 0 292 | 293 | def _calculate_port_speed(self, dpid, port_no, speed): 294 | curr_bw = speed * 8 / 10 ** 6 # MBit/s 295 | # print(f"monitorMMMM---> _calculate_port_speed: {curr_bw} MBits/s", ) 296 | self.port_curr_speed.setdefault(dpid, {}) 297 | self.port_curr_speed[dpid][port_no] = curr_bw 298 | 299 | @set_ev_cls(ofp_event.EventOFPPortStatus, MAIN_DISPATCHER) 300 | def _port_status_handler(self, ev): 301 | """ 处理端口状态: ADD, DELETE, MODIFIED""" 302 | msg = ev.msg 303 | dp = msg.datapath 304 | ofp = dp.ofproto 305 | 306 | if msg.reason == ofp.OFPPR_ADD: 307 | reason = 'ADD' 308 | elif msg.reason == ofp.OFPPR_DELETE: 309 | reason = 'DELETE' 310 | elif msg.reason == ofp.OFPPR_MODIFY: 311 | reason = 'MODIFY' 312 | else: 313 | reason = 'unknown' 314 | 315 | print('MMMM---> _port_status_handler OFPPortStatus received: reason=%s desc=%s' % (reason, msg.desc)) 316 | 317 | # 通过获得的网络拓扑,更新其bw权重 318 | def create_bandwidth_graph(self): 319 | # print("MMMM---> create bandwidth graph") 320 | for link in self.discovery.link_port_table: 321 | src_dpid, dst_dpid = link 322 | src_port, dst_port = self.discovery.link_port_table[link] 323 | 324 | if src_dpid in self.port_curr_speed.keys() and dst_dpid in self.port_curr_speed.keys(): 325 | src_port_bw = self.port_curr_speed[src_dpid][src_port] 326 | dst_port_bw = self.port_curr_speed[dst_dpid][dst_port] 327 | src_dst_bandwidth = min(src_port_bw, dst_port_bw) # bottleneck bandwidth 328 | 329 | # print(f"monitor--> dst[{src_dpid}]_port[{src_port}]_bw: %.5f" % dst_port_bw) 330 | # print(f"monitor---> src[{dst_dpid}]_port[{dst_port}]_bw: %.5f" % src_port_bw) 331 | # print("monitor---> src_dst_bandwidth: %.5f" % src_dst_bandwidth) 332 | 333 | # 对图的edge设置 可用bw 值 334 | capacity = self.discovery.m_graph[src_dpid][dst_dpid]['bw'] 335 | self.discovery.graph[src_dpid][dst_dpid]['bw'] = max(capacity - src_dst_bandwidth, 0) 336 | 337 | else: 338 | self.logger.info( 339 | "MMMM---> create_bandwidth_graph: [{}] [{}] not in port_free_bandwidth ".format(src_dpid, 340 | dst_dpid)) 341 | self.discovery.graph[src_dpid][dst_dpid]['bw'] = -1 342 | 343 | # print("MMMM---> ", self.discovery.graph.edges(data=True)) 344 | # print("MMMM---> " * 2, self.discovery.count + 1) 345 | 346 | # calculate loss tx - rx / tx 347 | def calculate_loss_of_link(self): 348 | """ 349 | 发端口 和 收端口 ,端口loss 350 | """ 351 | for link, port in self.discovery.link_port_table.items(): 352 | src_dpid, dst_dpid = link 353 | src_port, dst_port = port 354 | if (src_dpid, src_port) in self.port_stats_table.keys() and \ 355 | (dst_dpid, dst_port) in self.port_stats_table.keys(): 356 | # {(dpid, port_no): (stat.tx_bytes, stat.rx_bytes, stat.rx_errors, stat.duration_sec, 357 | # stat.duration_nsec, stat.tx_packets, stat.rx_packets)} 358 | # 1. 顺向 2022/3/11 packets modify--> bytes 359 | tx = self.port_stats_table[(src_dpid, src_port)][-1][0] # tx_bytes 360 | rx = self.port_stats_table[(dst_dpid, dst_port)][-1][1] # rx_bytes 361 | loss_ratio = abs(float(tx - rx) / tx) * 100 362 | self._save_stats(self.port_loss, link, loss_ratio, 5) 363 | # print(f"MMMM--->[{link}]({dst_dpid}, {dst_port}) rx: ", rx, "tx: ", tx, 364 | # "loss_ratio: ", loss_ratio) 365 | 366 | # 2. 逆项 367 | tx = self.port_stats_table[(dst_dpid, dst_port)][-1][0] # tx_bytes 368 | rx = self.port_stats_table[(src_dpid, src_port)][-1][1] # rx_bytes 369 | loss_ratio = abs(float(tx - rx) / tx) * 100 370 | self._save_stats(self.port_loss, link[::-1], loss_ratio, 5) 371 | 372 | # print(f"MMMM--->[{link[::-1]}]({dst_dpid}, {dst_port}) rx: ", rx, "tx: ", tx, 373 | # "loss_ratio: ", loss_ratio) 374 | else: 375 | self.logger.info("MMMM---> calculate_loss_of_link error", ) 376 | 377 | # update graph loss 378 | def update_graph_loss(self): 379 | """从1 往2 和 从2 往1,取最大作为链路loss """ 380 | for link in self.discovery.link_port_table: 381 | src_dpid = link[0] 382 | dst_dpid = link[1] 383 | if link in self.port_loss.keys() and link[::-1] in self.port_loss.keys(): 384 | src_loss = self.port_loss[link][-1] # 1-->2 -1取最新的那个 385 | dst_loss = self.port_loss[link[::-1]][-1] # 2-->1 386 | link_loss = max(src_loss, dst_loss) # 百分比 max loss between port1 and port2 387 | self.discovery.graph[src_dpid][dst_dpid]['loss'] = link_loss 388 | 389 | # print(f"MMMM---> update_graph_loss link[{link}]_loss: ", link_loss) 390 | else: 391 | self.discovery.graph[src_dpid][dst_dpid]['loss'] = 100 392 | 393 | def create_loss_graph(self): 394 | """ 395 | 在graph中更新边的loss值 396 | """ 397 | # self.calculate_loss_of_link() 398 | self.update_graph_loss() 399 | -------------------------------------------------------------------------------- /ryu/network_structure.py: -------------------------------------------------------------------------------- 1 | # network_structure.py 2 | import copy 3 | import time 4 | import xml.etree.ElementTree as ET 5 | 6 | from ryu.base import app_manager 7 | from ryu.ofproto import ofproto_v1_3 8 | from ryu.controller import ofp_event 9 | from ryu.controller.handler import set_ev_cls, MAIN_DISPATCHER, CONFIG_DISPATCHER 10 | from ryu.lib import hub 11 | from ryu.lib import igmplib, mac 12 | from ryu.lib.dpid import str_to_dpid 13 | from ryu.lib.packet import packet, arp, ethernet, ipv4, igmp 14 | from ryu.topology import event 15 | from ryu.topology.api import get_switch, get_link, get_host 16 | 17 | import networkx as nx 18 | import matplotlib.pyplot as plt 19 | 20 | import setting 21 | from setting import print_pretty_table, print_pretty_list 22 | 23 | 24 | class NetworkStructure(app_manager.RyuApp): 25 | """ 26 | 发现网络拓扑,保存网络结构 27 | """ 28 | OFP_VERSION = [ofproto_v1_3.OFP_VERSION] 29 | 30 | # _CONTEXTS = {'igmplib': igmplib.IgmpLib} 31 | 32 | def __init__(self, *args, **kwargs): 33 | super(NetworkStructure, self).__init__(*args, **kwargs) 34 | self.start_time = time.time() 35 | self.name = 'discovery' 36 | # self._snoop = kwargs['igmplib'] 37 | # self._snoop.set_querier_mode(dpid=str_to_dpid('000000000000001e'), server_port=2) 38 | self.topology_api_app = self 39 | self.link_info_xml = setting.LINKS_INFO # xml file path of links info 40 | self.m_graph = self.parse_topo_links_info() # 解析mininet构建的topo链路信息 41 | 42 | self.graph = nx.Graph() 43 | self.pre_graph = nx.Graph() 44 | 45 | self.access_table = {} # {(dpid, in_port): (src_ip, src_mac)} 46 | self.switch_all_ports_table = {} # {dpid: {port_no, ...}} 47 | self.all_switches_dpid = {} # dict_key[dpid] 48 | self.switch_port_table = {} # {dpid: {port, ...} 49 | self.link_port_table = {} # {(src.dpid, dst.dpid): (src.port_no, dst.port_no)} 50 | self.not_use_ports = {} # {dpid: {port, ...}} 交换机之间没有用来连接的port 51 | self.shortest_path_table = {} # {(src.dpid, dst.dpid): [path]} 52 | self.arp_table = {} # {(dpid, eth_src, arp_dst_ip): in_port} 53 | self.arp_src_dst_ip_table = {} 54 | 55 | # self.multiple_access_table = {} # {(dpid, in_port): (src_ip, src_mac)} 56 | # self.group_table = {} # {group_address: [(dpid, in_port), ...]} 57 | 58 | # self._discover_thread = hub.spawn(self._discover_network_structure) 59 | # self._show_graph = hub.spawn(self.show_graph_plt()) 60 | self.initiation_delay = setting.INIT_TIME 61 | self.first_flag = True 62 | self.cal_path_flag = False 63 | 64 | self._structure_thread = hub.spawn(self.scheduler) 65 | self._shortest_path_thread = hub.spawn(self.cal_shortest_path_thread) 66 | 67 | def print_parameters(self): 68 | # self.logger.info("discovery---> access_table: %s", self.access_table) 69 | # self.logger.info("discovery---> link_port_table: %s", self.link_port_table) 70 | # self.logger.info("discovery---> not_use_ports: %s", self.not_use_ports) 71 | # self.logger.info("discovery---> shortest_path_table: %s", self.shortest_path_table) 72 | logger = self.logger.info if setting.LOGGER else print 73 | # 图 74 | # logger("============================= SSSS graph edges==============================") 75 | # logger('SSSS---> graph edges:\n', self.graph.edges) 76 | # logger("=============================end SSSS graph edges=============================") 77 | 78 | # 交换机dpid: {交换机所有port号} 79 | # {dpid: {port_no, ...}} 80 | print_pretty_table(self.switch_all_ports_table, ['dpid', 'port_no'], [10, 10], 81 | 'SSSS switch_all_ports_table', logger) 82 | 83 | # 交换机id: lldp发现的端口 84 | # {dpid: {port, ...} 85 | print_pretty_table(self.switch_port_table, ['dpid', 'port_no'], [10, 10], 'SSSS switch_port_table', 86 | logger) 87 | 88 | # {(dpid, in_port): (src_ip, src_mac)} 89 | print_pretty_table(self.access_table, ['(dpid, in_port)', '(src_ip, src_mac)'], [10, 40], 'SSSS access_table', 90 | logger) 91 | 92 | # {dpid: {port, ...}} 93 | print_pretty_table(self.not_use_ports, ['dpid', 'not_use_ports'], [10, 30], 'SSSS not_use_ports', logger) 94 | 95 | def scheduler(self): 96 | i = 0 97 | while True: 98 | if i == 3: 99 | self.get_topology(None) 100 | i = 0 101 | hub.sleep(setting.DISCOVERY_PERIOD) 102 | if setting.PRINT_SHOW: 103 | self.print_parameters() 104 | i += 1 105 | 106 | def cal_shortest_path_thread(self): 107 | self.cal_path_flag = False 108 | while True: 109 | if self.cal_path_flag: 110 | self.calculate_all_nodes_shortest_paths(weight=setting.WEIGHT) 111 | # print("*****discovery---> self.shortest_path_table:\n", self.shortest_path_table) 112 | hub.sleep(setting.DISCOVERY_PERIOD) 113 | 114 | # Flow mod and Table miss 115 | @set_ev_cls(ofp_event.EventOFPSwitchFeatures, CONFIG_DISPATCHER) 116 | def switch_features_handler(self, ev): 117 | datapath = ev.msg.datapath 118 | ofproto = datapath.ofproto 119 | parser = datapath.ofproto_parser 120 | 121 | self.logger.info("discovery---> switch: %s connected", datapath.id) 122 | 123 | # install table miss flow entry 124 | match = parser.OFPMatch() # match all 125 | actions = [parser.OFPActionOutput(ofproto.OFPP_CONTROLLER, 126 | ofproto.OFPCML_NO_BUFFER)] 127 | 128 | self.add_flow(datapath, 0, match, actions) 129 | 130 | def add_flow(self, datapath, priority, match, actions): 131 | inst = [datapath.ofproto_parser.OFPInstructionActions(datapath.ofproto.OFPIT_APPLY_ACTIONS, 132 | actions)] 133 | mod = datapath.ofproto_parser.OFPFlowMod(datapath=datapath, priority=priority, 134 | match=match, instructions=inst) 135 | datapath.send_msg(mod) 136 | 137 | # Packet In 138 | @set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER) 139 | def _packet_in_handler(self, ev): 140 | # print("discovery---> discovery PacketIn") 141 | msg = ev.msg 142 | datapath = msg.datapath 143 | 144 | # 输入端口号 145 | in_port = msg.match['in_port'] 146 | pkt = packet.Packet(msg.data) 147 | arp_pkt = pkt.get_protocol(arp.arp) 148 | 149 | if isinstance(arp_pkt, arp.arp): 150 | # print("SSSS---> _packet_in_handler: arp packet") 151 | arp_src_ip = arp_pkt.src_ip 152 | src_mac = arp_pkt.src_mac 153 | self.storage_access_info(datapath.id, in_port, arp_src_ip, src_mac) 154 | # print("discovery---> access_table:\n ", self.access_table) 155 | 156 | # 将packet-in解析的arp的网络通路信息存储 157 | def storage_access_info(self, dpid, in_port, src_ip, src_mac): 158 | # print(f"SSSS--->storage_access_info, self.access_table: {self.access_table}") 159 | if in_port in self.not_use_ports[dpid]: 160 | # print("discovery--->", dpid, in_port, src_ip, src_mac) 161 | if (dpid, in_port) in self.access_table: 162 | if self.access_table[(dpid, in_port)] == (src_ip, src_mac): 163 | return 164 | else: 165 | self.access_table[(dpid, in_port)] = (src_ip, src_mac) 166 | return 167 | else: 168 | self.access_table.setdefault((dpid, in_port), None) 169 | self.access_table[(dpid, in_port)] = (src_ip, src_mac) 170 | return 171 | 172 | # 利用topology库获取拓扑信息 173 | events = [event.EventSwitchEnter, event.EventSwitchLeave, 174 | event.EventPortAdd, event.EventPortDelete, event.EventPortModify, 175 | event.EventLinkAdd, event.EventLinkDelete] 176 | 177 | @set_ev_cls(events) 178 | def get_topology(self, ev): 179 | present_time = time.time() 180 | if present_time - self.start_time < self.initiation_delay: # Set to 30s 181 | print(f'SSSS--->get_topology: need to WAIT {self.initiation_delay - (present_time - self.start_time):.2f}s') 182 | return 183 | elif self.first_flag: 184 | self.first_flag = False 185 | print("SSSS--->get_topology: complete WAIT") 186 | 187 | # self.logger.info("discovery---> EventSwitch/Port/Link") 188 | self.logger.info("[Topology Discovery Ok]") 189 | # 事件发生时,获得swicth列表 190 | switch_list = get_switch(self.topology_api_app, None) 191 | # 将swicth添加到self.switch_all_ports_table 192 | for switch in switch_list: 193 | dpid = switch.dp.id 194 | self.switch_all_ports_table.setdefault(dpid, set()) 195 | self.switch_port_table.setdefault(dpid, set()) 196 | self.not_use_ports.setdefault(dpid, set()) 197 | # print("discovery---> ",switch, switch.ports) 198 | for p in switch.ports: 199 | self.switch_all_ports_table[dpid].add(p.port_no) 200 | 201 | self.all_switches_dpid = self.switch_all_ports_table.keys() 202 | 203 | # time.sleep(0.5) 204 | # 获得link 205 | link_list = get_link(self.topology_api_app, None) 206 | # print("discovery---> ",len(link_list)) 207 | 208 | # 将link添加到self.link_table 209 | for link in link_list: 210 | src = link.src # 实际是个port实例,我找了半天 211 | dst = link.dst 212 | self.link_port_table[(src.dpid, dst.dpid)] = (src.port_no, dst.port_no) 213 | 214 | if src.dpid in self.all_switches_dpid: 215 | self.switch_port_table[src.dpid].add(src.port_no) 216 | if dst.dpid in self.all_switches_dpid: 217 | self.switch_port_table[dst.dpid].add(dst.port_no) 218 | 219 | # 统计没使用的端口 220 | for sw_dpid in self.switch_all_ports_table.keys(): 221 | all_ports = self.switch_all_ports_table[sw_dpid] 222 | linked_port = self.switch_port_table[sw_dpid] 223 | # print("discovery---> all_ports, linked_port", all_ports, linked_port) 224 | self.not_use_ports[sw_dpid] = all_ports - linked_port 225 | 226 | # 建立拓扑 bw和delay未定 227 | self.build_topology_between_switches() 228 | self.cal_path_flag = True 229 | 230 | def build_topology_between_switches(self, bw=0, delay=0, loss=0): 231 | """ 根据 src_dpid 和 dst_dpid 建立拓扑,bw 和 delay 信息还未定""" 232 | # networkxs使用已有Link的src_dpid和dst_dpid信息建立拓扑 233 | _graph = nx.Graph() 234 | 235 | # self.graph.clear() 236 | for (src_dpid, dst_dpid) in self.link_port_table.keys(): 237 | # 建立switch之间的连接,端口可以通过查link_port_table获得 238 | _graph.add_edge(src_dpid, dst_dpid, bw=bw, delay=delay, loss=loss) 239 | if _graph.edges == self.graph.edges: 240 | return 241 | else: 242 | self.graph = _graph 243 | 244 | def calculate_weight(self, node1, node2, weight_dict): 245 | """ 计算路径时,weight可以调用函数,该函数根据因子计算 bw * factor - delay * (1 - factor) 后的weight""" 246 | # weight可以调用的函数 247 | assert 'bw' in weight_dict and 'delay' in weight_dict, "edge weight should have bw and delay" 248 | try: 249 | weight = weight_dict['bw'] * setting.FACTOR - weight_dict['delay'] * (1 - setting.FACTOR) 250 | return weight 251 | except TypeError: 252 | print("discovery ERROR---> weight_dict['bw']: ", weight_dict['bw']) 253 | print("discovery ERROR---> weight_dict['delay']: ", weight_dict['delay']) 254 | return None 255 | 256 | def get_shortest_paths(self, src_dpid, dst_dpid, weight=None): 257 | """ 计算src到dst的最短路径,存在self.shortest_path_table中""" 258 | graph = self.graph.copy() 259 | # print(graph.edges) 260 | # print("SSSS--->get_shortest_paths ==calculate shortest path %s to %s" % (src_dpid, dst_dpid)) 261 | self.shortest_path_table[(src_dpid, dst_dpid)] = nx.shortest_path(graph, 262 | source=src_dpid, 263 | target=dst_dpid, 264 | weight=weight, 265 | method=setting.METHOD) 266 | # print("SSSS--->get_shortest_paths ==[PATH] %s <---> %s: %s" % ( 267 | # src_dpid, dst_dpid, self.shortest_path_table[(src_dpid, dst_dpid)])) 268 | 269 | def calculate_all_nodes_shortest_paths(self, weight=None): 270 | """ 根据已构建的图,计算所有nodes间的最短路径,weight为权值,可以为calculate_weight()该函数""" 271 | self.shortest_path_table = {} # 先清空,再计算 272 | for src in self.graph.nodes(): 273 | for dst in self.graph.nodes(): 274 | if src != dst: 275 | self.get_shortest_paths(src, dst, weight=weight) 276 | else: 277 | continue 278 | 279 | def get_host_ip_location(self, host_ip): 280 | """ 281 | 通过host_ip查询 self.access_table: {(dpid, in_port): (src_ip, src_mac)} 282 | 获得(dpid, in_port) 283 | """ 284 | if host_ip == "0.0.0.0" or host_ip == "255.255.255.255": 285 | return None 286 | 287 | for key in self.access_table.keys(): # {(dpid, in_port): (src_ip, src_mac)} 288 | if self.access_table[key][0] == host_ip: 289 | # print("discovery--->zzzz---> key", key) 290 | return key 291 | print("SSS--->get_host_ip_location: %s location is not found" % host_ip) 292 | return None 293 | 294 | def get_ip_by_dpid(self, dpid): 295 | """ 296 | 通过 dpid 查询 {(dpid, in_port): (src_ip, src_mac)} 297 | 获得 ip src_ip 298 | """ 299 | for key, value in self.access_table.items(): 300 | if key[0] == dpid: 301 | return value[0] 302 | print("SSS--->get_ip_by_dpid: %s ip is not found" % dpid) 303 | return None 304 | 305 | def parse_topo_links_info(self): 306 | m_graph = nx.Graph() 307 | parser = ET.parse(self.link_info_xml) 308 | root = parser.getroot() 309 | 310 | # links_info_element = root.find("links_info") 311 | 312 | def _str_tuple2int_list(s: str): 313 | s = s.strip() 314 | assert s.startswith('(') and s.endswith(")"), '应该为str的元组,如 "(1, 2)"' 315 | s_ = s[1: -1].split(', ') 316 | return [int(i) for i in s_] 317 | 318 | node1, node2, port1, port2, bw, delay, loss = None, None, None, None, None, None, None 319 | for e in root.iter(): 320 | if e.tag == 'links': 321 | node1, node2 = _str_tuple2int_list(e.text) 322 | elif e.tag == 'ports': 323 | port1, port2 = _str_tuple2int_list(e.text) 324 | elif e.tag == 'bw': 325 | bw = float(e.text) 326 | elif e.tag == 'delay': 327 | delay = float(e.text[:-2]) 328 | elif e.tag == 'loss': 329 | loss = float(e.text) 330 | else: 331 | print(e.tag) 332 | continue 333 | m_graph.add_edge(node1, node2, port1=port1, port2=port2, bw=bw, delay=delay, loss=loss) 334 | 335 | for edge in m_graph.edges(data=True): 336 | print(edge) 337 | return m_graph 338 | -------------------------------------------------------------------------------- /ryu/setting.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from functools import reduce 3 | 4 | WORK_DIR = Path.cwd().parent 5 | 6 | # setting.py 7 | FACTOR = 0.9 # the coefficient of 'bw' , 1 - FACTOR is the coefficient of 'delay' 8 | 9 | METHOD = 'dijkstra' # the calculation method of shortest path 10 | 11 | DISCOVERY_PERIOD = 10 # discover network structure's period, the unit is seconds. 12 | 13 | MONITOR_PERIOD = 5 # monitor period, bw 14 | 15 | DELAY_PERIOD = 1.3 # detector period, delay 16 | 17 | SCHEDULE_PERIOD = 6 # shortest forwarding network awareness period 18 | 19 | PRINT_SHOW = False # show or not show print 20 | 21 | INIT_TIME = 30 # wait init for awareness 22 | 23 | PRINT_NUM_OF_LINE = 8 # 一行打印8个值 24 | 25 | LOGGER = True # 是否保存日志 26 | 27 | LINKS_INFO = WORK_DIR / "mininet/links_info/links_info.xml" # 链路信息的xml文件路径 28 | 29 | # SRC_IP = "10.0.0.1" 30 | # DST_MULTICAST_IP = {'224.1.1.1': 1, } # 组播地址: 标号(下面的索引号) 31 | # DST_GROUP_IP = [["10.0.0.2", "10.0.0.3", "10.0.0.4"], ] # 组成员的ip,(索引为上面的标号) 32 | 33 | DST_MULTICAST_IP = {'224.1.1.1': ["10.0.0.2", "10.0.0.4", "10.0.0.11"], } # 组播地址: 组成员的ip 34 | 35 | WEIGHT = 'bw' 36 | # FINAL_SWITCH_FLOW_IDEA = 1 37 | 38 | finish_time_file = WORK_DIR / "mininet/finish_time.json" 39 | 40 | 41 | def list_insert_one_by_one(list1, list2): 42 | l = [] 43 | for x, y in zip(list1, list2): 44 | l.extend([x, y]) 45 | return l 46 | 47 | 48 | def gen_format_str(num): 49 | fmt = '' 50 | for i in range(num): 51 | fmt += '{{:<{}}}' 52 | # fmt += '\n' 53 | return fmt 54 | 55 | 56 | # 只能打印key: value的两列,还不如用pandas 57 | def print_pretty_table(param, titles, widths, table_name='zzlong', logger=None): 58 | """ 59 | 打印一个漂亮的表 60 | :param param: 要打印的字典,dict 61 | :param titles: 每列的title 62 | :param widths: 每列的宽度 63 | :param table_name: 表名字 64 | :param logger: 用什么打印 print / logger.info 65 | :return: None 66 | """ 67 | f = logger if logger else print 68 | all_width = reduce(lambda x, y: x + y, widths) 69 | cut_line = "=" * all_width 70 | # 表名字 71 | w = all_width - len(table_name) 72 | if w > 1: 73 | f(cut_line[:w // 2] + table_name + cut_line[w // 2: w]) 74 | else: 75 | f("=" + table_name + "=") 76 | 77 | # 以表格输出 78 | if isinstance(param, dict): 79 | # 获得{:^{}}多少个这个 80 | fmt = gen_format_str(len(titles)) 81 | # 确定宽度 82 | width_fmt = fmt.format(*widths) 83 | # 确定值 84 | title_fmt = width_fmt.format(*titles) 85 | # 打印第一行title 86 | f(title_fmt) 87 | # 打印分割线 88 | f(cut_line) 89 | # 打印每一行的值 90 | for k, v in param.items(): 91 | content_fmt = width_fmt.format(str(k), str(v)) 92 | # 打印内容 93 | f(content_fmt) 94 | 95 | # 打印分割线 96 | f(cut_line + '\n') 97 | 98 | 99 | # def print_pretty_list(param, num, width=10, table_name='zzlong', logger=None): 100 | # """ 101 | # 按每行固定个,打印列表中的值 102 | # :param param: 要打印的列表 list 103 | # :param num: 每行多少个值 104 | # :param width: 每个值的宽度 105 | # :param table_name: 表名字 106 | # :param logger: 用什么打印 print / logger.info 107 | # :return: None 108 | # """ 109 | # f = logger if logger else print 110 | # all_widths = num * width 111 | # cut_line = "=" * all_widths 112 | # # 表名字 113 | # w = all_widths - len(table_name) 114 | # if w > 1: 115 | # f(cut_line[:w // 2] + table_name + cut_line[w // 2: w]) 116 | # else: 117 | # f("=" + table_name + "=") 118 | # 119 | # # 直接打印 120 | # temp = 0 121 | # for i in range(len(param) // num): 122 | # f(param[temp: temp + num]) 123 | # temp += num 124 | # if param[temp:]: 125 | # f(param[temp:]) 126 | # else: 127 | # pass 128 | # 129 | # # 打印分割线 130 | # f(cut_line + '\n') 131 | 132 | 133 | if __name__ == '__main__': 134 | # a = {'test1': [11, 12, 13], 'test2': [21, 22, 23], 'test3': [31, 32, 33]} 135 | # print_pretty_table(a, ['my_test', 'values'], [10, 14], 'test_table', print) 136 | # 137 | # b = list(range(30)) 138 | # print_pretty_list(b, 10, 10) 139 | print(WORK_DIR) 140 | print(LINKS_INFO) 141 | print(finish_time_file) 142 | --------------------------------------------------------------------------------