├── Notice.txt
├── README.md
├── RL
├── config.py
├── env.py
├── env_config.py
├── log.py
├── net.py
├── replymemory.py
├── rl.py
├── test.py
└── train.py
├── mininet
├── generate_matrices.py
├── generate_nodes_topo.py
├── iperf_script.py
├── links_info
│ └── links_info.xml
└── topologies
│ ├── topology-anonymised.xml
│ └── topology2.xml
└── ryu
├── arp_handler.py
├── network_delay.py
├── network_monitor.py
├── network_structure.py
├── setting.py
└── shortest_path_forwarding.py
/Notice.txt:
--------------------------------------------------------------------------------
1 | Dear Researcher
2 | Thank you for your interest in our work.
3 | However, since we found that in 2024, a research group of graduate students used our open source code to participate in the China ## Network Technology Competition and won a prize, with almost no changes to our work and code, and in 2023, when we reviewed a manuscript for some journal, we found that the manuscript submitted by a graduate student from Sichuan Province also had stolen manuscripts and codes with almost no changes. Therefore, we withdrew all the codes from publication and only some of them were made available. The interested research groups could reimplement our code according to the detailed steps introduced in our paper.
4 | We strongly condemn this immoral act.
5 | Authors.
6 | November, 2024
7 |
8 |
9 | 声明
10 |
11 | 由于发现24年有研究课题小组的研究生用我们开源的代码参加中国##学生网络技术比赛并获奖,几乎没有任何的内容设计和代码上的改动,我们课题组在23年为某一中文核心期刊审稿时发现四川省某高校的一个研究生的投稿稿件也是盗用了我们课题组未上网的中文稿件和代码,因此我们暂停公开所有代码,有兴趣的课题组可以根据我们论文介绍的详细步骤实现我们的代码。
12 | 我们严厉谴责这种不道德的行为。
13 | 2014年11月
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DRL-M4MR: An intelligent multicast routing approach based on DQN deep reinforcement learning in SDN
2 |
3 | Traditional multicast routing methods have some problems in constructing a multicast tree. These problems include limited access to network state information, poor adaptability to dynamic and complex changes in the network, and inflexible data forwarding. To address these defects, the optimal multicast routing problem in software-defined networking (SDN) is tailored as a multiobjective optimization problem, and DRL-M4MR, an intelligent multicast routing algorithm based on the deep Q network (DQN) deep reinforcement learning (DRL) method is designed to construct a multicast tree in a software-defined network. First, combining the characteristics of SDN global network-aware information, the multicast tree state matrix, link bandwidth matrix, link delay matrix and link packet loss rate matrix are designed as the state space of the reinforcement learning agent to solve the problem in that the original method cannot make full use of network status information. Second, the action space of the agent is all the links in the network, and the action selection strategy is designed to add the links to the current multicast tree in four cases. Third, single-step and final reward function forms are designed to guide the agent to make decisions to construct the optimal multicast tree. The double network architectures, dueling network architectures and prioritized experience replay are adopted to improve the learning efficiency and convergence of the agent. Finally, after the DRL-M4MR agent is trained, the SDN controller installs the multicast flow entries by reversely traversing the multicast tree to the SDN switches to implement intelligent multicast routing. The experimental results show that, compared with existing algorithms, the multicast tree constructed by DRL-M4MR can obtain better bandwidth, delay, and packet loss rate performance after training, and it can make more intelligent multicast routing decisions in a dynamic network environment.
4 |
5 |
--------------------------------------------------------------------------------
/RL/config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : config.py
3 | # @Date : 2022-05-18
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # @From :
6 | import time
7 | from pathlib import Path
8 | import platform
9 | import torch
10 | import numpy
11 |
12 | sys_platform = platform.system()
13 |
14 |
15 | class Config:
16 | DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
17 | TIME = time.strftime("%Y%m%d%H%M%S", time.localtime())
18 |
19 | # dtype
20 | NUMPY_TYPE = numpy.float32
21 | TORCH_TYPE = torch.float32
22 |
23 | # Net
24 | NUM_STATES = 14 ** 2 * 4
25 | NUM_ACTIONS = 14 ** 2
26 |
27 | PKL_STEP = 3
28 | PKL_NUM = 120
29 | PKL_START = 10 # index 从0开始
30 | PKL_CUT_NUM = PKL_START + PKL_NUM * PKL_STEP # 结束index
31 |
32 | # DQN
33 | BATCH_SIZE = 8
34 |
35 | # memory pool
36 | MEMORY_CAPACITY = 1024 * 2
37 |
38 | # nsteps
39 | N_STEPS = 1
40 |
41 | # hyper-parameters
42 | LR = 1e-5 # learning rate
43 | REWARD_DECAY = 0.9 # gamma
44 |
45 | USE_DECAY = True
46 | E_GREEDY_START = 1 # epsilon
47 | E_GREEDY_FINAL = 0.01 # epsilon
48 | E_GREEDY_DECAY = 700 # epsilon
49 | E_GREEDY = [E_GREEDY_START, E_GREEDY_FINAL, E_GREEDY_DECAY]
50 |
51 | E_GREEDY_ORI = 0.2
52 |
53 | TAU = 1
54 |
55 | EPISODES = 4000 # 训练代数
56 | UPDATE_TARGET_FREQUENCY = 10
57 |
58 | REWARD_DEFAULT = [1.0, 0.1, -0.1, -1]
59 |
60 | DISCOUNT = 0.9
61 |
62 | BETA1 = 1
63 | BETA2 = 1
64 | BETA3 = 1
65 |
66 | A_STEP = 0
67 | B_STEP = None
68 |
69 | A_PATH = 0
70 | B_PATH = 1
71 |
72 | START_SYMBOL = 1
73 | END_SYMBOL = 2
74 | STEP_SYMBOL = 1
75 | BRANCH_SYMBOL = 2
76 |
77 | # control
78 | # START_LEARN = 200
79 | CLAMP = False # 梯度裁剪
80 |
81 | # file path and pkl path
82 | if sys_platform == "Windows":
83 | xml_topology_path = Path(
84 | r'D:\WorkSpace\Hello_Myself\Hello_Multicast\RLMulticastProject\mininet\topologies\topology2.xml')
85 | pkl_weight_path = Path(
86 | r"D:\WorkSpace\Hello_Myself\Hello_Multicast\RLMulticastProject\ryu\pickle\2022-03-11-19-40-21")
87 |
88 |
89 | else:
90 | xml_topology_path = Path(r'/home/dell/RLMulticastProject/mininet/topologies/topology2.xml')
91 | pkl_weight_path = Path(r"/home/dell/RLMulticastProject/ryu/pickle/2022-03-11-19-40-21")
92 |
93 | # nodes
94 | start_node = 12
95 | end_nodes = [2, 4, 11]
96 |
97 | @classmethod
98 | def set_num_states_actions(cls, state_space_num, action_space_num):
99 | cls.NUM_STATES = state_space_num
100 | cls.NUM_ACTIONS = action_space_num
101 |
102 | @classmethod
103 | def log_params(cls, logger):
104 | rewards_info = "\n===rewards===\n" + \
105 | f" REWARD_DEFAULT:{cls.REWARD_DEFAULT}\n"
106 | lr_info = "===LR===\n" + \
107 | f" LR:{cls.LR}\n"
108 | episodes_info = "===EPISODES===\n" + \
109 | f" EPISODES:{cls.EPISODES}\n"
110 | batchsize_info = "===BATCH_SIZE===\n" + \
111 | f" BATCH_SIZE:{cls.BATCH_SIZE}\n"
112 |
113 | update_infp = "===UPDATE_TARGET_FREQUENCY===\n" + \
114 | f" UPDATE_TARGET_FREQUENCY:{cls.UPDATE_TARGET_FREQUENCY}\n"
115 | gamma_info = "===gamma==\n" + \
116 | f" REWARD_DECAY:{cls.REWARD_DECAY}\n"
117 | nsteps_info = "===nsteps===\n" + \
118 | f" N_STEPS:{cls.N_STEPS}\n"
119 | egreedy_info = "===egreedy===\n" + \
120 | f" [E_GREEDY_START, E_GREEDY_FINAL, E_GREEDY_DECAY:{cls.E_GREEDY}\n"
121 | pickle_info = "===Pickle Param===\n" + \
122 | f" PKL_START, PKL_NUM, PKL_STEP:{cls.PKL_STEP}, {cls.PKL_NUM}, {cls.PKL_START}\n"
123 |
124 | env_info = "===ENV===\n" + \
125 | f" DISCOUNT:{cls.DISCOUNT},\n BETA1, BETA2, BETA3:{cls.BETA1},{cls.BETA2},{cls.BETA3}\n"
126 | logger.info(
127 | rewards_info + lr_info + episodes_info + batchsize_info + update_infp + gamma_info + nsteps_info + egreedy_info + pickle_info + env_info)
128 |
129 | # logger.info(cls.__dict__)
130 |
131 | @classmethod
132 | def set_lr(cls, lr):
133 | cls.LR = lr
134 |
135 | @classmethod
136 | def set_nsteps(cls, nsteps):
137 | cls.N_STEPS = nsteps
138 |
139 | @classmethod
140 | def set_batchsize(cls, batchsize):
141 | cls.BATCH_SIZE = batchsize
142 |
143 | @classmethod
144 | def set_egreedy(cls, egreedy):
145 | cls.E_GREEDY_DECAY = egreedy
146 |
147 | @classmethod
148 | def set_gamma(cls, gamma):
149 | cls.REWARD_DECAY = gamma
150 |
151 | @classmethod
152 | def set_update_frequency(cls, update_frequency):
153 | cls.UPDATE_TARGET_FREQUENCY = update_frequency
154 |
155 | @classmethod
156 | def set_tau(cls, tau):
157 | cls.TAU = tau
158 |
159 | @classmethod
160 | def set_rewards(cls, rewards):
161 | cls.REWARD_DEFAULT = rewards
162 |
163 | @classmethod
164 | def print_cls_dict(cls):
165 | print(cls.__dict__)
166 |
--------------------------------------------------------------------------------
/RL/env.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : env.py
3 | # @Date : 2022-05-18
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # @From :
6 | import copy
7 | import math
8 | from itertools import zip_longest
9 | from math import exp
10 | from itertools import tee
11 | from functools import reduce
12 |
13 | import numpy as np
14 | import networkx as nx
15 |
16 | # from env_config import *
17 | from config import Config
18 |
19 |
20 | class MulticastEnv:
21 | """
22 | 类初始化时,需要一个nx.Graph的图
23 | 修改图时,调用self.modify_graph(graph),修改self.graph属性
24 | 调用环境前,使用self.reset(start, ends),重置环境的属性、列表、矩阵等
25 | 使用 self.step(link) 前进一步, 修改当前树,修改枝,修改路由矩阵、列表
26 |
27 | ---
28 | env = MulticastEnv(graph)
29 | env.reset(start_node, end_nodes)
30 | env.step(link)
31 | """
32 |
33 | def __init__(self, graph, numpy_type=Config.NUMPY_TYPE, normalize=True):
34 | """
35 | :param graph: nx.graph
36 | """
37 | self.graph = graph
38 | self.nodes = sorted(graph.nodes) if graph is not None else None
39 | self.edges = sorted(graph.edges, key=lambda x: (x[0], x[1])) if graph is not None else None
40 | self.numpy_type = numpy_type
41 |
42 | self.tree_nodes = set() # 树节点
43 | self.route_graph = nx.Graph() # 路由列表
44 | self.branches = set() # 枝列表
45 | self.branches_graph = nx.Graph()
46 | self.mask = []
47 |
48 | self.n_actions = None
49 |
50 | self.start = None
51 | self.ends = None
52 | self.ends_constant = None
53 |
54 | self.adj_matrix = None
55 | self.route_matrix = None # 路由矩阵
56 | self.state_matrix = None
57 | self.bw_matrix = None
58 | self.delay_matrix = None
59 | self.loss_matrix = None
60 |
61 | self.normal_bw_matrix = None
62 | self.normal_delay_matrix = None
63 | self.normal_loss_matrix = None
64 |
65 | self.step_normal_bw_matrix = None
66 | self.step_normal_delay_matrix = None
67 | self.step_normal_loss_matrix = None
68 |
69 | self.step_num = 1
70 | self.alley_num = 0
71 |
72 | self.max_bw = np.finfo(np.float32).eps
73 | self.min_bw = 0
74 |
75 | self.max_delay = np.finfo(np.float32).eps
76 | self.min_delay = 0
77 |
78 | self.max_loss = np.finfo(np.float32).eps
79 | self.min_loss = 0
80 |
81 | self.all_link_delay_sum = 0
82 | self.all_link_non_loss_prob = 0
83 |
84 | if graph is not None:
85 | # self.parse_graph_edge_data() # 解析图
86 | self.set_nodes()
87 | self.set_edges()
88 | self.set_adj_matrix() # 设置邻接矩阵
89 |
90 | self.done_reward = None
91 | self.step_reward = None
92 | self.hell_reward = None
93 | self.alley_reward = None
94 |
95 | self.discount = Config.DISCOUNT
96 |
97 | self.beta1 = Config.BETA1
98 | self.beta2 = Config.BETA2
99 | self.beta3 = Config.BETA3
100 |
101 | self.scale = normalize
102 | self.a_step = Config.A_STEP
103 | self.b_step = Config.B_STEP
104 |
105 | self.a_path = Config.A_PATH
106 | self.b_path = Config.B_PATH
107 |
108 | self.start_symbol = Config.START_SYMBOL
109 | self.end_symbol = Config.END_SYMBOL
110 | self.step_symbol = Config.STEP_SYMBOL
111 | self.branch_symbol = Config.BRANCH_SYMBOL
112 |
113 | self.set_reward_conf(Config.REWARD_DEFAULT)
114 | self.set_a_b(0, self.step_reward, 0, self.done_reward)
115 |
116 | def initialize_params(self):
117 | """
118 | 重置 为空
119 | :return: None
120 | """
121 | self.tree_nodes = set() # 树节点
122 | self.route_graph = nx.Graph() # 树
123 | self.branches = set() # 枝列表
124 | self.branches_graph = nx.Graph()
125 | self.mask = []
126 |
127 | self.n_actions = None
128 |
129 | self.start = None
130 | self.ends = None
131 | self.ends_constant = None
132 |
133 | self.adj_matrix = None
134 | self.route_matrix = None # 路由矩阵
135 | self.state_matrix = None
136 | self.bw_matrix = None
137 | self.delay_matrix = None
138 | self.loss_matrix = None
139 |
140 | self.normal_bw_matrix = None
141 | self.normal_delay_matrix = None
142 | self.normal_loss_matrix = None
143 |
144 | self.step_normal_bw_matrix = None
145 | self.step_normal_delay_matrix = None
146 | self.step_normal_loss_matrix = None
147 |
148 | self.step_num = 1
149 | self.alley_num = 0
150 |
151 | self.max_bw = np.finfo(np.float32).eps
152 | self.min_bw = 0
153 |
154 | self.max_delay = np.finfo(np.float32).eps
155 | self.min_delay = 0
156 |
157 | self.max_loss = np.finfo(np.float32).eps
158 | self.min_loss = 0
159 |
160 | self.all_link_delay_sum = 0
161 | self.all_link_non_loss_prob = 0
162 |
163 | def set_params(self):
164 | """
165 | 设置值
166 | :return: None
167 | """
168 | self.set_nodes()
169 | self.set_edges()
170 | self.set_adj_matrix()
171 | self.parse_graph_edge_data()
172 |
173 | def parse_graph_edge_data(self):
174 | """
175 | 解析边的 bw, delay, loss 存入矩阵
176 | :return: bw_matrix, delay_matrix, loss_matrix
177 | """
178 | _num = len(self.nodes)
179 | # 全 - 1 矩阵
180 | bw_matrix = - np.ones((_num, _num), dtype=self.numpy_type)
181 | delay_matrix = - np.ones((_num, _num), dtype=self.numpy_type)
182 | loss_matrix = - np.ones((_num, _num), dtype=self.numpy_type)
183 |
184 | all_link_delay_sum, all_link_non_loss_prob = 0, 1
185 | for edge in self.graph.edges.data():
186 | _start, _end, _data = edge
187 | bw_matrix[_start - 1][_end - 1] = _data['bw']
188 | bw_matrix[_end - 1][_start - 1] = _data['bw']
189 |
190 | delay_matrix[_start - 1][_end - 1] = _data['delay']
191 | delay_matrix[_end - 1][_start - 1] = _data['delay']
192 | all_link_delay_sum += _data['delay']
193 |
194 | loss_matrix[_start - 1][_end - 1] = _data['loss']
195 | loss_matrix[_end - 1][_start - 1] = _data['loss']
196 | all_link_non_loss_prob *= 1 - _data['loss']
197 |
198 | maximum_spt_delay = list(nx.maximum_spanning_edges(self.graph, weight='delay'))
199 | self.all_link_delay_sum = reduce(lambda x, y: x + y, [link[2]['delay'] for link in maximum_spt_delay])
200 | maximum_spt_loss = list(nx.maximum_spanning_edges(self.graph, weight='loss'))
201 | self.all_link_non_loss_prob = reduce(lambda x, y: x * y, [1 - link[2]['loss'] for link in maximum_spt_loss])
202 |
203 | self.max_bw, self.min_bw, bw_matrix = self.get_non_minus_one_max_min(bw_matrix)
204 | self.max_delay, self.min_delay, delay_matrix = self.get_non_minus_one_max_min(delay_matrix)
205 | self.max_loss, self.min_loss, loss_matrix = self.get_non_minus_one_max_min(loss_matrix)
206 |
207 | self.bw_matrix = bw_matrix
208 | self.delay_matrix = delay_matrix
209 | self.loss_matrix = loss_matrix
210 |
211 | self.normal_bw_matrix, self.step_normal_bw_matrix = self.normalize_param_matrix(bw_matrix, self.max_bw,
212 | self.min_bw)
213 | self.normal_delay_matrix, self.step_normal_delay_matrix = self.normalize_param_matrix(delay_matrix,
214 | self.max_delay,
215 | self.min_delay)
216 | self.normal_loss_matrix, self.step_normal_loss_matrix = self.normalize_param_matrix(loss_matrix, self.max_loss,
217 | self.min_loss)
218 | return bw_matrix, delay_matrix, loss_matrix, all_link_delay_sum
219 |
220 | def normalize_param_matrix(self, matrix, max_value, min_value):
221 | """
222 | 将矩阵按最大最小值正则 值在0到1之间
223 | """
224 | normal_m = (matrix - min_value) / (max_value + 1e-6)
225 | normal_m -= normal_m * np.identity(len(self.nodes))
226 |
227 | step_normal_m = self.a_step + normal_m * (self.b_step - self.a_step)
228 | return normal_m, step_normal_m
229 |
230 | def get_adjacent_link_params_max_min(self, node):
231 | """
232 | 获得当前节点的周围链路的 参数 的最大最小值
233 | :param node: 当前节点
234 | :return: [(max_bw, min_bw), (max_delay, min_delay), (max_loss, min_loss)]
235 | """
236 | adj_index = self.get_node_adj_index_list(node)
237 | max_bw = self.bw_matrix[self.node_to_index(node), adj_index].max()
238 | min_bw = self.bw_matrix[self.node_to_index(node), adj_index].min()
239 |
240 | max_delay = self.delay_matrix[self.node_to_index(node), adj_index].max()
241 | min_delay = self.delay_matrix[self.node_to_index(node), adj_index].min()
242 |
243 | max_loss = self.loss_matrix[self.node_to_index(node), adj_index].max()
244 | min_loss = self.loss_matrix[self.node_to_index(node), adj_index].min()
245 |
246 | return [(min_bw, max_bw), (min_delay, max_delay), (min_loss, max_loss)]
247 |
248 | @staticmethod
249 | def get_non_minus_one_max_min(matrix):
250 | """
251 | 返回除去-1的最大最小值
252 | :param matrix: 要求最大最小的矩阵
253 | :return: max, _min, matrix 【最大最小值, matrix将-1改为0】
254 | """
255 | non_minus_one_mask = np.where(matrix != -1)
256 | _max = matrix[non_minus_one_mask].max()
257 | _min = matrix[non_minus_one_mask].min()
258 | matrix[np.where(matrix == -1)] = 0
259 | return _max, _min, matrix
260 |
261 | def modify_graph(self, graph: nx.Graph):
262 | """
263 | 修改 属性 并返回 graph
264 | :param graph: networkx的图
265 | :return: self.graph
266 | """
267 | self.graph = graph
268 | return self.graph
269 |
270 | def set_nodes(self):
271 | """
272 | 修改并返回 nodes
273 | :return: self.nodes
274 | """
275 |
276 | self.nodes = sorted(self.graph.nodes)
277 | return self.nodes
278 |
279 | def set_edges(self):
280 | """
281 | 设置 边
282 | :return: None
283 | """
284 | # [(1, 3), (1, 4), (1, 5), (1, 9), (1, 11),
285 | # (2, 4), (2, 5), (2, 6), (4, 5), (4, 6), (4, 7),
286 | # (4, 9), (4, 11), (5, 6), (5, 7), (5, 8), (7, 8),
287 | # (7, 9), (8, 14), (9, 10), (9, 13), (12, 13), (12, 14)]
288 | edges = [(e[0], e[1]) if e[0] < e[1] else (e[1], e[0]) for e in self.graph.edges]
289 | self.edges = sorted(edges, key=lambda x: (x[0], x[1]))
290 |
291 | def set_adj_matrix(self):
292 | """
293 | 设置并返回邻接矩阵
294 | :return: adj_m
295 | """
296 | adj_m = nx.adjacency_matrix(self.graph, self.nodes).todense()
297 | self.adj_matrix = np.array(adj_m, dtype=self.numpy_type)
298 | return self.adj_matrix
299 |
300 | def read_pickle_and_modify(self, path):
301 | """
302 | 读取图graph的pickle文件
303 | 初始化所有参数
304 | :param path: pickle文件路径
305 | :return: 图 nx.graph
306 | """
307 | pkl_graph = nx.read_gpickle(path)
308 | self.modify_graph(pkl_graph)
309 |
310 | self.initialize_params()
311 | self.set_params()
312 |
313 | return pkl_graph
314 |
315 | @staticmethod
316 | def node_to_index(node):
317 | """
318 | 节点从1起, 索引从0起,将节点号转为索引号
319 | :param node: 节点号
320 | :return: 索引号
321 | """
322 | return node - 1
323 |
324 | @staticmethod
325 | def index_to_node(index):
326 | """
327 | 索引号转为节点号
328 | :param index: 索引号
329 | :return: 节点号
330 | """
331 | return index + 1
332 |
333 | def add_tree_node(self, node):
334 | """
335 | 向树中添加节点
336 | """
337 | self.tree_nodes.add(node)
338 |
339 | def get_node_adj_index_list(self, node):
340 | """
341 | 根据当前节点获得邻居节点的索引号
342 | :param node: 节点序号
343 | :return: 邻居节点的索引号列表
344 | """
345 | index = self.node_to_index(node)
346 | adj_ids = np.nonzero(self.adj_matrix[index]) # return tuple
347 | adj_index_list = adj_ids[0].tolist()
348 | return adj_index_list
349 |
350 | def modify_branches(self, node):
351 | """
352 | 修改枝列表
353 | 修改之前需要先 将满足条件的node添加到tree node, self.add_tree_node(node)
354 | prim算法中有可添加链路方法
355 | 2022/3/17 大改
356 | :return: None
357 | """
358 | # A. 修改 branch
359 | # 取单行, 即为node的邻居节点行, 是邻居节点则为1, 否则是0
360 | adj_ids_list = self.get_node_adj_index_list(node)
361 | adj_nodes = set(map(self.index_to_node, adj_ids_list)) # 序号转为node
362 | # 1. 更新branches中的节点
363 | self.branches.update(adj_nodes)
364 | # 2. 移除branches中的tree集合的节点
365 | self.branches.difference_update(self.tree_nodes)
366 |
367 | # B. 修改 branches_graph
368 | # 1. 枝节点的周围节点添加进去
369 | for u, v in zip_longest([node], list(adj_nodes), fillvalue=node): # zip_longest
370 | if v in self.tree_nodes:
371 | self.branches_graph.remove_edge(u, v)
372 | else:
373 | self.branches_graph.add_edge(u, v)
374 |
375 | def get_branches_matrix(self):
376 | """
377 | 获得当前枝的邻接矩阵表示
378 | """
379 | branches_m = np.zeros((len(self.nodes), len(self.nodes)), dtype=self.numpy_type)
380 | for e in self.branches_graph.edges:
381 | u, v = e
382 | u = self.node_to_index(u)
383 | v = self.node_to_index(v)
384 | branches_m[u][v] = self.branch_symbol
385 | branches_m[v][u] = self.branch_symbol
386 | return branches_m
387 |
388 | def add_to_route_graph(self, link):
389 | """
390 | 添加link到 self.route_graph中
391 | :param link: 链路
392 | :return: None
393 | """
394 | self.route_graph.add_edge(*link)
395 |
396 | def get_mask_of_current_tree_branch(self):
397 | """
398 | 获得是当前树的branch的mask
399 | (1, 0, 0, 1, ...)
400 | :return: mask
401 | """
402 | _branches = self.branches_graph.edges
403 | edges = self.edges # 这个list要保证不变
404 | mask = np.zeros(len(self.graph.edges), dtype=bool)
405 | for edge in _branches:
406 | if edge in edges:
407 | # 将索引位置设置为True
408 | mask[edges.index(edge)] = True
409 | elif edge[::-1] in edges:
410 | mask[edges.index(edge[::-1])] = True
411 | else:
412 | raise IndexError("edge not in branches_graph")
413 |
414 | self.mask = mask
415 |
416 | return mask
417 |
418 | def _judge_link(self, link):
419 | """
420 | 弃用,link[1]错误问题
421 | 判断link是否满足当前的情况,是否是当前树的枝
422 | :param link: 下一步链路
423 | :return: True or False
424 | """
425 | # 是否是树枝,否则返回None
426 | end_node = link[1]
427 | if end_node in self.branches:
428 | return True
429 | else:
430 | return False
431 |
432 | def judge_end_node(self, link):
433 | """
434 | 判断 link 中哪个是 长出来的新枝,并修正link方向
435 | :param link: 下一跳的link
436 | :return: tree_node, branch_node
437 | """
438 | tree_node = self.tree_nodes & set(link) # 取交
439 | if len(tree_node) == 1:
440 | branch_node = set(link) - self.tree_nodes # link中包含而tree_nodes中不包含
441 |
442 | try:
443 | return list(tree_node)[0], list(branch_node)[0]
444 | except IndexError:
445 | raise IndexError(f"({link}, {tree_node}, {branch_node}, {self.tree_nodes})")
446 | else:
447 | return None
448 |
449 | def judge_link(self, link):
450 | """
451 | 判断 link 是否是枝
452 | :param link: 下一跳的link
453 | :return: True or False
454 | """
455 | if link in self.branches_graph.edges:
456 | return True
457 | else:
458 | return False
459 |
460 | def reset(self, start, ends):
461 | """
462 | 构建初始的路径矩阵,
463 | :param start: 1, 2,... 比标号多1, 索引标号从0 起
464 | :param ends: list
465 | :return: route_matrix: numpy matrix
466 | """
467 | if self.graph is None:
468 | raise Exception("graph is None")
469 |
470 | self.start = copy.deepcopy(start)
471 | self.ends = copy.deepcopy(ends)
472 | self.ends_constant = copy.deepcopy(ends)
473 |
474 | _num = len(self.nodes)
475 | # 全 0 矩阵
476 | route_matrix = np.zeros((_num, _num), dtype=self.numpy_type)
477 | # 将路径矩阵的源节点设置为 1
478 | _idx = self.node_to_index(start)
479 | route_matrix[_idx, _idx] = self.start_symbol
480 | # 将目的节点设置为 -1
481 | for end in self.ends:
482 | _idx = self.node_to_index(end)
483 | route_matrix[_idx, _idx] = self.end_symbol
484 |
485 | self.route_matrix = route_matrix
486 | self.state_matrix = route_matrix
487 |
488 | # 添加源节点到树
489 | self.add_tree_node(self.start)
490 | # 添加源节点的枝
491 | self.modify_branches(self.start)
492 | # 获得当前树的枝mask
493 | self.get_mask_of_current_tree_branch()
494 |
495 | return route_matrix
496 |
497 | def step(self, link):
498 | """
499 | 更新,将新的路径放入route_matrix
500 | 2022/3/17 不是tree中的节点作为 end_node 而不是 link[1]作为end_node
501 | :param link: 链路 如(1, 2)
502 | :return: state_: numpy.matrix, reward_score: float, route_done: bool
503 | """
504 | assert self.hell_reward is not None
505 | # 判断link是否满足当前的情况,是否是当前 树的枝
506 | link = self.judge_end_node(link)
507 | if link:
508 | tree_node, branch_node = link
509 | # 1. 路由矩阵打上记号
510 | self.route_matrix[self.node_to_index(tree_node)][self.node_to_index(branch_node)] = self.step_symbol
511 | self.route_matrix[self.node_to_index(branch_node)][self.node_to_index(tree_node)] = self.step_symbol
512 | # 2.1 添加目的节点到树节点集合中
513 | self.add_tree_node(branch_node)
514 | # 2.2 修改枝列表
515 | self.modify_branches(branch_node)
516 | # 2.3 获得当前树的枝mask
517 | self.get_mask_of_current_tree_branch()
518 | # 3. 链路添加到路由列表中
519 | self.add_to_route_graph((tree_node, branch_node))
520 |
521 | # 判断是结束了还是往前进了一步
522 | _judge_flag = self._judge_end(branch_node)
523 | if _judge_flag == 'ALL':
524 | route_done = True
525 | state_ = None # 这个动作后状态,无动作
526 | reward_score = self.calculate_path_score()
527 | # reward_score = self.calculate_link_reward(link)
528 | # reward_score = (self.discount ** self.step_num) * reward_score
529 | # reward_score = self.done_reward + self.step_num * self.step_reward
530 | alley, _ = self.find_blind_alley()
531 | # reward_score += alley * self.alley_reward
532 | self.alley_num = alley
533 |
534 | elif _judge_flag == 'PART':
535 | route_done = False
536 | state_ = self.route_matrix
537 | reward_score = self.calculate_link_reward(link)
538 | # reward_score = self.calculate_path_score()
539 | # alley = self.find_blind_alley()
540 | # reward_score += alley * self.alley_reward
541 | elif _judge_flag == "NOT":
542 | route_done = False
543 | state_ = self.route_matrix
544 | reward_score = self.calculate_link_reward(link)
545 | else:
546 | raise Exception("link judge error")
547 | else:
548 | _judge_flag = "HELL"
549 | route_done = False
550 | state_ = self.route_matrix
551 | reward_score = self.hell_reward
552 |
553 | self.step_num += 1
554 | return state_, reward_score, route_done, _judge_flag
555 |
556 | def _judge_end(self, next_node):
557 | """
558 | 判断是否是已经结束。
559 | 如果目的节点空了,表示已经找到所有目的节点
560 | :param next_node: 下一跳
561 | :return: "ALL", "PART", "NOT"
562 | """
563 | assert next_node is not None, "next_node is None"
564 | if next_node in self.ends:
565 | _i = self.ends.index(next_node)
566 | self.ends.pop(_i)
567 |
568 | # 目的节点列表为空
569 | if len(self.ends) == 0:
570 | return "ALL"
571 | else:
572 | return "PART"
573 | return "NOT"
574 |
575 | def reward_score(self, link):
576 | """
577 | 奖励
578 | :param link: 链路
579 | :return: reward
580 | """
581 | score = self.calculate_link_reward(link)
582 | return score
583 |
584 | def step_reward_exp(self):
585 | """
586 | step_reward * e**(1/nodes_num * step_num)
587 | :return: score
588 | """
589 | score = self.step_reward * exp(1 / len(self.nodes) * self.step_num - 1)
590 | return score
591 |
592 | def discount_reward(self, score):
593 | """
594 | 对reward进行discount 处理 score = score * discount ** step_num
595 | :param score: reward
596 | :return: score
597 | """
598 | score *= self.discount ** (self.step_num - 1)
599 | return score
600 |
601 | def link_reward_func(self, bw, delay, loss, b):
602 | """
603 | 计算reward
604 | R = β1 * bw + β2 * (1 - delay) + β3 * (1 - loss)
605 | :param bw: 带宽
606 | :param delay: 时延
607 | :param loss: 丢包率
608 | :param b: 上限
609 | """
610 | return self.beta1 * bw + self.beta2 * (b - delay) + self.beta3 * (b - loss)
611 | # return -(self.beta1 * (b - bw) + self.beta2 * delay + self.beta3 * loss)
612 |
613 | def path_reward_func(self, bw, delay, loss, b):
614 | return self.beta1 * bw + self.beta2 * (b - delay) + self.beta3 * loss
615 | # return -(self.beta1 * (b - bw) + self.beta2 * delay + self.beta3 * (b - loss))
616 |
617 | def calculate_link_reward(self, link):
618 | """
619 | 计算每步选择link reward R
620 | :return: reward
621 | """
622 | e0, e1 = self.node_to_index(link[0]), self.node_to_index(link[1])
623 | bw_hat = self.step_normal_bw_matrix[e0, e1]
624 | delay_hat = self.step_normal_delay_matrix[e0, e1]
625 | loss_hat = self.step_normal_loss_matrix[e0, e1]
626 | r = self.link_reward_func(bw_hat, delay_hat, loss_hat, self.b_step)
627 | return r
628 |
629 | def calculate_path_score(self):
630 | """
631 | 计算路径的回报
632 | :return: reward
633 | """
634 | non_losses = np.array([])
635 | delays = np.array([])
636 |
637 | e2e_bw, e2e_delay, e2e_bw_hat, e2e_delay_hat = self.find_end_to_end_max_bw_delay(self.route_graph, self.start,
638 | self.ends_constant)
639 |
640 | for link in self.route_graph.edges:
641 | delay = self.graph[link[0]][link[1]]['delay']
642 | loss = self.graph[link[0]][link[1]]['loss']
643 | delays = np.append(delays, delay)
644 | non_losses = np.append(non_losses, 1 - loss)
645 | # delay_hat = e2e_delay_hat
646 | num = len(self.ends_constant) + 1
647 | delay_hat = self.min_max_normalize(delays.sum(), self.min_delay * num, self.all_link_delay_sum, a=self.a_path,
648 | b=self.b_path)
649 |
650 | # prob sum
651 | _min = 1 - self.max_loss
652 | _max = 1 - self.min_loss
653 | non_loss_hat = non_losses.prod()
654 | # non_loss_hat = self.min_max_normalize(non_losses.prod(), _min ** num, _max, a=self.a_path, b=self.b_path)
655 |
656 | # if delays.sum() < self.min_delay * num:
657 | # raise ValueError("delays.sum() < self.min_delay * num")
658 |
659 | r = self.path_reward_func(e2e_bw_hat, delay_hat, non_loss_hat, self.b_path)
660 | return r
661 |
662 | def find_end_to_end_max_bw(self, route_graph, start, ends):
663 | """
664 | 找到端对端的最大剩余带宽
665 | :param route_graph:组播树
666 | :param start:源节点
667 | :param ends:目的节点
668 | :return bws:端对端的带宽列表
669 | """
670 | bws = np.array([])
671 | for end in ends:
672 | p = nx.shortest_path(route_graph, source=start, target=end)
673 | bw = self.max_bw
674 |
675 | def pairwise(iterable):
676 | a, b = tee(iterable)
677 | next(b, None)
678 | return zip(a, b)
679 |
680 | for e in pairwise(p):
681 | if self.graph.edges[e[0], e[1]]['bw'] < bw:
682 | bw = self.graph.edges[e[0], e[1]]['bw']
683 | else:
684 | pass
685 | bws = np.append(bws, bw)
686 | return bws
687 |
688 | def find_end_to_end_max_bw_delay(self, route_graph, start, ends):
689 | """
690 | 从源节点到各个目的节点路径,分别的最大剩余带宽 时延
691 | :param route_graph:组播树
692 | :param start:源节点
693 | :param ends:目的节点
694 | :return:最大剩余带宽列表 时延列表 bws, delays, bws_hat, delays_hat
695 | """
696 | bws = np.array([])
697 | delays = np.array([])
698 | delays_hat = np.array([])
699 | for end in ends:
700 | if end in route_graph.nodes:
701 | p = nx.shortest_path(route_graph, source=start, target=end)
702 | bw = self.max_bw
703 | delay = 0
704 |
705 | def pairwise(iterable):
706 | a, b = tee(iterable)
707 | next(b, None)
708 | return zip(a, b)
709 |
710 | for e in pairwise(p):
711 | if self.graph.edges[e[0], e[1]]['bw'] < bw:
712 | bw = self.graph.edges[e[0], e[1]]['bw']
713 | else:
714 | pass
715 | delay += self.graph.edges[e[0], e[1]]['delay']
716 |
717 | bws = np.append(bws, bw)
718 | delays = np.append(delays, delay)
719 |
720 | num = len(self.edges)
721 | delay_hat = self.min_max_normalize(delay, self.min_delay * num, self.all_link_delay_sum, a=self.a_path,
722 | b=self.b_path)
723 |
724 | delays_hat = np.append(delays_hat, delay_hat)
725 | else:
726 | continue
727 | bws_hat = self.min_max_normalize(bws.mean(), self.min_bw, self.max_bw, a=self.a_path, b=self.b_path)
728 | delays_hat = delays_hat.max()
729 |
730 | return bws, delays, bws_hat, delays_hat
731 |
732 | def min_max_normalize(self, x, min_value, max_value, a=None, b=None):
733 | """
734 | 最大最小标准化
735 | :param x: 要标准化的值
736 | :param min_value: 最小值
737 | :param max_value: 最大值
738 | :param a: 下限
739 | :param b: 上限
740 | :return: 标准化后的x_hat
741 | """
742 | if a is None:
743 | a = self.a_step
744 | if b is None:
745 | b = self.b_step
746 |
747 | # 加一个很小的值
748 | x_normal = (x - min_value) / (max_value - min_value + np.finfo(np.float32).eps)
749 | x_hat = a + x_normal * (b - a)
750 | return x_hat.astype(self.numpy_type)
751 |
752 | def get_route_params(self, mode='train'):
753 | """
754 | 计算路径的剩余带宽, 时延, 丢包率
755 | :return : 路径的剩余带宽, 时延, 丢包率
756 | """
757 | bw, delay, loss = 0, 0, 0
758 | num = 0
759 | alley, r_route_graph = self.find_blind_alley()
760 | for r in r_route_graph.edges:
761 | bw += self.graph[r[0]][r[1]]["bw"]
762 | delay += self.graph[r[0]][r[1]]["delay"]
763 | loss += self.graph[r[0]][r[1]]["loss"]
764 | num += 1
765 |
766 | bws = self.find_end_to_end_max_bw(self.route_graph, self.start, self.ends_constant)
767 | if mode == 'train':
768 | length = len(self.route_graph.edges)
769 | elif mode == 'test':
770 | length = len(r_route_graph.edges)
771 |
772 | return bws.mean(), delay / num, loss / num, length, alley
773 |
774 | def map_action(self, action):
775 | """
776 | 根据动作的标号,返回选择的链路
777 | 如 action=1, 返回(1, 2)
778 |
779 | 2022/3/17 large modification
780 |
781 | :param action: 动作标号
782 | :return: 下一跳
783 | """
784 | edge = self.edges[action]
785 | return edge
786 |
787 | def find_blind_alley(self):
788 | """
789 | 找到死角
790 | """
791 | alley = 0
792 | r_graph = self.route_graph.copy()
793 | while True:
794 | del_node = []
795 | for pair in r_graph.degree:
796 | node, degree = pair
797 | if degree == 1 and node not in self.ends_constant and node != self.start:
798 | alley += 1
799 | del_node.append(node)
800 | for node in del_node:
801 | r_graph.remove_node(node)
802 |
803 | if not del_node:
804 | break
805 |
806 | return alley, r_graph
807 |
808 | def set_reward_conf(self, rewards_list):
809 | """
810 | 设置奖励值
811 | """
812 | done_reward, step_reward, alley_reward, hell_reward = rewards_list
813 | self.done_reward = done_reward
814 | self.step_reward = step_reward
815 | self.hell_reward = hell_reward
816 | self.alley_reward = alley_reward
817 |
818 | def set_a_b(self, a_step, b_step, a_path, b_path):
819 | """
820 | 设置归一化ab的值
821 | """
822 |
823 | self.a_step = a_step
824 | self.a_path = a_path
825 | self.b_step = abs(b_step)
826 | self.b_path = abs(b_path)
827 |
--------------------------------------------------------------------------------
/RL/env_config.py:
--------------------------------------------------------------------------------
1 | # # -*- coding: utf-8 -*-
2 | # # @File : env_config.py
3 | # # @Date : 2022-05-18
4 | # # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # # @From :
6 | # # done_reward, step_reward, alley_reward, hell_reward
7 | # # [1.0, 0.01, -0.001, -1], [1.0, 0.1, -0.001, -1], [1.0, 1.0, -0.001, -1]
8 | #
9 | # reward_default = [1.0, 0.1, -0.1, -1]
10 | #
11 | # DISCOUNT = 0.9
12 | #
13 | # BETA1 = 1
14 | # BETA2 = 1
15 | # BETA3 = 1
16 | #
17 | # A_STEP = 0
18 | # B_STEP = None
19 | #
20 | # A_PATH = 0
21 | # B_PATH = 1
22 | #
23 | # START_SYMBOL = 1
24 | # END_SYMBOL = 2
25 | # STEP_SYMBOL = 1
26 | # BRANCH_SYMBOL = 2
27 | #
28 | #
29 | # def traverse_reward_list(reward_list):
30 | # for idx in range(len(reward_list)):
31 | # yield reward_list[idx]
32 | #
33 | # def set_rewards(rewards):
34 | # reward_default = rewards
35 | # return
36 |
--------------------------------------------------------------------------------
/RL/log.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : log.py
3 | # @Date : 2022-05-18
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # @From :
6 | import logging
7 | import time
8 | import os
9 | from pathlib import Path
10 |
11 |
12 | class MyLog:
13 | """
14 | 日志记录
15 | mylog = Mylog()
16 | logger = mylog.logger
17 | logger.info()
18 | logger.warning()
19 | """
20 | def __init__(self, path: Path, filesave=False, consoleprint=True, name=None):
21 | """
22 | log
23 | :param path: 运行日志的当前文件 Path(__file__)
24 | :param filesave: 是否存储日志
25 | :param consoleprint: 是否打印到终端
26 | """
27 |
28 | self.formatter = logging.Formatter(
29 | f"{name}: %(message)s")
30 | self.logger = logging.getLogger(name=__name__)
31 | self.set_log_level()
32 | self.log_path = Path.joinpath(path.parent, 'Logs')
33 | if name is None:
34 | self.log_file = os.path.join(self.log_path, path.stem + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()))
35 | else:
36 | self.log_file = os.path.join(self.log_path, path.stem + name)
37 |
38 | self.mk_log_dir()
39 | # log_path = os.path.join(os.getcwd(), 'Logs')
40 |
41 | if filesave:
42 | self.file_handler()
43 |
44 | if consoleprint:
45 | self.console_handler()
46 |
47 | def set_log_level(self, level=logging.DEBUG):
48 | self.logger.setLevel(level)
49 |
50 | def mk_log_dir(self):
51 | try:
52 | # os.mkdir(log_path)
53 | Path.mkdir(self.log_path)
54 | except FileExistsError:
55 | for child in self.log_path.iterdir():
56 | if child.stat().st_size == 0:
57 | Path.unlink(child)
58 |
59 | def file_handler(self):
60 | fh = logging.FileHandler(self.log_file + '.log',
61 | mode='w', encoding='utf-8', )
62 | fh.setLevel(logging.INFO)
63 |
64 | fh.setFormatter(self.formatter)
65 | self.logger.addHandler(fh)
66 |
67 | def console_handler(self):
68 | ch = logging.StreamHandler()
69 | ch.setLevel(logging.DEBUG)
70 |
71 | ch.setFormatter(self.formatter)
72 |
73 | self.logger.addHandler(ch)
74 |
75 | def pd_to_csv(self, dataframe):
76 | dataframe.to_csv(self.log_file + '.csv')
77 | self.logger.info("csv saved")
78 |
--------------------------------------------------------------------------------
/RL/net.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : net.py
3 | # @Date : 2021-12-06
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | def weight_init(m):
11 | classname = m.__class__.__name__
12 | if classname.find('Conv2d') != -1:
13 | torch.nn.init.xavier_normal_(m.weight.data)
14 | torch.nn.init.constant_(m.bias.data, 0.0)
15 | elif classname.find('Linear') != -1:
16 | torch.nn.init.xavier_normal_(m.weight)
17 | torch.nn.init.constant_(m.bias, 0.0)
18 |
19 |
20 | class MyMulticastNet(nn.Module):
21 | def __init__(self, states_channel, action_num):
22 | super(MyMulticastNet, self).__init__()
23 | self.conv1 = nn.Conv2d(states_channel, 32, kernel_size=(1, 1))
24 | self.relu1 = nn.ReLU()
25 |
26 | self.fc1 = nn.Linear(32 * 14 * 14, 512)
27 | self.fc1_relu = nn.ReLU()
28 |
29 | # self.fc2 = nn.Linear(512, 256)
30 | # self.fc2_relu = nn.ReLU()
31 |
32 | self.adv1 = nn.Linear(512, 256)
33 | self.adv_relu = nn.ReLU()
34 | self.adv2 = nn.Linear(256, action_num)
35 |
36 | self.apply(weight_init)
37 |
38 | def forward(self, x):
39 | x = self.relu1(self.conv1(x))
40 |
41 | x = x.view(x.shape[0], -1)
42 | x = self.fc1_relu(self.fc1(x))
43 | # x = self.fc2_relu(self.fc2(x))
44 |
45 | adv = self.adv_relu(self.adv1(x))
46 | q_value = self.adv2(adv)
47 |
48 | return q_value
49 |
50 |
51 | class MyMulticastNet2(nn.Module):
52 | def __init__(self, states_channel, action_num):
53 | super(MyMulticastNet2, self).__init__()
54 | self.conv1 = nn.Conv2d(states_channel, 32, kernel_size=(1, 1))
55 | self.relu1 = nn.ReLU()
56 |
57 | self.fc1 = nn.Linear(6272, 512)
58 | self.fc1_relu = nn.ReLU()
59 |
60 | self.fc2 = nn.Linear(512, 256)
61 | self.fc2_relu = nn.ReLU()
62 |
63 | self.adv = nn.Linear(256, action_num)
64 | self.val = nn.Linear(256, 1)
65 |
66 | self.apply(weight_init)
67 |
68 | def forward(self, x):
69 | x = self.relu1(self.conv1(x))
70 |
71 | x = x.view(x.shape[0], -1)
72 | x = self.fc1_relu(self.fc1(x))
73 | x = self.fc2_relu(self.fc2(x))
74 |
75 | sate_value = self.val(x)
76 | advantage_function = self.adv(x)
77 |
78 | return sate_value + (advantage_function - advantage_function.mean())
79 |
80 |
81 | class MyMulticastNet3(nn.Module):
82 | def __init__(self, states_channel, action_num):
83 | super(MyMulticastNet3, self).__init__()
84 | self.conv1_1 = nn.Conv2d(states_channel, 32, kernel_size=(5, 1))
85 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=(5, 1))
86 |
87 | self.conv2_1 = nn.Conv2d(states_channel, 32, kernel_size=(1, 5))
88 | self.conv2_2 = nn.Conv2d(32, 32, kernel_size=(1, 5))
89 |
90 | self.fc1 = nn.Linear(5376, 512)
91 | self.fc2 = nn.Linear(512, 256)
92 |
93 | self.adv = nn.Linear(256, action_num)
94 | self.val = nn.Linear(256, 1)
95 |
96 | self.apply(weight_init)
97 |
98 | def forward(self, x):
99 | x1_1 = F.leaky_relu(self.conv1_1(x))
100 | x1_2 = F.leaky_relu(self.conv1_2(x1_1))
101 |
102 | x2_1 = F.leaky_relu(self.conv2_1(x))
103 | x2_2 = F.leaky_relu(self.conv2_2(x2_1))
104 |
105 | x1_3 = x1_2.view(x.shape[0], -1)
106 | x2_3 = x2_2.view(x.shape[0], -1)
107 | x = torch.cat([x1_3, x2_3], dim=1)
108 |
109 | x = F.leaky_relu(self.fc1(x))
110 | x = F.leaky_relu(self.fc2(x))
111 |
112 | sate_value = self.val(x)
113 | advantage_function = self.adv(x)
114 |
115 | return sate_value + (advantage_function - advantage_function.mean())
116 |
117 |
118 | class MyMulticastNet4(nn.Module):
119 | def __init__(self, states_channel, action_num):
120 | super(MyMulticastNet4, self).__init__()
121 | self.conv1_1 = nn.Conv2d(states_channel, 32, kernel_size=(3, 1))
122 | self.conv1_2 = nn.Conv2d(32, 32, kernel_size=(3, 1))
123 | self.conv1_3 = nn.Conv2d(32, 32, kernel_size=(3, 1))
124 |
125 | self.conv2_1 = nn.Conv2d(states_channel, 32, kernel_size=(1, 3))
126 | self.conv2_2 = nn.Conv2d(32, 32, kernel_size=(1, 3))
127 | self.conv2_3 = nn.Conv2d(32, 32, kernel_size=(1, 3))
128 |
129 | self.lstm = nn.LSTMCell(3584, 512)
130 | self.fc1 = nn.Linear(512, 256)
131 |
132 | self.adv = nn.Linear(256, action_num)
133 | self.val = nn.Linear(256, 1)
134 |
135 | self.apply(weight_init)
136 | self.lstm.bias_ih.data.fill_(0)
137 | self.lstm.bias_hh.data.fill_(0)
138 |
139 | def forward(self, inputs):
140 | x, (hx, cx) = inputs
141 | x1_1 = F.leaky_relu(self.conv1_1(x))
142 | x1_2 = F.leaky_relu(self.conv1_2(x1_1))
143 | x1_3 = F.leaky_relu(self.conv1_3(x1_2))
144 |
145 | x2_1 = F.leaky_relu(self.conv2_1(x))
146 | x2_2 = F.leaky_relu(self.conv2_2(x2_1))
147 | x2_3 = F.leaky_relu(self.conv2_3(x2_2))
148 |
149 | x1_3 = x1_3.view(x.shape[0], -1)
150 | x2_3 = x2_3.view(x.shape[0], -1)
151 | x = x1_3 + x2_3
152 |
153 | hx, cx = self.lstm(x, (hx, cx))
154 |
155 | x = hx
156 | x = F.leaky_relu(self.fc1(x))
157 |
158 | sate_value = self.val(x)
159 | advantage_function = self.adv(x)
160 |
161 | return sate_value + (advantage_function - advantage_function.mean()), (hx, cx)
162 |
--------------------------------------------------------------------------------
/RL/replymemory.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : replymemory.py
3 | # @Date : 2022-05-18
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # @From :
6 | import random
7 | import operator
8 | from collections import namedtuple
9 | import torch
10 | import numpy as np
11 |
12 | from config import Config
13 |
14 | # 使用具名元组 快速建立一个类
15 | Transition = namedtuple('Transition',
16 | ('state', 'action', 'reward', 'next_state'))
17 |
18 |
19 | class ExperienceReplayMemory:
20 | def __init__(self, capacity, torch_type) -> None:
21 | self.capacity = capacity
22 | self.memory = []
23 | self.position = 0
24 | self.torch_type = torch_type
25 |
26 | def push(self, *args):
27 | """保存变换"""
28 | if len(self.memory) < self.capacity:
29 | self.memory.append(None)
30 | s, a, r, s_ = args
31 | self.memory[self.position] = Transition(s,
32 | a.reshape(1, -1),
33 | torch.tensor(r, dtype=self.torch_type).reshape(1, -1),
34 | s_)
35 | # self.memory[self.position] = Transition(*torch.tensor(args))
36 | self.position = (self.position + 1) % self.capacity
37 |
38 | def sample(self, batch_size):
39 | return random.sample(self.memory, batch_size), None, None
40 |
41 | def clean_memory(self):
42 | self.memory = []
43 |
44 | def __len__(self):
45 | return len(self.memory)
46 |
47 |
48 | class SegmentTree:
49 | def __init__(self, capacity, operation, neutral_element):
50 | assert capacity > 0 and capacity & (capacity - 1) == 0
51 | self._capacity = capacity
52 | self._value = [neutral_element for _ in range(2 * capacity)]
53 | self._operation = operation
54 |
55 | def _reduce_helper(self, query_start, query_end, node, node_start, node_end):
56 | if query_start == node_start and query_end == node_end:
57 | return self._value[node]
58 | mid = (node_start + node_end) // 2
59 | if query_end <= mid:
60 | return self._reduce_helper(query_start, query_end, 2 * node, node_start, mid)
61 | else:
62 | if mid + 1 <= query_start:
63 | return self._reduce_helper(query_start, query_end, 2 * node + 1, mid + 1, node_end)
64 | else:
65 | return self._operation(
66 | self._reduce_helper(query_start, mid, 2 * node, node_start, mid),
67 | self._reduce_helper(mid + 1, query_end, 2 * node + 1, mid + 1, node_end)
68 | )
69 |
70 | def reduce(self, start=0, end=None):
71 | if end is None:
72 | end = self._capacity
73 | if end <= 0:
74 | end += self._capacity
75 | end -= 1
76 | return self._reduce_helper(start, end, 1, 0, self._capacity - 1)
77 |
78 | def __setitem__(self, idx, val):
79 | idx += self._capacity
80 | self._value[idx] = val
81 | idx //= 2
82 | while idx >= 1:
83 | self._value[idx] = self._operation(
84 | self._value[2 * idx],
85 | self._value[2 * idx + 1]
86 | )
87 | idx //= 2
88 |
89 | def __getitem__(self, idx):
90 | assert 0 <= idx < self._capacity
91 | return self._value[self._capacity + idx]
92 |
93 |
94 | class SumSegmentTree(SegmentTree):
95 | def __init__(self, capacity):
96 | super(SumSegmentTree, self).__init__(
97 | capacity=capacity,
98 | operation=operator.add,
99 | neutral_element=0.0
100 | )
101 |
102 | def sum(self, start=0, end=None):
103 | return super(SumSegmentTree, self).reduce(start, end)
104 |
105 | def find_prefixsum_idx(self, prefixsum):
106 | try:
107 | assert 0 <= prefixsum <= self.sum() + np.finfo(np.float32).eps
108 | except AssertionError:
109 | print(f"Prefix sum error: {prefixsum}")
110 | exit()
111 | idx = 1
112 | while idx < self._capacity:
113 | if self._value[2 * idx] > prefixsum:
114 | idx = 2 * idx
115 | else:
116 | prefixsum -= self._value[2 * idx]
117 | idx = 2 * idx + 1
118 | return idx - self._capacity
119 |
120 |
121 | class MinSegmentTree(SegmentTree):
122 | def __init__(self, capacity):
123 | super(MinSegmentTree, self).__init__(
124 | capacity=capacity,
125 | operation=min,
126 | neutral_element=float('inf')
127 | )
128 |
129 | def min(self, start=0, end=None):
130 | return super(MinSegmentTree, self).reduce(start, end)
131 |
132 |
133 | class PrioritizedReplayMemory:
134 | def __init__(self, torch_type, size, alpha=0.6, beta_start=0.4, beta_frames=70000*Config.PKL_NUM):
135 | self.torch_type = torch_type
136 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
137 | self._storage = []
138 | self._maxsize = size
139 | self._next_idx = 0
140 |
141 | assert alpha >= 0
142 | self._alpha = alpha
143 |
144 | self.beta_start = beta_start
145 | self.beta_frames = beta_frames
146 | self.frame = 1
147 |
148 | it_capacity = 1
149 | while it_capacity < size:
150 | it_capacity *= 2
151 |
152 | self._it_sum = SumSegmentTree(it_capacity)
153 | self._it_min = MinSegmentTree(it_capacity)
154 | self._max_priority = 1.0
155 |
156 | def __len__(self):
157 | return len(self._storage)
158 |
159 | def beta_by_frame(self, frame_idx):
160 | return min(1.0, self.beta_start + frame_idx * (1.0 - self.beta_start) / self.beta_frames)
161 |
162 | def push(self, *data):
163 | s, a, r, s_ = data
164 | data = Transition(s,
165 | a.reshape(1, -1),
166 | torch.tensor(r, dtype=self.torch_type).reshape(1, -1),
167 | s_)
168 |
169 | idx = self._next_idx
170 | if self._next_idx >= len(self._storage):
171 | self._storage.append(data)
172 | else:
173 | self._storage[self._next_idx] = data
174 | self._next_idx = (self._next_idx + 1) % self._maxsize
175 |
176 | self._it_sum[idx] = self._max_priority ** self._alpha
177 | self._it_min[idx] = self._max_priority ** self._alpha
178 |
179 | def _encode_sample(self, idxes):
180 | return [self._storage[i] for i in idxes]
181 |
182 | def _sample_proportional(self, batch_size):
183 | res = []
184 | for _ in range(batch_size):
185 | mass = random.random() * self._it_sum.sum(0, len(self._storage) - 1)
186 | idx = self._it_sum.find_prefixsum_idx(mass)
187 | res.append(idx)
188 | return res
189 |
190 | def sample(self, batch_size):
191 | idxes = self._sample_proportional(batch_size)
192 | weights = []
193 | p_min = self._it_min.min() / self._it_sum.sum()
194 |
195 | beta = self.beta_by_frame(self.frame)
196 |
197 | self.frame += 1
198 | max_weight = (p_min * len(self._storage)) ** (-beta)
199 |
200 | for idx in idxes:
201 | p_sample = self._it_sum[idx] / self._it_sum.sum()
202 | weight = (p_sample * len(self._storage)) ** (-beta)
203 | weights.append(weight / max_weight)
204 | weights = torch.tensor(weights, device=self.device, dtype=self.torch_type)
205 | encoded_sample = self._encode_sample(idxes)
206 | return encoded_sample, idxes, weights
207 |
208 | def update_priorities(self, idxes, priorities):
209 | assert len(idxes) == len(priorities)
210 | for idx, priority in zip(idxes, priorities):
211 | assert 0 <= idx < len(self._storage)
212 | self._it_sum[idx] = (priority + np.finfo(np.float32).eps) ** self._alpha
213 | self._it_min[idx] = (priority + np.finfo(np.float32).eps) ** self._alpha
214 | self._max_priority = max(self._max_priority, (priority + np.finfo(np.float32).eps))
215 |
--------------------------------------------------------------------------------
/RL/rl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : rl.py
3 | # @Date : 2022-05-18
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # @From :
6 | import math
7 | import os
8 | import pickle
9 | import random
10 |
11 | from pathlib import Path
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 |
16 | from replymemory import Transition, ExperienceReplayMemory, PrioritizedReplayMemory
17 |
18 |
19 | class DQN:
20 | """
21 | Deep Q Network
22 | 调用 self.flat_state(state) 将二维状态展平成一维状态
23 | 调用 self.choose_action(state) 从policy_net选择一个动作
24 | 调用 self.learn(state, action, reward, next_state)学习
25 | """
26 |
27 | def __init__(self,
28 | conf,
29 | net):
30 | """
31 | :params conf: config.Config() 实例
32 | :params net: net.Net 类
33 | """
34 |
35 | self.state_num = conf.NUM_STATES
36 | self.action_num = conf.NUM_ACTIONS
37 |
38 | self.lr = conf.LR
39 | self.use_decay = conf.USE_DECAY
40 | self.gamma = conf.REWARD_DECAY
41 | self.epsilon_start, self.epsilon_final, self.epsilon_decay = conf.E_GREEDY
42 | self.epsilon = conf.E_GREEDY_ORI
43 |
44 | self.tau = conf.TAU
45 |
46 | self.device = conf.DEVICE
47 | self.batch_size = conf.BATCH_SIZE
48 | self.memory_capacity = conf.MEMORY_CAPACITY
49 |
50 | # self.start_learn = config.START_LEARN
51 | self.clamp = conf.CLAMP
52 | self.torch_type = conf.TORCH_TYPE
53 |
54 | self.experiment_time = conf.TIME
55 |
56 | self.n_steps = conf.N_STEPS
57 | self.n_step_buffer = []
58 |
59 | self.policy_net, self.target_net = self.initialize_net(net)
60 |
61 | self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.lr)
62 | self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[2500, ], gamma=0.1) # [2500, 5000]
63 | # self.loss_func = nn.MSELoss()
64 | # HuberLoss 当delta=1时,等价于smoothl1一样
65 | self.loss_func = nn.SmoothL1Loss(reduction='none')
66 | self.memory_pool = self.initialize_replay_memory("PrioritizedReplayMemory")
67 |
68 | self.learning_step_count = 0
69 | self.update_target_count = 0
70 | self.update_target_frequency = conf.UPDATE_TARGET_FREQUENCY
71 |
72 | self.losses = []
73 |
74 | self.train_or_eval = 'train'
75 | self.change_model_mode('train')
76 |
77 | self.move_net_to_device()
78 |
79 | Path("./saved_agents").mkdir(exist_ok=True)
80 |
81 | def initialize_replay_memory(self, mode):
82 | """
83 | 初始化 回放池
84 | :param mode: ExperienceReplayMemory 和 PrioritizedReplayMemory
85 | :return: 经验回放的类的实例
86 | """
87 | if mode == 'ExperienceReplayMemory':
88 | return ExperienceReplayMemory(self.memory_capacity, self.torch_type)
89 | elif mode == 'PrioritizedReplayMemory':
90 | return PrioritizedReplayMemory(self.torch_type, self.memory_capacity)
91 |
92 | def initialize_net(self, net):
93 | """
94 | 初始化policy网络 和 target网络
95 | :param net: net.py中的net: nn.Module
96 | :return: policy_net, target_net: net的实例
97 | """
98 | policy_net = net(self.state_num, self.action_num)
99 | target_net = net(self.state_num, self.action_num)
100 | return policy_net, target_net
101 |
102 | def move_net_to_device(self):
103 | """
104 | 将网络移动到设备
105 | """
106 | self.policy_net.to(self.device)
107 | self.target_net.to(self.device)
108 |
109 | def change_model_mode(self, train_or_eval='train'):
110 | """
111 | 改变模型的模式,model.train() 或者 model.eval()
112 | :param train_or_eval: ['train', 'eval']
113 | :return: None
114 | """
115 | assert train_or_eval in ['train', 'eval']
116 | if train_or_eval == 'train':
117 | self.policy_net.train()
118 | self.target_net.train()
119 |
120 | elif train_or_eval == 'eval':
121 | self.policy_net.eval()
122 | self.target_net.eval()
123 |
124 | self.train_or_eval = train_or_eval
125 |
126 | def decay_epsilon(self, epoch, mode='exp'):
127 | """
128 | decay e-greedy
129 | :param epoch: 代数
130 | :param mode: 模式
131 | :return: eps
132 | """
133 | if mode == 'exp':
134 | eps = self.epsilon_final + (self.epsilon_start - self.epsilon_final) * math.exp(
135 | -1. * epoch / self.epsilon_decay)
136 | else:
137 | eps = min(abs(self.epsilon_start - self.epsilon_final * epoch), self.epsilon_final)
138 | return eps
139 |
140 | def choose_action(self, s, epoch):
141 | """
142 | 选择动作 e_greedy
143 | 2022/3/20 action_value * torch.from_numpy(mask) 如果全为0 那么选的时候可能会选不应该选的
144 | :param s: 状态
145 | :param epoch: 代数
146 | :return: action
147 | """
148 | if self.use_decay:
149 | epsilon = self.decay_epsilon(epoch)
150 | else:
151 | epsilon = self.epsilon
152 | # temp = random.random()
153 | temp = np.random.rand()
154 | if temp >= epsilon: # greedy policy
155 | action_value = self.policy_net.forward(s.to(self.device))
156 | # shape [1, actions_num]; max -> (values, indices)
157 | act_node = torch.max(action_value, 1)[1] # 最大值的索引
158 | else: # random policy
159 | act_node = torch.from_numpy(np.random.choice(np.array(range(self.action_num)), size=1)).to(self.device)
160 | # array([indice], dtype)
161 | act_node = act_node[0]
162 | return act_node
163 |
164 | def choose_max_action(self, s):
165 | """
166 | test时用,每次选择最大action-state value的动作
167 | :param s: 状态
168 | :return act_node: action index
169 | """
170 | action_value = self.policy_net.forward(s.to(self.device))
171 | # shape [1, actions_num]; max -> (values, indices)
172 | act_node = torch.max(action_value, 1)[1] # 最大值的索引
173 | return act_node
174 |
175 | def store_transition_in_memory(self, s, a, r, s_):
176 | """
177 | 存放到经验池中
178 | :param s: 状态s
179 | :param a: 动作a
180 | :param r: 回报r
181 | :param s_: 状态s撇
182 | :return: None
183 | """
184 | self.n_step_buffer.append((s, a, r, s_))
185 | if len(self.n_step_buffer) < self.n_steps and s_ is not None:
186 | return
187 | R = sum([self.n_step_buffer[i][2] * (self.gamma ** i) for i in range(self.n_steps)])
188 | s, a, _, _ = self.n_step_buffer.pop(0)
189 |
190 | self.memory_pool.push(s, a, R, s_)
191 |
192 | def finish_n_steps(self):
193 | """
194 | 将最后小于 n steps 的那几步,存入memory
195 | """
196 | while len(self.n_step_buffer) > 0:
197 | R = sum([self.n_step_buffer[i][2] * (self.gamma ** i) for i in range(len(self.n_step_buffer))])
198 | s, a, _, _ = self.n_step_buffer.pop(0)
199 | self.memory_pool.push(s, a, R, None)
200 |
201 | def get_batch_vars(self):
202 | """
203 | 获得batch的state,action,reward,next_state
204 | :return: state_batch, action_batch, reward_batch, non_final_next_states, non_final_mask, indices, weight
205 | """
206 | transitions, indices, weights = self.memory_pool.sample(self.batch_size)
207 |
208 | batch = Transition(*zip(*transitions))
209 | # 竖着放一起 (B, Hin)
210 | state_batch = torch.cat(batch.state).to(self.device)
211 | action_batch = torch.cat(batch.action).to(self.device)
212 | reward_batch = torch.cat(batch.reward).to(self.device)
213 |
214 | # 计算非最终状态的掩码并连接批处理元素(最终状态将是模拟结束后的状态)
215 | non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
216 | batch.next_state)), device=self.device, dtype=torch.bool)
217 |
218 | try:
219 | non_final_next_states = torch.cat([s for s in batch.next_state if s is not None]).to(self.device)
220 | non_flag = False
221 | except Exception as e:
222 | non_final_next_states = None
223 | non_flag = True
224 |
225 | return state_batch, action_batch, reward_batch, non_final_next_states, non_flag, non_final_mask, indices, weights
226 |
227 | def calculate_loss(self, batch_vars):
228 | """
229 | 计算损失
230 | :param batch_vars: state_batch, action_batch, reward_batch, non_final_next_states, non_final_mask, indices, weight
231 | :return: loss
232 | """
233 | batch_state, batch_action, batch_reward, non_final_next_states, non_flag, non_final_mask, indices, weights = batch_vars
234 |
235 | # 从policy net中根据状态s,获得执行action的values
236 | _out = self.policy_net(batch_state)
237 | state_action_values = torch.gather(_out, 1, batch_action.type(torch.int64))
238 |
239 | with torch.no_grad():
240 | # 如果non_final_next_states是全None,那么max_next_state_values就全是0
241 | max_next_state_values = torch.zeros(self.batch_size, dtype=self.torch_type, device=self.device).unsqueeze(1)
242 | if not non_flag:
243 | # 从target net中根据非最终状态,获得相应的value值
244 | max_next_action = self.target_net(non_final_next_states).max(dim=1)[1].view(-1, 1)
245 | max_next_state_values[non_final_mask] = self.target_net(non_final_next_states).gather(1,
246 | max_next_action)
247 | # 计算期望的Q值
248 | expected_state_action_values = (max_next_state_values * self.gamma) + batch_reward
249 |
250 | self.memory_pool.update_priorities(indices, (
251 | state_action_values - expected_state_action_values).detach().squeeze(1).abs().cpu().numpy().tolist())
252 |
253 | # 计算Huber损失
254 | loss = self.loss_func(state_action_values, expected_state_action_values) * weights.unsqueeze(1)
255 | loss = loss.mean()
256 |
257 | return loss
258 |
259 | def learn(self, s, a, r, s_, episode):
260 | """
261 | 从transition中学习
262 | :param s: 状态s
263 | :param a: 动作a
264 | :param r: 回报r
265 | :param s_: 状态s撇
266 | :return: None
267 | """
268 | if self.train_or_eval == 'eval':
269 | return
270 |
271 | self.store_transition_in_memory(s, a, r, s_)
272 |
273 | if len(self.memory_pool) < self.batch_size:
274 | return
275 |
276 | batch_vars = self.get_batch_vars()
277 | loss = self.calculate_loss(batch_vars)
278 |
279 | # Optimize the model
280 | self.optimizer.zero_grad()
281 | loss.backward()
282 | if self.clamp:
283 | for param in self.policy_net.parameters():
284 | param.grad.data.clamp_(-1, 1)
285 |
286 | self.optimizer.step()
287 |
288 | self.update_target_model(episode)
289 | self.append_loss_item(loss.data)
290 |
291 | def update_target_model(self, episode):
292 | """
293 | 更新target网络
294 | :return: None
295 | """
296 | self.update_target_count += 1
297 | self.update_count = self.update_target_count % self.update_target_frequency
298 | # self.update_count = episode % self.update_target_frequency
299 | if self.update_count == 0:
300 | # self.target_net.load_state_dict(self.policy_net.state_dict())
301 | self._soft_update()
302 |
303 | def _soft_update(self):
304 | """
305 | 软更新
306 | """
307 | for target_param, param in zip(self.target_net.parameters(), self.policy_net.parameters()):
308 | target_param.data.copy_(
309 | target_param.data * (1.0 - self.tau) + param.data * self.tau
310 | )
311 |
312 | def append_loss_item(self, loss):
313 | """
314 | 将loss添加到 losses中
315 | :param loss: 损失
316 | :return: None
317 | """
318 |
319 | self.losses.append(loss)
320 |
321 | def save_weight(self, name):
322 | """
323 | 保存模型 和 优化器的参数
324 | :return: None
325 | """
326 |
327 | torch.save(self.policy_net.state_dict(),
328 | f'./saved_agents/policy_net-{name}.pt')
329 | torch.save(self.target_net.state_dict(),
330 | f'./saved_agents/target_net-{name}.pt')
331 | torch.save(self.optimizer.state_dict(),
332 | f'./saved_agents/optim-{name}.pt')
333 |
334 | def load_weight(self, fname_model, fname_optim):
335 | """
336 | 加载模型和优化器的参数,policy_net网络参数,target_net网络参数,optimizer参数
337 | :return: None
338 | """
339 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
340 | if os.path.isfile(fname_model):
341 | self.policy_net.load_state_dict(torch.load(fname_model, map_location=device))
342 | self.target_net.load_state_dict(self.policy_net.state_dict())
343 | else:
344 | raise FileNotFoundError
345 |
346 | if fname_optim is not None and os.path.isfile(fname_optim):
347 | self.optimizer.load_state_dict(torch.load(fname_optim))
348 |
349 | def save_replay(self):
350 | """
351 | 保存经验池的数据, pickle形式
352 | :return: None
353 | """
354 | pickle.dump(self.memory_pool, open(f'./saved_agents/exp_replay_agent-{self.experiment_time}.pkl', 'wb'))
355 |
356 | def load_replay(self, fname):
357 | """
358 | 加载经验池的数据
359 | :return: Nones
360 | """
361 | if os.path.isfile(fname):
362 | self.memory_pool = pickle.load(open(fname, 'rb'))
363 |
--------------------------------------------------------------------------------
/RL/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : test.py
3 | # @Date : 2022-05-18
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # @From :
6 | import os
7 | import pickle
8 | from pathlib import Path
9 | from collections import namedtuple
10 | import random
11 | from collections.abc import Iterable
12 | import time
13 | from itertools import tee
14 |
15 | import networkx
16 | import xml.etree.ElementTree as ET
17 | import numpy as np
18 | import torch
19 | from torch.utils.tensorboard import SummaryWriter
20 | import networkx as nx
21 | from networkx.algorithms.approximation import steiner_tree
22 | import matplotlib.pyplot as plt
23 |
24 | from log import MyLog
25 | from rl import DQN
26 | from env import MulticastEnv
27 | from net import *
28 | from config import Config
29 | from env_config import *
30 |
31 | np.random.seed(2022)
32 | torch.random.manual_seed(2022)
33 | random.seed(2022)
34 |
35 | mylog = MyLog(Path(__file__), filesave=True, consoleprint=True)
36 | logger = mylog.logger
37 | RouteParams = namedtuple("RouteParams", ('bw', 'delay', 'loss'))
38 | plt.style.use("seaborn-whitegrid")
39 |
40 |
41 | class Train:
42 | def __init__(self, Conf, Env, RL, Net, name, mode):
43 | """
44 | Conf: class, 配置类
45 | Env: class, 环境类
46 | RL: class, 强化学习类
47 | Net: class, 神经网络类
48 | """
49 | self.graph = None
50 | self.nodes_num = None
51 | self.edges_num = None
52 |
53 | self.state_channel = 4
54 |
55 | self.record_dict = {} # 记录训练时所有
56 |
57 | self.config = self.set_initial_config(Conf) # 配置
58 | self.config.log_params(logger)
59 |
60 | # 1. 设置 env
61 | self.env = self.set_initial_env(Env)
62 | # 2. 初始化 图、节点数、边数
63 | self.set_init_topology()
64 | # 3. 设置config中的NUM_STATES NUM_ACTIONS
65 | self.set_num_states_actions()
66 | # 4. 设置 RL
67 | self.rl = self.set_initial_rl(RL, Net)
68 |
69 | self.writer = SummaryWriter(f"./runs/{name}") if mode != 'eval' else None
70 |
71 | self.reward_list_idx = ""
72 |
73 | def set_init_topology(self):
74 | """
75 | 1. 解析xml文件
76 | 2. 设置 self.graph
77 | self.nodes_num
78 | self.edges_num
79 | """
80 | graph, nodes_num, edges_num = self.parse_xml_topology(self.config.xml_topology_path)
81 | self.graph = graph
82 | self.nodes_num = nodes_num
83 | self.edges_num = edges_num
84 |
85 | def set_num_states_actions(self):
86 | """
87 | 根据解析topo设置config中的状态空间和动作空间,供深度网络使用
88 | """
89 | state_space_num = self.state_channel # input channel
90 | action_space_num = self.edges_num # output channel
91 | self.config.set_num_states_actions(state_space_num, action_space_num)
92 |
93 | def set_initial_env(self, Env):
94 | """
95 | Env类初始化
96 | :param Env: 环境类
97 | :return : Env的实例
98 | """
99 | env = Env(self.graph, self.config.NUMPY_TYPE)
100 | return env
101 |
102 | def set_initial_rl(self, RL, Net):
103 | """
104 | RL类初始化
105 | :param RL: RL类 如DQN
106 | :param Net: Net类
107 | :return : RL的实例
108 | """
109 | rl = RL(self.config, Net)
110 | return rl
111 |
112 | def set_initial_config(self, Config):
113 | """
114 | 配置初始化
115 | :param Config: 配置类
116 | :return: conf实例
117 | """
118 | conf = Config()
119 | return conf
120 |
121 | @staticmethod
122 | def parse_xml_topology(topology_path):
123 | """
124 | parse topology from topology.xml
125 | :param topology_path: 拓扑的xml文件路径
126 | :return: topology graph, networkx.Graph()
127 | nodes_num, int
128 | edges_num, int
129 | """
130 | tree = ET.parse(topology_path)
131 | root = tree.getroot()
132 | topo_element = root.find("topology")
133 | graph = networkx.Graph()
134 | for child in topo_element.iter():
135 | # parse nodes
136 | if child.tag == 'node':
137 | node_id = int(child.get('id'))
138 | graph.add_node(node_id)
139 | # parse link
140 | elif child.tag == 'link':
141 | from_node = int(child.find('from').get('node'))
142 | to_node = int(child.find('to').get('node'))
143 | graph.add_edge(from_node, to_node)
144 |
145 | nodes_num = len(graph.nodes)
146 | edges_num = len(graph.edges)
147 |
148 | print('nodes: ', nodes_num, '\n', graph.nodes, '\n',
149 | 'edges: ', edges_num, '\n', graph.edges)
150 | return graph, nodes_num, edges_num
151 |
152 | def pkl_file_path_yield(self, pkl_dir, n: int = 3000, step: int = 1):
153 | """
154 | 生成保存的pickle文件的路径, 按序号递增的方式生成
155 | :param pkl_dir: Path, pkl文件的目录
156 | :param n: 截取
157 | :param step: 间隔
158 | """
159 | a = os.listdir(pkl_dir)
160 | assert n < len(a), "n should small than len(a)"
161 | b = sorted(a, key=lambda x: int(x.split('-')[0]))
162 | for p in b[:n:step]:
163 | yield pkl_dir / p
164 |
165 | def record_route_params(self, episode, bw, delay, loss):
166 | """
167 | 将route信息存入字典中 {episode: RouteParams}
168 | :param episode: 训练的第几代, 作为key
169 | :param bw: 剩余带宽和
170 | :param delay: 时延和
171 | :param loss: 丢包率和
172 | :return: None
173 | """
174 | self.record_dict.setdefault(episode, RouteParams(bw, delay, loss))
175 | with open("./recorder.pkl", "wb+") as f:
176 | pickle.dump(self.record_dict, f)
177 |
178 | @staticmethod
179 | def record_reward(r):
180 | with open('./reward.pkl', 'wb') as f:
181 | pickle.dump(r, f)
182 |
183 | def choice_multicast_nodes(self):
184 | """
185 | 随机选择 源节点和目的节点
186 | source 至少 1
187 | destination 至少 2
188 | 所有节点至多 nodes_num - 1
189 | :return: multicast nodes
190 | """
191 | _a = list(self.graph.nodes())
192 | _size = np.random.randint(3, self.nodes_num)
193 | # 无放回抽取
194 | multicast_nodes = np.random.choice(_a, size=_size, replace=False)
195 | multicast_nodes = list(multicast_nodes)
196 | start_node = multicast_nodes.pop(0)
197 | end_nodes = multicast_nodes
198 |
199 | return start_node, end_nodes
200 |
201 | @staticmethod
202 | def flat_and_combine_state(route, bw, delay, loss):
203 | """
204 | 将二维矩阵展平成一维, 将多个一维拼接
205 | :param route: 路由矩阵
206 | :param bw: 剩余带宽矩阵
207 | :param delay: 时延矩阵
208 | :param loss: 丢包率矩阵
209 | :return: 展平的tensor
210 | """
211 | if route is not None:
212 | flatten_route = torch.flatten(torch.from_numpy(route)).unsqueeze(0)
213 | flatten_bw = torch.flatten(torch.from_numpy(bw)).unsqueeze(0)
214 | flatten_delay = torch.flatten(torch.from_numpy(delay)).unsqueeze(0)
215 | flatten_loss = torch.flatten(torch.from_numpy(loss)).unsqueeze(0)
216 | combine_state = torch.cat([flatten_route, flatten_bw, flatten_delay, flatten_loss], dim=1)
217 | return combine_state
218 | else:
219 | return None
220 |
221 | def combine_state(self, route):
222 | """
223 | 组合 多channel
224 | :param route: 路由矩阵
225 | :return: 多channel tensor
226 | """
227 | if route is not None:
228 | if self.state_channel == 1:
229 | combine_state = torch.stack(
230 | [torch.from_numpy(route) + torch.from_numpy(self.env.get_branches_matrix())], dim=0)
231 | elif self.state_channel == 2:
232 | combine_state = torch.stack(
233 | [
234 | torch.from_numpy(route),
235 | torch.from_numpy(self.env.get_branches_matrix()),
236 | ],
237 | dim=0)
238 | elif self.state_channel == 4:
239 | combine_state = torch.stack(
240 | [
241 | torch.from_numpy(route) + torch.from_numpy(self.env.get_branches_matrix()),
242 | torch.from_numpy(self.env.normal_bw_matrix),
243 | torch.from_numpy(self.env.normal_delay_matrix),
244 | torch.from_numpy(self.env.normal_loss_matrix)
245 | ],
246 | dim=0)
247 | else:
248 | combine_state = torch.stack(
249 | [
250 | torch.from_numpy(route),
251 | torch.from_numpy(self.env.get_branches_matrix()),
252 | torch.from_numpy(self.env.normal_bw_matrix),
253 | torch.from_numpy(self.env.normal_delay_matrix),
254 | torch.from_numpy(self.env.normal_loss_matrix)
255 | ],
256 | dim=0)
257 | return torch.unsqueeze(combine_state, dim=0)
258 | else:
259 | return None
260 |
261 | @staticmethod
262 | def kmb_algorithm(graph, src_node, dst_nodes, weight=None):
263 | """
264 | 经典KMB算法 组播树
265 |
266 | :param graph: networkx.graph
267 | :param src_node: 源节点
268 | :param dst_nodes: 目的节点
269 | :param weight: 计算的权重
270 | :return: 返回图形的最小Steiner树的近似值
271 | """
272 | terminals = [src_node] + dst_nodes
273 | st_tree = steiner_tree(graph, terminals, weight)
274 | return st_tree
275 |
276 | @staticmethod
277 | def spanning_tree(graph, weight=None):
278 | """
279 | 生成树算法
280 | :param graph: networkx.graph
281 | :param weight: 计算的权重
282 | :return: iterator 最小生成树
283 | """
284 | spanning_tree = nx.algorithms.minimum_spanning_tree(graph, weight)
285 | return spanning_tree
286 |
287 | def get_tree_params(self, tree, graph):
288 | """
289 | 计算tree的bw, delay, loss 参数和
290 | :param tree: 要计算的树
291 | :param graph: 计算图中的数据
292 | :return: bw, delay, loss, len
293 | """
294 | bw, delay, loss = 0, 0, 0
295 | if isinstance(tree, nx.Graph):
296 | edges = tree.edges
297 | elif isinstance(tree, Iterable):
298 | edges = tree
299 | else:
300 | raise ValueError("tree param error")
301 | num = 0
302 | for r in edges:
303 | bw += graph[r[0]][r[1]]["bw"]
304 | delay += graph[r[0]][r[1]]["delay"]
305 | loss += graph[r[0]][r[1]]["loss"]
306 | num += 1
307 | bw_mean = self.env.find_end_to_end_max_bw(tree, self.env.start, self.env.ends_constant).mean()
308 | return bw_mean, delay / num, loss / num, len(tree.edges)
309 |
310 | @staticmethod
311 | def modify_bw_weight(graph):
312 | """
313 | 将delay取负,越大表示越小
314 | :param graph: 图
315 | :return: weight
316 | """
317 | _g = graph.copy()
318 | for edge in graph.edges:
319 | _g[edge[0]][edge[1]]['bw'] = 1 / (graph[edge[0]][edge[1]]['bw'] + 1)
320 | return _g
321 |
322 | def get_kmb_params(self, graph, start_node, end_nodes):
323 | """
324 | 获得以 bw 为权重的 steiner tree 返回该树的 bw和
325 | 获得以 delay 为权重的 steiner tree 返回该树的 delay和
326 | 获得以 loss 为权重的 steiner tree 返回该树的 loss和
327 | 获得以 hope 为权重的 steiner tree 返回该树的 长度length
328 | :param graph: 图
329 | :param start_node: 源节点
330 | :param end_nodes: 目的节点
331 | :return: bw, delay, loss, length
332 | """
333 | _g = self.modify_bw_weight(graph)
334 | # kmb算法 计算权重为-bw
335 | kmb_bw_tree = self.kmb_algorithm(_g, start_node, end_nodes,
336 | weight='bw')
337 | bw_bw, bw_delay, bw_loss, bw_length = self.get_tree_params(kmb_bw_tree, graph)
338 |
339 | # kmb算法 计算权重为delay
340 | kmb_delay_tree = self.kmb_algorithm(graph, start_node, end_nodes,
341 | weight='delay')
342 | delay_bw, delay_delay, delay_loss, delay_length = self.get_tree_params(kmb_delay_tree, graph)
343 |
344 | # kmb算法 计算权重为loss
345 | kmb_loss_tree = self.kmb_algorithm(graph, start_node, end_nodes,
346 | weight='loss')
347 | loss_bw, loss_delay, loss_loss, loss_length = self.get_tree_params(kmb_loss_tree, graph)
348 |
349 | # kmb算法 为None
350 | kmb_hope_tree = self.kmb_algorithm(graph, start_node, end_nodes, weight=None)
351 | length_bw, length_delay, length_loss, length_length = self.get_tree_params(kmb_hope_tree, graph)
352 |
353 | bw_ = [bw_bw, delay_bw, loss_bw, length_bw]
354 | delay_ = [bw_delay, delay_delay, loss_delay, length_delay]
355 | loss_ = [bw_loss, delay_loss, loss_loss, length_loss]
356 | length_ = [bw_length, delay_length, loss_length, length_length]
357 | return bw_, delay_, loss_, length_
358 |
359 | def get_spanning_tree_params(self, graph):
360 | """
361 | 获得以 bw 为权重的 spanning tree 返回该树的 bw和
362 | 获得以 delay 为权重的 spanning tree 返回该树的 delay和
363 | 获得以 loss 为权重的 spanning tree 返回该树的 loss和
364 | 获得以 hope 为权重的 spanning tree 返回该树的 长度length
365 | :param graph:
366 | :return:
367 | """
368 | _g = self.modify_bw_weight(graph)
369 | spanning_bw_tree = self.spanning_tree(_g, weight='bw')
370 | bw, _, _, _ = self.get_tree_params(spanning_bw_tree, graph)
371 | spanning_delay_tree = self.spanning_tree(graph, weight='delay')
372 | _, delay, _, _ = self.get_tree_params(spanning_delay_tree, graph)
373 | spanning_loss_tree = self.spanning_tree(graph, weight='loss')
374 | _, _, loss, _ = self.get_tree_params(spanning_loss_tree, graph)
375 | spanning_length_tree = self.spanning_tree(graph, weight=None)
376 | _, _, _, length = self.get_tree_params(spanning_length_tree, graph)
377 |
378 | return bw, delay, loss, length
379 |
380 | def print_train_info(self, episode, index, reward, link):
381 | logger.info(f"[{episode}][{index}] reward: {reward}")
382 | logger.info(f"[{episode}][{index}] tree_nodes: {self.env.tree_nodes}")
383 | logger.info(f"[{episode}][{index}] route_list: {self.env.route_graph.edges}")
384 | logger.info(f"[{episode}][{index}] branches: {self.env.branches}")
385 | logger.info(f"[{episode}][{index}] link: {link}")
386 | logger.info(f"[{episode}][{index}] step_num: {self.env.step_num}")
387 | # logger.info(f'[{episode}][{index}]: {self.env.route_matrix}')
388 | logger.info("=======================================================")
389 |
390 | def update(self):
391 | """
392 | 状态更新 rl学习
393 | 1. 循环代数, 进行训练
394 | 2. 读取一个graph, 环境reset
395 | 3. while True 直到跑出path
396 |
397 | 2022/3/17 修改link方向BUG
398 | """
399 | pkl_cut_num = self.config.PKL_CUT_NUM
400 | pkl_step = self.config.PKL_STEP
401 | loss_step = 0
402 | for episode in range(self.config.EPISODES):
403 | # start_node, end_nodes = self.choice_multicast_nodes()
404 | start_node = 12
405 | end_nodes = [2, 4, 11]
406 |
407 | logger.info(f"[{episode}] start_node: {start_node}")
408 | logger.info(f"[{episode}] end_nodes: {end_nodes}")
409 |
410 | episode_reward = np.array([])
411 | episode_bw, kmb_bw, spanning_bw = np.array([]), np.array([]), np.array([])
412 | episode_delay, kmb_delay, spanning_delay = np.array([]), np.array([]), np.array([])
413 | episode_loss, kmb_loss, spanning_loss = np.array([]), np.array([]), np.array([])
414 | episode_length, kmb_length, spanning_length = np.array([]), np.array([]), np.array([])
415 | episode_steps = np.array([])
416 |
417 | for index, pkl_path in enumerate(
418 | self.pkl_file_path_yield(self.config.pkl_weight_path, n=pkl_cut_num, step=pkl_step)):
419 |
420 | self.env.read_pickle_and_modify(pkl_path)
421 | state = self.env.reset(start_node, end_nodes)
422 |
423 | reward_temp = 0
424 | while True:
425 | # 1. 二维
426 | combine_state = self.combine_state(state)
427 | # 2.1 动作选择
428 | action = self.rl.choose_action(combine_state, episode)
429 | # 2.2 将动作映射为链路
430 | link = self.env.map_action(action)
431 | # 3. 环境交互
432 | new_state, reward, done, flag = self.env.step(link)
433 | # 4. 下一个状态
434 | combine_new_state = self.combine_state(new_state)
435 | # 5. RL学习
436 | self.rl.learn(combine_state, action, reward, combine_new_state)
437 | reward_temp += reward
438 | if len(self.rl.losses) > 0:
439 | self.writer.add_scalar("Optim/Loss", self.rl.losses.pop(0), loss_step)
440 | loss_step += 1
441 | if done:
442 | self.rl.finish_n_steps()
443 |
444 | if flag == "ALL":
445 | bw, delay, loss, length = self.env.get_route_params() # 获得路径的所以链路bw和,delay和,loss和
446 | # 添加到数组中
447 | episode_bw = np.append(episode_bw, bw)
448 | episode_delay = np.append(episode_delay, delay)
449 | episode_loss = np.append(episode_loss, loss)
450 | episode_length = np.append(episode_length, length)
451 |
452 | # kmb 算法
453 | bw, delay, loss, length = self.get_kmb_params(self.env.graph, start_node, end_nodes)
454 | kmb_bw = np.append(kmb_bw, bw)
455 | kmb_delay = np.append(kmb_delay, delay)
456 | kmb_loss = np.append(kmb_loss, loss)
457 | kmb_length = np.append(kmb_length, length)
458 |
459 | episode_reward = np.append(episode_reward, reward_temp)
460 | episode_steps = np.append(episode_steps, self.env.step_num - 1)
461 | self.print_train_info(episode, index, reward, link)
462 | break
463 |
464 | # 6. 状态改变
465 | state = new_state
466 |
467 | # self.writer.add_scalars('Episode/reward',
468 | # {"reward": episode_reward.mean(), "reward_max": episode_reward.max(initial=0)},
469 | # episode)
470 | # self.writer.add_scalars('Episode/steps',
471 | # {"reward": episode_steps.mean(), "reward_max": episode_steps.max(initial=0)},
472 | # episode)
473 | #
474 | # self.writer.add_scalars('Episode/bw', {"rl": episode_bw.mean(), "kmb_bw": kmb_bw.mean(),
475 | # }, episode)
476 | # self.writer.add_scalars('Episode/delay', {"rl": episode_delay.mean(), "kmb_delay": kmb_delay.mean(),
477 | # }, episode)
478 | # self.writer.add_scalars('Episode/loss', {"rl": episode_loss.mean(), "kmb_loss": kmb_loss.mean(),
479 | # }, episode)
480 | # self.writer.add_scalars('Episode/length', {"rl": episode_length.mean(), "kmb_length": kmb_length.mean(),
481 | # }, episode)
482 | self.writer.add_scalar('Episode/reward', episode_reward.mean(), episode)
483 | self.writer.add_scalar('Episode/steps', episode_steps.mean(), episode)
484 | self.writer.add_scalar('Episode/bw', episode_bw.mean(), episode)
485 | self.writer.add_scalar('Episode/delay', episode_delay.mean(), episode)
486 | self.writer.add_scalar('Episode/loss', episode_loss.mean(), episode)
487 | self.writer.add_scalar('Episode/length', episode_length.mean(), episode)
488 | self.writer.add_scalar("learning_rate", self.rl.optimizer.param_groups[0]['lr'], episode)
489 | self.rl.scheduler.step()
490 | self.rl.save_weight()
491 |
492 | logger.debug('train over')
493 |
494 | def set_reward_list_idx(self, idx):
495 | self.reward_list_idx = idx
496 |
497 | def compare_test(self,
498 | weight_file):
499 | self.rl.change_model_mode('eval')
500 |
501 | pkl_cut_num = self.config.PKL_CUT_NUM
502 | pkl_step = self.config.PKL_STEP
503 | start_node = 12
504 | end_nodes = [2, 4, 11]
505 |
506 | episode_bw, kmb_bw, spanning_bw = [], [], []
507 | episode_delay, kmb_delay, spanning_delay = [], [], []
508 | episode_loss, kmb_loss, spanning_loss = [], [], []
509 | episode_length, kmb_length, spanning_length = [], [], []
510 |
511 | self.rl.load_weight(weight_file, None)
512 | for index, pkl_path in enumerate(
513 | self.pkl_file_path_yield(self.config.pkl_weight_path, n=pkl_cut_num, step=pkl_step)):
514 |
515 | self.env.read_pickle_and_modify(pkl_path)
516 | state = self.env.reset(start_node, end_nodes)
517 |
518 | while True:
519 | # 1. 二维
520 | combine_state = self.combine_state(state)
521 | # 2.1 动作选择
522 | action = self.rl.choose_max_action(combine_state)
523 | # 2.2 将动作映射为链路
524 | link = self.env.map_action(action)
525 | # 3. 环境交互
526 | new_state, reward, done, flag = self.env.step(link)
527 | if done:
528 | if flag == "ALL":
529 | bw, delay, loss, length = self.env.get_route_params() # 获得路径的所以链路bw和,delay和,loss和
530 | # 添加到数组中
531 | episode_bw.append(bw)
532 | episode_delay.append(delay)
533 | episode_loss.append(loss)
534 | episode_length.append(length)
535 |
536 | # kmb 算法
537 | bw, delay, loss, length = self.get_kmb_params(self.env.graph, start_node, end_nodes)
538 | kmb_bw.append(bw)
539 | kmb_delay.append(delay)
540 | kmb_loss.append(loss)
541 | kmb_length.append(length)
542 | break
543 |
544 | # 6. 状态改变
545 | state = new_state
546 | self.plot_compare_figure(episode_bw, kmb_bw, "traffic", "mean bw", "bw")
547 | self.plot_compare_figure(episode_delay, kmb_delay, "traffic", "mean delay", "delay")
548 | self.plot_compare_figure(episode_loss, kmb_loss, "traffic", "mean loss", "loss")
549 |
550 | # self.plot_compare_figure_subplots(episode_bw, episode_delay, episode_loss, kmb_bw, kmb_delay, kmb_loss)
551 |
552 | def plot_compare_figure(self, rl_result, kmb_result, x_label, y_label, title):
553 | width = 0.18
554 | plt.bar(range(len(kmb_result)), rl_result, width, label='rl')
555 |
556 | kmb_bw = [kmb_result[i][0] for i in range(len(kmb_result))]
557 | kmb_delay = [kmb_result[i][1] for i in range(len(kmb_result))]
558 | kmb_loss = [kmb_result[i][2] for i in range(len(kmb_result))]
559 |
560 | plt.bar([x + width for x in range(len(kmb_result))], kmb_bw, width, label='kmb_bw')
561 | plt.bar([x + 2 * width for x in range(len(kmb_result))], kmb_delay, width, label='kmb_delay')
562 | plt.bar([x + 3 * width for x in range(len(kmb_result))], kmb_loss, width, label='kmb_loss')
563 |
564 | plt.xticks(range(len(kmb_result)), range(len(kmb_result)), rotation=0, fontsize='small')
565 | plt.xlabel(x_label)
566 | plt.ylabel(y_label)
567 | # plt.title(title)
568 | plt.legend(bbox_to_anchor=(0., 1.0), loc='lower left', ncol=4, )
569 |
570 | _path = Path('./images')
571 | if _path.exists():
572 | plt.savefig(_path / f'{title}.png')
573 | else:
574 | _path.mkdir(exist_ok=True)
575 | plt.savefig(_path / f'{title}.png')
576 | plt.show()
577 |
578 | def plot_compare_figure_subplots(self, rl_bw_result, rl_delay_result, rl_loss_result, kmb_bw_result,
579 | kmb_delay_result, kmb_loss_result):
580 | width = 0.18
581 | fig, ax = plt.subplots(1, 3)
582 |
583 | def get_bw_delay_loss(result):
584 | _bw = [result[i][0] for i in range(len(result))]
585 | _delay = [result[i][1] for i in range(len(result))]
586 | _loss = [result[i][2] for i in range(len(result))]
587 | return _bw, _delay, _loss
588 |
589 | def ax_plot(i, kmb_result, rl_result):
590 | kmb_bw, kmb_delay, kmb_loss = get_bw_delay_loss(kmb_result)
591 | ax[i].bar(range(len(rl_result)), rl_result, width, label='rl')
592 | ax[i].bar([x + width for x in range(len(kmb_bw))], kmb_bw, width, label='kmb_bw')
593 | ax[i].bar([x + 2 * width for x in range(len(kmb_bw))], kmb_delay, width, label='kmb_delay')
594 | ax[i].bar([x + 3 * width for x in range(len(kmb_bw))], kmb_loss, width, label='kmb_loss')
595 |
596 | # bw
597 | ax_plot(0, kmb_bw_result, rl_bw_result)
598 | ax[0].set_ylabel("mean bw")
599 | ax_plot(1, kmb_delay_result, rl_delay_result)
600 | ax[1].set_ylabel("mean delay")
601 | ax_plot(2, kmb_loss_result, rl_loss_result)
602 | ax[2].set_ylabel("mean loss")
603 |
604 | plt.xticks(range(len(rl_bw_result)), range(len(rl_bw_result)), rotation=0, fontsize='small')
605 | # plt.title(title)
606 | plt.legend(bbox_to_anchor=(0., 1.0), loc='lower left', ncol=4, )
607 |
608 | _path = Path('./images')
609 | if _path.exists():
610 | plt.savefig(_path / 'result.png')
611 | else:
612 | _path.mkdir(exist_ok=True)
613 | plt.savefig(_path / f'result.png')
614 | plt.show()
615 |
616 | def read_data_path(self, data_path="./data"):
617 | data_path = Path(data_path)
618 | data_path_dict = {"lr": [], "nsteps": [], "batchsize": [], 'egreedy': [], "gamma": [], "update": [],
619 | "rewardslist": [], "tau": []}
620 |
621 | for data_doc in data_path.iterdir():
622 | if data_doc.match('*_lr_*'):
623 | data_path_dict['lr'].append(data_doc)
624 | elif data_doc.match('*_nsteps_*'):
625 | data_path_dict['nsteps'].append(data_doc)
626 | elif data_doc.match('*_batchsize_*'):
627 | data_path_dict['batchsize'].append(data_doc)
628 | elif data_doc.match('*_egreedy_*'):
629 | data_path_dict['egreedy'].append(data_doc)
630 | elif data_doc.match('*_gamma_*'):
631 | data_path_dict['gamma'].append(data_doc)
632 | elif data_doc.match('*_updatefrequency_*'):
633 | data_path_dict['update'].append(data_doc)
634 | elif data_doc.match('*_rewardslist_*'):
635 | data_path_dict['rewardslist'].append(data_doc)
636 | elif data_doc.match('*_tau_*'):
637 | data_path_dict['tau'].append(data_doc)
638 |
639 | return data_path_dict
640 |
641 | def get_diff_data_from_multi_file(self, data_path_list):
642 | data_dict = {"bw": [], "delay": [], "loss": [], 'length': [], "final_reward": [], "episode_reward": [],
643 | "steps": []}
644 |
645 | for path in data_path_list:
646 | for child in path.iterdir():
647 | data = np.load(child)
648 | if child.match("*bw*"):
649 | data_dict['bw'].append(data)
650 | elif child.match('*delay*'):
651 | data_dict['delay'].append(data)
652 | elif child.match("*loss*"):
653 | data_dict['loss'].append(data)
654 | elif child.match("*length*"):
655 | data_dict['length'].append(data)
656 | elif child.match("*final_reward*"):
657 | data_dict['final_reward'].append(data)
658 | elif child.match("*episode_reward*"):
659 | data_dict['episode_reward'].append(data)
660 | elif child.match("*steps*"):
661 | data_dict['steps'].append(data)
662 |
663 | return data_dict
664 |
665 | def get_compare_data(self, data_path="./data"):
666 | data_path_dict = self.read_data_path(data_path)
667 | data_npy_dict = {}
668 | for k in data_path_dict.keys():
669 | data_npy_dict[k] = self.get_diff_data_from_multi_file(data_path_dict[k])
670 |
671 | return data_npy_dict
672 |
673 | def plot_line_chart(self, data_list, name):
674 | for data in data_list:
675 | x = range(len(data))
676 | plt.plot(x, data)
677 | plt.legend()
678 | plt.save('./')
679 |
680 |
681 | if __name__ == '__main__':
682 | pt_file_path = "./saved_agents/policy_net-2022-05-10-16-05-33.pt"
683 | exp_time = time.strftime("%Y%m%d%H%M%S", time.localtime())
684 | _name = f"[{exp_time}]_test"
685 | train = Train(Config, MulticastEnv, DQN, MyMulticastNet3, name=_name, mode='eval')
686 | train.compare_test(pt_file_path)
687 |
--------------------------------------------------------------------------------
/mininet/generate_matrices.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : generate_matrices.py
3 | # @Date : 2021-12-27
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # @From :
6 | import argparse
7 | from pathlib import Path
8 |
9 | import numpy as np
10 | import numpy.random
11 | from tmgen.models import modulated_gravity_tm
12 | import matplotlib.pyplot as plt
13 |
14 |
15 | def set_seed():
16 | numpy.random.seed(args.seed)
17 |
18 |
19 | def generate_tm():
20 | tm = modulated_gravity_tm(args.num_nodes, args.num_tms, args.mean_traffic, args.pm_ratio, args.t_ratio,
21 | args.diurnal_freq, args.spatial_variance, args.temporal_variance)
22 |
23 | mean_time_tm = []
24 | for t in range(args.num_tms):
25 | mean_time_tm.append(tm.at_time(t).mean())
26 | print(f"time: {t} h, mean traffic: {mean_time_tm[-1]}")
27 |
28 | # 构造一个(num_nodes, num_nodes, num_tms)的0-1均匀分布生成的矩阵
29 | _size = (args.num_nodes,) * 2
30 | _size += (args.num_tms,)
31 |
32 | temp = np.random.random(_size)
33 | mask = temp < args.communicate_ratio
34 | communicate_tm = tm.matrix * mask
35 |
36 | mean_communicate_tm = []
37 | for t in range(args.num_tms):
38 | mean_communicate_tm.append(communicate_tm[:, :, t].mean())
39 | print(f"time: {t} h, mean communicate nodes traffic: {mean_communicate_tm[-1]}")
40 |
41 | np_save(tm.matrix, "traffic_matrix")
42 | np_save(mean_time_tm, "mean_time_tm")
43 |
44 | np_save(communicate_tm, "communicate_tm")
45 | np_save(mean_communicate_tm, "mean_communicate_tm")
46 |
47 | plot_tm_mean(mean_time_tm, title="mean_time_tm")
48 | plot_tm_mean(mean_communicate_tm, title="mean_communicate_tm")
49 |
50 |
51 | def np_save(file_data, file_name):
52 | Path('./tm_statistic').mkdir(exist_ok=True)
53 | np.save(f'./tm_statistic/{file_name}.npy', file_data)
54 | print(f'save {file_name}')
55 |
56 |
57 | def plot_tm_mean(mean_list, x_label='time', y_label='mean_traffic', title='mean'):
58 | fig = plt.figure()
59 | plt.xlabel(x_label)
60 | plt.ylabel(y_label)
61 | # plt.title(title)
62 | x = list(range(len(mean_list)))
63 | y = mean_list
64 | plt.bar(x, y)
65 | Path("./figure").mkdir(exist_ok=True)
66 | plt.savefig(f"./figure/{title}.pdf", dpi=300, bbox_inches='tight', pad_inches=0)
67 | plt.show()
68 |
69 |
70 | if __name__ == '__main__':
71 | parser = argparse.ArgumentParser(description="Generate traffic matrices")
72 | parser.add_argument("--seed", default=2020, help="random seed")
73 | parser.add_argument("--num_nodes", default=14, help="number of nodes of network")
74 | parser.add_argument("--num_tms", default=24, help="total number of matrices")
75 | # 1.55 * 1e3 * 0.75
76 | parser.add_argument("--mean_traffic", default=5 * 10 ** 3 * 0.75, help="mean volume of traffic (Kbps)")
77 | parser.add_argument("--pm_ratio", default=1.5, help="peak-to-mean ratio")
78 | parser.add_argument("--t_ratio", default=0.75, help="trough-to-mean ratio")
79 | parser.add_argument("--diurnal_freq", default=1 / 24, help="Frequency of modulation")
80 | parser.add_argument("--spatial_variance", default=500,
81 | help="Variance on the volume of traffic between origin-destination pairs")
82 | parser.add_argument("--temporal_variance", default=0.03, help="Variance on the volume in time")
83 | parser.add_argument("--communicate_ratio", default=0.7, help="percentage of nodes to communicate")
84 | args = parser.parse_args()
85 |
86 | # set_seed()
87 | # generate_tm()
88 |
89 | mean_time_tm = np.load(r"D:\WorkSpace\Hello_Myself\Hello_Multicast\RLMulticastProject\mininet\tm_statistic\tm_statistic\mean_time_tm.npy")
90 | plot_tm_mean(mean_time_tm)
--------------------------------------------------------------------------------
/mininet/generate_nodes_topo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : GEANT_23nodes_topo.py.py
3 | # @Date : 2021-12-09
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | import os
6 | import random
7 | from pathlib import Path
8 | import time
9 | import json
10 | import threading
11 | import xml.etree.ElementTree as ET
12 |
13 | import networkx
14 |
15 | from mininet.topo import Topo
16 | from mininet.net import Mininet
17 | from mininet.node import RemoteController
18 | from mininet.link import TCLink
19 | from mininet.cli import CLI
20 | from mininet.log import setLogLevel
21 | from mininet.util import dumpNodeConnections
22 |
23 | random.seed(2020)
24 |
25 |
26 | def generate_port(node_idx1, node_idx2):
27 | if (node_idx2 > 9) and (node_idx1 > 9):
28 | port = str(node_idx1) + "0" + str(node_idx2)
29 | else:
30 | port = str(node_idx1) + "00" + str(node_idx2) # test
31 |
32 | return int(port)
33 |
34 |
35 | def generate_switch_port(graph):
36 | switch_port_dict = {}
37 | for node in graph.nodes:
38 | switch_port_dict.setdefault(node, list(range(graph.degree[node])))
39 | return switch_port_dict
40 |
41 |
42 | def parse_xml_topology(topology_path):
43 | """
44 | parse topology from topology.xml
45 | :return: topology graph, networkx.Graph()
46 | nodes_num, int
47 | edges_num, int
48 | """
49 | tree = ET.parse(topology_path)
50 | root = tree.getroot()
51 | topo_element = root.find("topology")
52 | graph = networkx.Graph()
53 | for child in topo_element.iter():
54 | # parse nodes
55 | if child.tag == 'node':
56 | node_id = int(child.get('id'))
57 | graph.add_node(node_id)
58 | # parse link
59 | elif child.tag == 'link':
60 | from_node = int(child.find('from').get('node'))
61 | to_node = int(child.find('to').get('node'))
62 | graph.add_edge(from_node, to_node)
63 |
64 | nodes_num = len(graph.nodes)
65 | edges_num = len(graph.edges)
66 |
67 | print('nodes: ', nodes_num, '\n', graph.nodes, '\n',
68 | 'edges: ', edges_num, '\n', graph.edges)
69 | return graph, nodes_num, edges_num
70 |
71 |
72 | def create_topo_links_info_xml(path, links_info):
73 | """
74 |
75 | (switch1, switch2)
76 | (1, 1)
77 | 100
78 | 5ms
79 | 1
80 |
81 |
82 | :param path: 保存路径
83 | :param links_info: 链路信息字典 {link: {ports, bw, delay, loss}}
84 | :return: None
85 | """
86 | # 根节点
87 | root = ET.Element('links_info')
88 |
89 | for link, info in links_info.items():
90 | # 子节点
91 | child = ET.SubElement(root, 'links')
92 | child.text = str(link)
93 |
94 | # 二级子节点
95 | sub_child1 = ET.SubElement(child, 'ports')
96 | sub_child1.text = str((info['port1'], info['port2']))
97 |
98 | sub_child2 = ET.SubElement(child, 'bw')
99 | sub_child2.text = str(info['bw'])
100 |
101 | sub_child2 = ET.SubElement(child, 'delay')
102 | sub_child2.text = str(info['delay'])
103 |
104 | sub_child2 = ET.SubElement(child, 'loss')
105 | sub_child2.text = str(info['loss'])
106 |
107 | tree = ET.ElementTree(root)
108 | Path(path).parent.mkdir(exist_ok=True)
109 | tree.write(path, encoding='utf-8', xml_declaration=True)
110 | print('saved links info xml.')
111 |
112 |
113 | def get_mininet_device(net, idx: list, device='h'):
114 | """
115 | 获得idx中mininet的实例, 如 h1, h2 ...; s1, s2 ...
116 | :param net: mininet网络实例
117 | :param idx: 设备标号集合, list
118 | :param device: 设备名称 'h', 's'
119 | :return d: dict{idx: 设备mininet实例}
120 | """
121 | d = {}
122 | for i in idx:
123 | d.setdefault(i, net.get(f'{device}{i}'))
124 |
125 | return d
126 |
127 |
128 | def run_corresponding_sh_script(devices: dict, label_path):
129 | """
130 | 对应的host运行对应的shell脚本
131 | :param devices: {idx: device}
132 | :param label_path: './24nodes/TM-{}/{}/{}_'
133 | """
134 | p = label_path + '{}.sh'
135 | for i, d in devices.items():
136 | if i < 9:
137 | i = f'0{i}'
138 | else:
139 | i = f'{i}'
140 | p = p.format(i)
141 | _cmd = f'bash {p}'
142 | d.cmd(_cmd)
143 | print(f"---> complete run {label_path}")
144 |
145 |
146 | def run_ip_add_default(hosts: dict):
147 | """
148 | 运行 ip route add default via 10.0.0.x 命令
149 | """
150 | _cmd = 'ip route add default via 10.0.0.'
151 | for i, h in hosts.items():
152 | print(_cmd + str(i))
153 | h.cmd(_cmd + str(i))
154 | print("---> run ip add default complete")
155 |
156 |
157 | def _test_cmd(devices: dict, my_cmd):
158 | for i, d in devices.items():
159 | d.cmd(my_cmd)
160 | print(f'exec {my_cmd}zzz{i}')
161 | # print(f'return {r}')
162 |
163 |
164 | def run_iperf(path, host):
165 | _cmd = 'bash ' + path + '&'
166 | host.cmd(_cmd)
167 |
168 |
169 | def all_host_run_iperf(hosts: dict, path, finish_file):
170 | """
171 | path = r'./iperfTM/'
172 | """
173 | idxs = len(os.listdir(path))
174 | path = path + '/TM-'
175 | for idx in range(idxs):
176 | script_path = path + str(idx)
177 | for i, h in hosts.items():
178 | servers_cmd = script_path + '/Servers/server_' + str(i) + '.sh'
179 | _cmd = 'bash '
180 | print(_cmd + servers_cmd)
181 | h.cmd(_cmd + servers_cmd)
182 |
183 | for i, h in hosts.items():
184 | clients_cmd = script_path + '/Clients/client_' + str(i) + '.sh'
185 | _cmd = 'bash '
186 | print(_cmd + clients_cmd + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
187 | h.cmd(_cmd + clients_cmd)
188 |
189 | time.sleep(300)
190 |
191 | write_iperf_time(finish_file)
192 |
193 |
194 | def write_pinall_time(finish_file):
195 | with open(finish_file, "w+") as f:
196 | _content = {
197 | "ping_all_finish_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
198 | "start_save_flag": True,
199 | "finish_flag": False
200 | }
201 | json.dump(_content, f)
202 |
203 |
204 | def write_iperf_time(finish_file):
205 | with open(finish_file, "r+") as f:
206 | _read = json.load(f)
207 | _content = {
208 | "iperf_finish_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
209 | "finish_flag": True,
210 | }
211 | _read.update(_content)
212 |
213 | with open(finish_file, "w+") as f:
214 | json.dump(_read, f)
215 |
216 |
217 | def remove_finish_file(finish_file):
218 | try:
219 | os.remove(finish_file)
220 | except FileNotFoundError:
221 | pass
222 |
223 |
224 | def net_h1_ping_others(net):
225 | hosts = net.hosts
226 | for h in hosts[1:]:
227 | net.ping((hosts[0], h))
228 |
229 |
230 | class GEANT23nodesTopo(Topo):
231 | def __init__(self, graph):
232 | super(GEANT23nodesTopo, self).__init__()
233 | self.node_idx = graph.nodes
234 | self.edges_pairs = graph.edges
235 | self.bw1 = 100 # Gbps -> M
236 | self.bw2 = 25 # Gbps -> M
237 | self.bw3 = 1.15 # Mbps
238 | self.bw4 = 100 # host -- switch
239 | self.delay = 20
240 | self.loss = 10
241 |
242 | self.host_port = 9
243 | self.snooper_port = 10
244 |
245 | self.bw1_links = [(12, 22), (12, 10), (12, 2), (13, 17), (4, 2), (4, 16), (1, 3), (1, 7), (1, 16), (3, 10),
246 | (3, 21), (10, 16), (10, 17), (7, 17), (7, 2), (7, 21), (16, 9), (20, 17)]
247 | self.bw2_links = [(13, 19), (13, 2), (19, 7), (23, 17), (23, 2), (8, 5), (8, 9), (18, 2), (18, 21), (5, 16),
248 | (3, 11), (10, 11), (22, 20), (20, 15), (9, 15)]
249 | self.bw3_links = [(13, 14), (19, 6), (3, 14), (7, 6)]
250 |
251 | def _return_bw(link: tuple):
252 | if link in self.bw1_links:
253 | return self.bw1
254 | elif link in self.bw2_links:
255 | return self.bw2
256 | elif link in self.bw3_links:
257 | return self.bw3
258 | else:
259 | raise ValueError
260 |
261 | # 添加交换机
262 | switches = {}
263 | for s in self.node_idx:
264 | switches.setdefault(s, self.addSwitch('s{0}'.format(s)))
265 | print('添加交换机:', s)
266 |
267 | switch_port_dict = generate_switch_port(graph)
268 | links_info = {}
269 | # 添加链路
270 | for l in self.edges_pairs:
271 | port1 = switch_port_dict[l[0]].pop(0) + 1
272 | port2 = switch_port_dict[l[1]].pop(0) + 1
273 | bw = _return_bw(l)
274 |
275 | _d = str(random.randint(0, self.delay)) + 'ms'
276 | _l = random.randint(0, self.loss)
277 |
278 | self.addLink(switches[l[0]], switches[l[1]], port1=port1, port2=port2,
279 | bw=bw, delay=_d, loss=_l)
280 |
281 | links_info.setdefault(l, {"port1": port1, "port2": port2, "bw": bw, "delay": _d, "loss": _l})
282 |
283 | create_topo_links_info_xml(links_info_xml_path, links_info)
284 |
285 | # 添加host
286 | for i in self.node_idx:
287 | _h = self.addHost(f'h{i}', ip=f'10.0.0.{i}', mac=f'00.00.00.00.00.0{i}')
288 | self.addLink(_h, switches[i], port1=0, port2=self.host_port,
289 | bw=self.bw4)
290 |
291 | # add snooper
292 | # snooper = self.addSwitch("s30")
293 | # for i in self.node_idx:
294 | # self.addLink(snooper, switches[i], port1=i, port2=self.snooper_port)
295 |
296 |
297 | class Nodes14Topo(Topo):
298 | def __init__(self, graph):
299 | super(Nodes14Topo, self).__init__()
300 | self.node_idx = graph.nodes
301 | self.edges_pairs = graph.edges
302 |
303 | self.random_bw = 30 # Gbps -> M * 10
304 | self.bw4 = 50 # host -- switch
305 |
306 | self.delay = 20 # ms
307 | self.loss = 10 # %
308 |
309 | self.host_port = 9
310 | self.snooper_port = 10
311 |
312 | # 添加交换机
313 | switches = {}
314 | for s in self.node_idx:
315 | switches.setdefault(s, self.addSwitch('s{0}'.format(s)))
316 | print('添加交换机:', s)
317 |
318 | switch_port_dict = generate_switch_port(graph)
319 | links_info = {}
320 | # 添加链路
321 | for l in self.edges_pairs:
322 | port1 = switch_port_dict[l[0]].pop(0) + 1
323 | port2 = switch_port_dict[l[1]].pop(0) + 1
324 |
325 | _bw = random.randint(5, self.random_bw)
326 | _d = str(random.randint(1, self.delay)) + 'ms'
327 | _l = random.randint(0, self.loss)
328 |
329 | self.addLink(switches[l[0]], switches[l[1]], port1=port1, port2=port2,
330 | bw=_bw, delay=_d, loss=_l)
331 |
332 | links_info.setdefault(l, {"port1": port1, "port2": port2, "bw": _bw, "delay": _d, "loss": _l})
333 |
334 | create_topo_links_info_xml(links_info_xml_path, links_info)
335 |
336 | # 添加host
337 | for i in self.node_idx:
338 | _h = self.addHost(f'h{i}', ip=f'10.0.0.{i}', mac=f'00.00.00.00.00.0{i}')
339 | self.addLink(_h, switches[i], port1=0, port2=self.host_port,
340 | bw=self.bw4)
341 |
342 |
343 | def main(graph, topo, finish_file):
344 | print('===Remove old finish file')
345 | remove_finish_file(finish_file)
346 |
347 | net = Mininet(topo=topo, link=TCLink, controller=RemoteController, waitConnected=True, build=False)
348 | c0 = net.addController('c0', ip='127.0.0.1', port=6633)
349 |
350 | net.build()
351 | net.start()
352 |
353 | print("get hosts device list")
354 | hosts = get_mininet_device(net, graph.nodes, device='h')
355 |
356 | print("===Dumping host connections")
357 | dumpNodeConnections(net.hosts)
358 | print('===Wait ryu init')
359 | time.sleep(40)
360 | # 添加网关ip
361 | # run_ip_add_default(hosts)
362 |
363 | # net.pingAll()
364 | net_h1_ping_others(net)
365 | write_pinall_time(finish_file)
366 |
367 | # iperf脚本
368 | # hosts[1].cmd('iperf -s -u -p 1002 -1 &')
369 | # hosts[2].cmd('iperf -c 10.0.0.1 -u -p 1002 -b 20000k -t 30 &')
370 | print('===Run iperf scripts')
371 | t = threading.Thread(target=all_host_run_iperf, args=(hosts, iperf_path, finish_file), name='iperf', daemon=True)
372 | print('===Thread iperf start')
373 | t.start()
374 | # all_host_run_iperf(hosts, iperf_path)
375 |
376 | CLI(net)
377 | net.stop()
378 |
379 |
380 | if __name__ == '__main__':
381 | xml_topology_path = r'./topologies/topology2.xml'
382 | links_info_xml_path = r'./links_info/links_info.xml'
383 | iperf_path = "./iperfTM"
384 | iperf_interval = 0
385 | finish_file = './finish_time.json'
386 |
387 | graph, nodes_num, edges_num = parse_xml_topology(xml_topology_path)
388 | # topo = GEANT23nodesTopo(graph)
389 | topo = Nodes14Topo(graph)
390 |
391 | setLogLevel('info')
392 | main(graph, topo, finish_file)
393 |
--------------------------------------------------------------------------------
/mininet/iperf_script.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : iperf_script.py
3 | # @Date : 2022-02-10
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # @From :
6 | import argparse
7 | import shutil
8 | from pathlib import Path
9 | import numpy as np
10 |
11 |
12 | def read_npy(file=None):
13 | if file is None:
14 | file = args.file
15 | tms = np.load(file)
16 | return tms
17 |
18 |
19 | def create_script(tms):
20 | label = 0
21 | tms = np.transpose(tms, (2, 0, 1))
22 | for tm in tms:
23 | # FOR CREATING FOLDERS PER TRAFFIC MATRIX
24 | Path(f'./iperfTM/TM-{label}').mkdir(parents=True, exist_ok=True)
25 | nameTM = Path(f'./iperfTM/TM-{label}')
26 | label += 1
27 | print('------', nameTM)
28 | Path.mkdir(nameTM, exist_ok=True)
29 |
30 | # --------------------FLOWS--------------------------
31 | # FOR CREATING FOLDERS PER NODE
32 | for i in range(len(tm[0])):
33 | Path.mkdir(nameTM / Path('Clients'), exist_ok=True)
34 | Path.mkdir(nameTM / Path('Servers'), exist_ok=True)
35 |
36 | # Default parameters
37 | time_duration = args.time_duration
38 | port = args.port
39 | ip_dest = args.ip_dest
40 | throughput = args.throughput # take it in kbps from TM
41 |
42 | # UDP with time = 10s
43 | # -c: ip_destination
44 | # -b: throughput in k,m or g (Kbps, Mbps or Gbps)
45 | # -t: time in seconds
46 |
47 | # SERVER SIDE
48 | # iperf3 -s
49 |
50 | # CLIENT SIDE with iperf3
51 | # iperf3 -c -u -p -b -t -V -J
52 |
53 | # As we do not consider throughput in the same node, when src=dest the thro = 0
54 | for src in range(len(tm[0])):
55 | for dst in range(len(tm[0])):
56 | if src == dst:
57 | print("src: ", src, "dst: ", dst)
58 | tm[src][dst] = 0.0
59 |
60 | for src in range(1, len(tm[0]) + 1):
61 | with open(str(nameTM) + "/Clients/client_{0}.sh".format(str(src)), 'w') as fileClient:
62 | outputstring_a1 = "#!/bin/bash \necho Generating traffic..."
63 | fileClient.write(outputstring_a1)
64 | for dst in range(1, len(tm[0]) + 1):
65 | throughput = float(tm[src - 1][dst - 1])
66 | # throughput_g = throughput / (100) # scale the throughput value to mininet link capacities
67 | temp1 = ''
68 | if src != dst:
69 | temp1 = ''
70 | temp1 += '\n'
71 | temp1 += 'iperf3 -c '
72 | temp1 += '10.0.0.{0} '.format(str(dst))
73 | if dst > 9:
74 | temp1 += '-p {0}0{1} '.format(str(src), str(dst))
75 | else:
76 | temp1 += '-p {0}00{1} '.format(str(src), str(dst))
77 | temp1 += '-u -b ' + str(format(throughput, '.2f')) + 'k'
78 | # temp1 += ' -w 256k -t ' + str(time_duration)
79 | temp1 += ' -t ' + str(time_duration)
80 | temp1 += ' >/dev/null 2>&1 &\n' # & at the end of the line it's for running the process in bkg
81 | temp1 += 'sleep 0.4'
82 | fileClient.write(temp1)
83 |
84 | # print(na)
85 | for dst in range(len(tm[0])):
86 | dst_ = dst + 1
87 | with open(str(nameTM) + "/Servers/server_{0}.sh".format(str(dst_)), 'w') as fileServer:
88 | outputstring_a2 = '#!/bin/bash \necho Initializing server listening...'
89 | fileServer.write(outputstring_a2)
90 | for src in range(len(tm[0])):
91 | src_ = src + 1
92 | temp2 = ''
93 | if src != dst:
94 | # n = n+1
95 | temp2 = ''
96 | temp2 += '\n'
97 | temp2 += 'iperf3 -s '
98 | if dst_ > 9:
99 | temp2 += '-p {0}0{1} '.format(str(src_), str(dst_))
100 | else:
101 | temp2 += '-p {0}00{1} '.format(str(src_), str(dst_))
102 | temp2 += '-1'
103 | temp2 += ' >/dev/null 2>&1 &\n' # & at the end of the line it's for running the process in bkg
104 | temp2 += 'sleep 0.3'
105 | fileServer.write(temp2)
106 |
107 |
108 | if __name__ == '__main__':
109 | parser = argparse.ArgumentParser(description="Generate traffic matrices")
110 | parser.add_argument("--seed", default=2020, help="random seed")
111 | # time_duration = 30
112 | # port = 2022
113 | # ip_dest = "10.0.0.1"
114 | # throughput = 0.0 # take it in kbps from TM
115 | parser.add_argument("--time_duration", default=30, help="time_duration")
116 | parser.add_argument("--port", default=2022, help="port")
117 | parser.add_argument("--ip_dest", default="10.0.0.1", help="ip_dest")
118 | parser.add_argument("--throughput", default=0.0, help="take it in kbps from TM")
119 | parser.add_argument("--file",
120 | default=r'tm_statistic/communicate_tm.npy',
121 | help="take it in kbps from TM")
122 | args = parser.parse_args()
123 | shutil.rmtree("iperfTM")
124 | tms = read_npy()
125 | create_script(tms)
126 |
--------------------------------------------------------------------------------
/mininet/links_info/links_info.xml:
--------------------------------------------------------------------------------
1 |
2 | (1, 4)(1, 1)2420ms2(1, 5)(2, 1)2615ms7(1, 9)(3, 1)2012ms6(1, 3)(4, 1)226ms9(1, 11)(5, 1)715ms3(2, 4)(1, 2)1114ms10(2, 6)(2, 1)1120ms2(2, 5)(3, 2)2015ms1(4, 11)(3, 2)276ms7(4, 9)(4, 2)2418ms7(4, 5)(5, 3)302ms7(4, 7)(6, 1)612ms3(4, 6)(7, 2)818ms1(5, 6)(4, 3)249ms2(5, 8)(5, 1)518ms4(5, 7)(6, 2)198ms4(7, 8)(3, 2)2115ms3(7, 9)(4, 3)212ms9(8, 14)(3, 1)2120ms10(9, 10)(4, 1)128ms9(9, 13)(5, 1)1410ms2(12, 14)(1, 2)2814ms6(12, 13)(2, 2)2410ms2
--------------------------------------------------------------------------------
/mininet/topologies/topology-anonymised.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | Topology generated by geant-TM-all.pl
5 | Steve Uhlig
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
--------------------------------------------------------------------------------
/mininet/topologies/topology2.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
--------------------------------------------------------------------------------
/ryu/arp_handler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : arp_handler.py
3 | # @Date : 2022-03-09
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | from ryu.base import app_manager
6 | from ryu.base.app_manager import lookup_service_brick
7 | from ryu.ofproto import ofproto_v1_3
8 | from ryu.controller import ofp_event
9 | from ryu.controller.handler import set_ev_cls, MAIN_DISPATCHER
10 | from ryu.lib.packet import packet
11 | from ryu.lib.packet import arp, ipv4, ethernet
12 |
13 | ETHERNET = ethernet.ethernet.__name__
14 | ETHERNET_MULTICAST = "ff:ff:ff:ff:ff:ff"
15 | ARP = arp.arp.__name__
16 |
17 |
18 | class ArpHandler(app_manager.RyuApp):
19 | OFP_VERSION = [ofproto_v1_3.OFP_VERSION]
20 |
21 | def __init__(self, *args, **kwargs):
22 | super(ArpHandler, self).__init__(*args, **kwargs)
23 | self.discovery = lookup_service_brick('discovery')
24 | self.monitor = lookup_service_brick('monitor')
25 |
26 | self.arp_table = {}
27 | self.sw = {}
28 | self.mac_to_port = {}
29 |
30 | @set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER)
31 | def _packet_in_handler(self, ev):
32 | """
33 | 处理PacketIn事件
34 | 1. arp包 是否已经记录
35 | """
36 | # print("shortest---> _packet_in_handler: PacketIn")
37 | msg = ev.msg
38 | datapath = msg.datapath
39 | ofproto = datapath.ofproto
40 | parser = datapath.ofproto_parser
41 |
42 | in_port = msg.match['in_port']
43 |
44 | pkt = packet.Packet(msg.data)
45 | # print("shortest---> _packet_in_handler: pkt:\n ", pkt)
46 |
47 | arp_pkt = pkt.get_protocol(arp.arp)
48 | ipv4_pkt = pkt.get_protocol(ipv4.ipv4)
49 |
50 | eth = pkt.get_protocols(ethernet.ethernet)[0]
51 | src = eth.src
52 |
53 | header_list = dict((p.protocol_name, p) for p in pkt.protocols if type(p) != bytes)
54 | # print("shortest--->_packet_in_handler: header_list:\n ", header_list)
55 | if isinstance(arp_pkt, arp.arp):
56 | self.arp_table[arp_pkt.src_ip] = src
57 |
58 | if self.arp_handler(header_list, datapath, in_port, msg.buffer_id):
59 | # 1:reply or drop; 0: flood
60 | # print("ARP_PROXY_13")
61 | return None
62 | else:
63 | arp_src_ip = arp_pkt.src_ip
64 | arp_dst_ip = arp_pkt.dst_ip
65 | location = self.discovery.get_host_ip_location(arp_dst_ip)
66 | # print("shortest--->zzzzzz----> location", location)
67 | if location: # 如果有这个主机的位置
68 | # print("shortest--->_packet_in_handler: ==Reply Arp to knew host")
69 | dpid_dst, out_port = location
70 | datapath = self.monitor.datapaths_table[dpid_dst]
71 | out = self._build_packet_out(datapath, ofproto.OFP_NO_BUFFER, ofproto.OFPP_CONTROLLER,
72 | out_port, msg.data)
73 | datapath.send_msg(out)
74 | return
75 | else:
76 | print("shortest--->_packet_in_handler: ==Flooding")
77 | for dpid in self.discovery.switch_all_ports_table:
78 | for port in self.discovery.switch_all_ports_table[dpid]:
79 | if (dpid, port) not in self.discovery.access_table.keys(): # 如果不知道
80 | datapath = self.monitor.datapaths_table[dpid]
81 | out = self._build_packet_out(datapath, ofproto.OFP_NO_BUFFER,
82 | ofproto.OFPP_CONTROLLER, port, msg.data)
83 | datapath.send_msg(out)
84 | return
85 |
86 | def arp_handler(self, header_list, datapath, in_port, msg_buffer_id):
87 | header_list = header_list
88 | datapath = datapath
89 | in_port = in_port
90 |
91 | # if ETHERNET in header_list:
92 | eth_dst = header_list[ETHERNET].dst
93 | eth_src = header_list[ETHERNET].src
94 |
95 | # print("shortest---> arp_handler eth_dst eth_src: \n", eth_dst, eth_src)
96 |
97 | if eth_dst == ETHERNET_MULTICAST and ARP in header_list:
98 | arp_dst_ip = header_list[ARP].dst_ip
99 | if (datapath.id, eth_src, arp_dst_ip) in self.sw: # break loop
100 | # print("shortest---> arp_handler: ====BREAK LOOP")
101 | out = datapath.ofproto_parser.OFPPacketOut(
102 | datapath=datapath,
103 | buffer_id=datapath.ofproto.OFP_NO_BUFFER,
104 | in_port=in_port,
105 | actions=[], data=None
106 | )
107 | datapath.send_msg(out)
108 | return True
109 | else:
110 | self.sw[(datapath.id, eth_src, arp_dst_ip)] = in_port
111 |
112 | if ARP in header_list:
113 | # print("discovery---> arp_handler: ====ARP ARP")
114 | hwtype = header_list[ARP].hwtype
115 | proto = header_list[ARP].proto
116 | hlen = header_list[ARP].hlen
117 | plen = header_list[ARP].plen
118 | opcode = header_list[ARP].opcode
119 |
120 | arp_src_ip = header_list[ARP].src_ip
121 | arp_dst_ip = header_list[ARP].dst_ip
122 |
123 | actions = []
124 |
125 | if opcode == arp.ARP_REQUEST:
126 | if arp_dst_ip in self.arp_table: # arp reply
127 | # print("shortest---> arp_handler: ====ARP REPLY")
128 | actions.append(datapath.ofproto_parser.OFPActionOutput(in_port))
129 |
130 | ARP_Reply = packet.Packet()
131 | ARP_Reply.add_protocol(ethernet.ethernet(ethertype=header_list[ETHERNET].ethertype,
132 | dst=eth_src,
133 | src=self.arp_table[arp_dst_ip]))
134 | ARP_Reply.add_protocol(arp.arp(opcode=arp.ARP_REPLY,
135 | src_mac=self.arp_table[arp_dst_ip],
136 | src_ip=arp_dst_ip,
137 | dst_mac=eth_src,
138 | dst_ip=arp_src_ip))
139 |
140 | ARP_Reply.serialize()
141 |
142 | out = datapath.ofproto_parser.OFPPacketOut(
143 | datapath=datapath,
144 | buffer_id=datapath.ofproto.OFP_NO_BUFFER,
145 | in_port=datapath.ofproto.OFPP_CONTROLLER,
146 | actions=actions,
147 | data=ARP_Reply.data
148 | )
149 | datapath.send_msg(out)
150 | return True
151 | return False
152 |
153 | # 构造输出的包
154 | def _build_packet_out(self, datapath, buffer_id, src_port, dst_port, data):
155 | """ 构造输出的包"""
156 | actions = []
157 | if dst_port:
158 | actions.append(datapath.ofproto_parser.OFPActionOutput(dst_port))
159 |
160 | msg_data = None
161 | if buffer_id == datapath.ofproto.OFP_NO_BUFFER:
162 | if data is None:
163 | return None
164 | msg_data = data
165 |
166 | out = datapath.ofproto_parser.OFPPacketOut(datapath=datapath, buffer_id=buffer_id,
167 | data=msg_data, in_port=src_port, actions=actions)
168 |
169 | return out
170 |
--------------------------------------------------------------------------------
/ryu/network_delay.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : network_delay.py
3 | # @Date : 2021-08-12
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # network_delay.py
6 | import copy
7 | import time
8 |
9 | from ryu.base import app_manager
10 | from ryu.base.app_manager import lookup_service_brick
11 | from ryu.controller import ofp_event
12 | from ryu.controller.handler import set_ev_cls, MAIN_DISPATCHER, DEAD_DISPATCHER
13 | from ryu.ofproto import ofproto_v1_3
14 | from ryu.lib import hub
15 |
16 | from ryu.topology.switches import Switches, LLDPPacket
17 |
18 | import setting
19 | import network_structure
20 | import network_monitor
21 |
22 |
23 | class NetworkDelayDetector(app_manager.RyuApp):
24 | """ 测量链路的时延"""
25 | OFP_VERSION = [ofproto_v1_3.OFP_VERSION]
26 | _CONTEXTS = {'switches': Switches}
27 |
28 | def __init__(self, *args, **kwargs):
29 | super(NetworkDelayDetector, self).__init__(*args, **kwargs)
30 | self.name = 'detector'
31 |
32 | self.network_structure = lookup_service_brick('discovery')
33 | self.network_monitor = lookup_service_brick('monitor')
34 | self.switch_module = lookup_service_brick('switches')
35 |
36 | self.switch_module = kwargs['switches']
37 | # self.network_structure = kwargs['discovery']
38 | # self.network_monitor = kwargs['monitor']
39 |
40 | self.echo_delay_table = {} # {dpid: ryu_ofps_delay}
41 | self.lldp_delay_table = {} # {src_dpid: {dst_dpid: delay}}
42 | self.echo_interval = 0.05
43 |
44 | # self.datapaths_table = self.network_monitor.datapaths_table
45 | self._delay_thread = hub.spawn(self.scheduler)
46 |
47 | def scheduler(self):
48 | while True:
49 | self._send_echo_request()
50 | self.create_delay_graph()
51 |
52 | # if setting.PRINT_SHOW:
53 | # self.show_delay_stats()
54 |
55 | hub.sleep(setting.DELAY_PERIOD)
56 |
57 | # 利用echo发送时间,与接收时间相减
58 | # 1. 发送echo request
59 | def _send_echo_request(self):
60 | """ 发送echo请求"""
61 | datapaths_table = self.network_monitor.datapaths_table.values()
62 | if datapaths_table is not None:
63 | for datapath in list(datapaths_table):
64 | parser = datapath.ofproto_parser
65 | data = time.time()
66 | echo_req = parser.OFPEchoRequest(datapath, b"%.12f" % data)
67 | datapath.send_msg(echo_req)
68 | hub.sleep(self.echo_interval) # 防止发太快,这边收不到
69 |
70 | # 2. 接收echo reply
71 | @set_ev_cls(ofp_event.EventOFPEchoReply, MAIN_DISPATCHER)
72 | def _ehco_reply_handler(self, ev):
73 | now_timestamp = time.time()
74 | data = ev.msg.data
75 | ryu_ofps_delay = now_timestamp - eval(data) # 现在时间减去发送的时间
76 | self.echo_delay_table[ev.msg.datapath.id] = ryu_ofps_delay
77 |
78 | # 利用LLDP时延
79 | @set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER)
80 | def _packet_in_handler(self, ev):
81 | """ 解析LLDP包, 这个处理程序可以接收所有可以接收的数据包, swicthes.py l:769"""
82 | # print("detector---> PacketIn")
83 | try:
84 | recv_timestamp = time.time()
85 | msg = ev.msg
86 | dpid = msg.datapath.id
87 | src_dpid, src_port_no = LLDPPacket.lldp_parse(msg.data)
88 |
89 | # print("---> self.switch_module.ports", self.switch_module.ports)
90 |
91 | for port in self.switch_module.ports.keys():
92 | if src_dpid == port.dpid and src_port_no == port.port_no:
93 | send_timestamp = self.switch_module.ports[port].timestamp
94 | if send_timestamp:
95 | delay = recv_timestamp - send_timestamp
96 | # else:
97 | # delay = 0
98 | self.lldp_delay_table.setdefault(src_dpid, {})
99 | self.lldp_delay_table[src_dpid][dpid] = delay # 存起来
100 | except LLDPPacket.LLDPUnknownFormat as e:
101 | return
102 |
103 | def create_delay_graph(self):
104 | # 遍历所有的边
105 | # print('---> create delay graph')
106 | for src, dst in self.network_structure.graph.edges:
107 | delay = self.calculate_delay(src, dst)
108 | self.network_structure.graph[src][dst]['delay'] = delay * 1000 # ms
109 | # print("--->" * 2, self.network_structure.count + 1)
110 |
111 | def calculate_delay(self, src, dst):
112 | """
113 | ┌------Ryu------┐
114 | | |
115 | src echo latency| |dst echo latency
116 | | |
117 | SwitchA------------SwitchB
118 | --->fwd_delay--->
119 | <---reply_delay<---
120 | """
121 |
122 | fwd_delay = self.lldp_delay_table[src][dst]
123 | reply_delay = self.lldp_delay_table[dst][src]
124 | ryu_ofps_src_delay = self.echo_delay_table[src]
125 | ryu_ofps_dst_delay = self.echo_delay_table[dst]
126 |
127 | delay = (fwd_delay + reply_delay - ryu_ofps_src_delay - ryu_ofps_dst_delay) / 2
128 | return max(delay, 0)
129 |
130 | def show_delay_stats(self):
131 | self.logger.info("==============================DDDD delay=================================")
132 | self.logger.info("src dst : delay")
133 | for src in self.lldp_delay_table.keys():
134 | for dst in self.lldp_delay_table[src].keys():
135 | delay = self.lldp_delay_table[src][dst]
136 | self.logger.info("%s <---> %s : %s", src, dst, delay)
137 |
--------------------------------------------------------------------------------
/ryu/network_monitor.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @File : network_monitor.py
3 | # @Date : 2021-08-12
4 | # @Author : chenwei -剑衣沉沉晚霞归,酒杖津津神仙来-
5 | # network_monitor.py
6 | import copy
7 | from operator import attrgetter
8 |
9 | from ryu.base import app_manager
10 | from ryu.ofproto import ofproto_v1_3
11 | from ryu.controller import ofp_event
12 | from ryu.controller.handler import set_ev_cls, MAIN_DISPATCHER, DEAD_DISPATCHER
13 | from ryu.lib import hub
14 | from ryu.base.app_manager import lookup_service_brick
15 |
16 | import setting
17 | from setting import print_pretty_table, print_pretty_list
18 |
19 |
20 | class NetworkMonitor(app_manager.RyuApp):
21 | """ 监控网络流量状态"""
22 | OFP_VERSIONS = [ofproto_v1_3.OFP_VERSION]
23 |
24 | def __init__(self, *args, **kwargs):
25 | super(NetworkMonitor, self).__init__(*args, **kwargs)
26 | self.name = 'monitor'
27 | # {dpid: datapath}
28 | self.datapaths_table = {}
29 | # {dpid:{port_no: (config, state, curr_speed, max_speed)}}
30 | self.dpid_port_fueatures_table = {}
31 | # {(dpid, port_no): (stat.tx_bytes, stat.rx_bytes, stat.rx_errors, stat.duration_sec,
32 | # stat.duration_nsec, stat.tx_packets, stat.rx_packets)}
33 | self.port_stats_table = {}
34 | # {dpid:{(in_port, ipv4_dsts, out_port): (packet_count, byte_count, duration_sec, duration_nsec)}}
35 | self.flow_stats_table = {}
36 | # {(dpid, port_no): [speed, .....]}
37 | self.port_speed_table = {}
38 | # {dpid: {(in_port, ipv4_dsts, out_port): speed}}
39 | self.flow_speed_table = {}
40 |
41 | self.port_flow_dpid_stats = {'port': {}, 'flow': {}}
42 | # {dpid: {port_no: curr_bw}}
43 | self.port_curr_speed = {}
44 |
45 | self.port_loss = {}
46 |
47 | self.discovery = lookup_service_brick("discovery") # 创建一个NetworkStructure的实例
48 |
49 | # self.monitor_thread = hub.spawn(self._monitor)
50 | self.monitor_thread = hub.spawn(self.scheduler)
51 | self.save_thread = hub.spawn(self.save_bw_loss_graph)
52 |
53 | def print_parameters(self):
54 | # print("monitor---> self.datapaths_table", self.datapaths_table)
55 | # print("monitor---> self.dpid_port_fueatures_table", self.dpid_port_fueatures_table)
56 | # print("monitor---> self.port_stats_table", self.port_stats_table)
57 | # print("monitor---> self.flow_stats_table", self.flow_stats_table)
58 | # print("monitor---> self.port_speed_table", self.port_speed_table)
59 | # print("monitor---> self.flow_speed_table", self.flow_speed_table)
60 | # print("monitor---> self.port_curr_speed", self.port_curr_speed)
61 |
62 | logger = self.logger.info if setting.LOGGER else print
63 |
64 | # {dpid: datapath}
65 | # print_pretty_table(self.datapaths_table, ['dpid', 'datapath'], [6, 64], 'MMMM datapaths_table',
66 | # logger)
67 |
68 | # # {dpid:{port_no: (config, state, curr_speed, max_speed)}}
69 | # print_pretty_table(self.dpid_port_fueatures_table,
70 | # ['dpid', 'port_no:(config, state, curr_speed, max_speed)'],
71 | # [6, 40], 'MMMM dpid_port_fueatures_table', logger)
72 |
73 | # {(dpid, port_no): (stat.tx_bytes, stat.rx_bytes, stat.rx_errors, stat.duration_sec,
74 | # stat.duration_nsec, stat.tx_packets, stat.rx_packets)}
75 | # print_pretty_table(self.port_stats_table,
76 | # ['(dpid, port_no)',
77 | # '(stat.tx_bytes, stat.rx_bytes, stat.rx_errors, stat.duration_sec, stat.duration_nsec, stat.tx_packets, stat.rx_packets)'],
78 | # [18, 120], 'MMMM port_stats_table', logger)
79 |
80 | # {(dpid, port_no): [speed, .....]}
81 | # print_pretty_table(self.port_speed_table,
82 | # ['(dpid, port_no)', 'speed'],
83 | # [18, 40], 'MMMM port_speed_table', logger)
84 |
85 | print("'MMMM port_loss: \n", self.port_loss)
86 |
87 | def print_parameters_(self):
88 | print("monitor---------- %s ----------", self.name)
89 | for attr, value in self.__dict__.items():
90 | print("monitor\n---> %s: %s" % attr, value)
91 | print("monitor===================================")
92 |
93 | def scheduler(self):
94 | while True:
95 | self.port_flow_dpid_stats['flow'] = {}
96 | self.port_flow_dpid_stats['port'] = {}
97 |
98 | self._request_stats()
99 | if setting.PRINT_SHOW:
100 | self.print_parameters()
101 | hub.sleep(setting.MONITOR_PERIOD)
102 |
103 | def save_bw_loss_graph(self):
104 | while True:
105 | self.create_bandwidth_graph()
106 | self.create_loss_graph()
107 | hub.sleep(setting.MONITOR_PERIOD)
108 |
109 | @set_ev_cls(ofp_event.EventOFPStateChange, [MAIN_DISPATCHER, DEAD_DISPATCHER])
110 | def _state_change_handler(self, ev):
111 | """ 存放所有的datapath实例"""
112 | datapath = ev.datapath # OFPStateChange类可以直接获得datapath
113 | if ev.state == MAIN_DISPATCHER:
114 | if datapath.id not in self.datapaths_table:
115 | print("MMMM---> register datapath: %016x" % datapath.id)
116 | self.datapaths_table[datapath.id] = datapath
117 |
118 | # 一些初始化
119 | self.dpid_port_fueatures_table.setdefault(datapath.id, {})
120 | self.flow_stats_table.setdefault(datapath.id, {})
121 |
122 | elif ev.state == DEAD_DISPATCHER:
123 | if datapath.id in self.datapaths_table:
124 | print("MMMM---> unreigster datapath: %016x" % datapath.id)
125 | del self.datapaths_table[datapath.id]
126 |
127 | # 主动发送request,请求状态信息
128 | def _request_stats(self):
129 | # print("MMMM---> send request ---> ---> send request ---> ")
130 | datapaths_table = self.datapaths_table.values()
131 |
132 | for datapath in list(datapaths_table):
133 | self.dpid_port_fueatures_table.setdefault(datapath.id, {})
134 | # print("MMMM---> send stats request: %016x", datapath.id)
135 | ofproto = datapath.ofproto
136 | parser = datapath.ofproto_parser
137 |
138 | # 1. 端口描述请求
139 | req = parser.OFPPortDescStatsRequest(datapath, 0) #
140 | datapath.send_msg(req)
141 |
142 | # 2. 端口统计请求
143 | req = parser.OFPPortStatsRequest(datapath, 0, ofproto.OFPP_ANY) # 所有端口
144 | datapath.send_msg(req)
145 |
146 | # 3. 单个流统计请求
147 | # req = parser.OFPFlowStatsRequest(datapath)
148 | # datapath.send_msg(req)
149 |
150 | # 处理上面请求的回复OFPPortDescStatsReply
151 | @set_ev_cls(ofp_event.EventOFPPortDescStatsReply, MAIN_DISPATCHER)
152 | def port_desc_stats_reply_handler(self, ev):
153 | """ 存储端口描述信息, 见OFPPort类, 配置、状态、当前速度"""
154 | # print("MMMM---> EventOFPPortDescStatsReply")
155 | msg = ev.msg
156 | dpid = msg.datapath.id
157 | ofproto = msg.datapath.ofproto
158 |
159 | config_dict = {ofproto.OFPPC_PORT_DOWN: 'Port Down',
160 | ofproto.OFPPC_NO_RECV: 'No Recv',
161 | ofproto.OFPPC_NO_FWD: 'No Forward',
162 | ofproto.OFPPC_NO_PACKET_IN: 'No Pakcet-In'}
163 |
164 | state_dict = {ofproto.OFPPS_LINK_DOWN: "Link Down",
165 | ofproto.OFPPS_BLOCKED: "Blocked",
166 | ofproto.OFPPS_LIVE: "Live"}
167 |
168 | for ofport in ev.msg.body: # 这一直有bug,修改properties
169 | if ofport.port_no != ofproto_v1_3.OFPP_LOCAL: # 0xfffffffe 4294967294
170 |
171 | if ofport.config in config_dict:
172 | config = config_dict[ofport.config]
173 | else:
174 | config = 'Up'
175 |
176 | if ofport.state in state_dict:
177 | state = state_dict[ofport.state]
178 | else:
179 | state = 'Up'
180 |
181 | # 存储配置,状态, curr_speed,max_speed=0
182 | port_features = (config, state, ofport.curr_speed, ofport.max_speed)
183 | # print("MMMM---> ofport.curr_speed", ofport.curr_speed)
184 | self.dpid_port_fueatures_table[dpid][ofport.port_no] = port_features
185 |
186 | @set_ev_cls(ofp_event.EventOFPPortStatsReply, MAIN_DISPATCHER)
187 | def port_stats_table_reply_handler(self, ev):
188 | """ 存储端口统计信息, 见OFPPortStats, 发送bytes、接收bytes、生效时间duration_sec等
189 | Replay message content:
190 | (stat.port_no,
191 | stat.rx_packets, stat.tx_packets,
192 | stat.rx_bytes, stat.tx_bytes,
193 | stat.rx_dropped, stat.tx_dropped,
194 | stat.rx_errors, stat.tx_errors,
195 | stat.rx_frame_err, stat.rx_over_err,
196 | stat.rx_crc_err, stat.collisions,
197 | stat.duration_sec, stat.duration_nsec))
198 | """
199 | # print("MMMM---> EventOFPPortStatsReply")
200 | body = ev.msg.body
201 | dpid = ev.msg.datapath.id
202 | self.port_flow_dpid_stats['port'][dpid] = body
203 | # self.port_curr_speed.setdefault(dpid, {})
204 |
205 | for stat in sorted(body, key=attrgetter("port_no")):
206 | port_no = stat.port_no
207 | if port_no != ofproto_v1_3.OFPP_LOCAL:
208 | key = (dpid, port_no)
209 | value = (stat.tx_bytes, stat.rx_bytes, stat.rx_errors,
210 | stat.duration_sec, stat.duration_nsec, stat.tx_packets, stat.rx_packets)
211 | self._save_stats(self.port_stats_table, key, value, 5) # 保存信息,最多保存前5次
212 |
213 | pre_bytes = 0
214 | # delta_time = setting.MONITOR_PERIOD
215 | delta_time = setting.SCHEDULE_PERIOD
216 | stats = self.port_stats_table[key] # 获得已经存了的统计信息
217 |
218 | if len(stats) > 1: # 有两次以上的信息
219 | pre_bytes = stats[-2][0] + stats[-2][1]
220 | delta_time = self._calculate_delta_time(stats[-1][3], stats[-1][4],
221 | stats[-2][3], stats[-2][4]) # 倒数第一个统计信息,倒数第二个统计信息
222 |
223 | speed = self._calculate_speed(stats[-1][0] + stats[-1][1],
224 | pre_bytes, delta_time)
225 | self._save_stats(self.port_speed_table, key, speed, 5)
226 | self._calculate_port_speed(dpid, port_no, speed)
227 |
228 | self.calculate_loss_of_link()
229 |
230 | @set_ev_cls(ofp_event.EventOFPFlowStatsReply, MAIN_DISPATCHER)
231 | def _flow_stats_reply_handler(self, ev):
232 | """ 存储flow的状态,算这个干啥。。"""
233 | msg = ev.msg
234 | body = msg.body
235 | datapath = msg.datapath
236 | dpid = datapath.id
237 |
238 | self.port_flow_dpid_stats['flow'][dpid] = body
239 | # print("MMMM---> body", body)
240 |
241 | for stat in sorted([flowstats for flowstats in body if flowstats.priority == 1],
242 | key=lambda flowstats: (flowstats.match.get('in_port'), flowstats.match.get('ipv4_dst'))):
243 | # print("MMMM---> stat.match", stat.match)
244 | # print("MMMM---> stat", stat)
245 | key = (stat.match['in_port'], stat.match['ipv4_dst'],
246 | stat.instructions[0].actions[0].port)
247 | value = (stat.packet_count, stat.byte_count, stat.duration_sec, stat.duration_nsec)
248 | self._save_stats(self.flow_stats_table[dpid], key, value, 5)
249 |
250 | pre_bytes = 0
251 | # delta_time = setting.MONITOR_PERIOD
252 | delta_time = setting.SCHEDULE_PERIOD
253 | value = self.flow_stats_table[dpid][key]
254 | if len(value) > 1:
255 | pre_bytes = value[-2][1]
256 | # print("MMMM---> _flow_stats_reply_handler delta_time: now", value[-1][2], value[-1][3], "pre",
257 | # value[-2][2],
258 | # value[-2][3])
259 | delta_time = self._calculate_delta_time(value[-1][2], value[-1][3],
260 | value[-2][2], value[-2][3])
261 | speed = self._calculate_speed(self.flow_stats_table[dpid][key][-1][1], pre_bytes, delta_time)
262 | self.flow_speed_table.setdefault(dpid, {})
263 | self._save_stats(self.flow_speed_table[dpid], key, speed, 5)
264 |
265 | # 存多次数据,比如一个端口存上一次的统计信息和这一次的统计信息
266 | @staticmethod
267 | def _save_stats(_dict, key, value, keep):
268 | if key not in _dict:
269 | _dict[key] = []
270 | _dict[key].append(value)
271 |
272 | if len(_dict[key]) > keep:
273 | _dict[key].pop(0) # 弹出最早的数据
274 |
275 | def _calculate_delta_time(self, now_sec, now_nsec, pre_sec, pre_nsec):
276 | """ 计算统计时间, 即两个消息时间差"""
277 | return self._calculate_seconds(now_sec, now_nsec) - self._calculate_seconds(pre_sec, pre_nsec)
278 |
279 | @staticmethod
280 | def _calculate_seconds(sec, nsec):
281 | """ 计算 sec + nsec 的和,单位为 seconds"""
282 | return sec + nsec / 10 ** 9
283 |
284 | @staticmethod
285 | def _calculate_speed(now_bytes, pre_bytes, delta_time):
286 | """ 计算统计流量速度"""
287 | if delta_time:
288 |
289 | return (now_bytes - pre_bytes) / delta_time
290 | else:
291 | return 0
292 |
293 | def _calculate_port_speed(self, dpid, port_no, speed):
294 | curr_bw = speed * 8 / 10 ** 6 # MBit/s
295 | # print(f"monitorMMMM---> _calculate_port_speed: {curr_bw} MBits/s", )
296 | self.port_curr_speed.setdefault(dpid, {})
297 | self.port_curr_speed[dpid][port_no] = curr_bw
298 |
299 | @set_ev_cls(ofp_event.EventOFPPortStatus, MAIN_DISPATCHER)
300 | def _port_status_handler(self, ev):
301 | """ 处理端口状态: ADD, DELETE, MODIFIED"""
302 | msg = ev.msg
303 | dp = msg.datapath
304 | ofp = dp.ofproto
305 |
306 | if msg.reason == ofp.OFPPR_ADD:
307 | reason = 'ADD'
308 | elif msg.reason == ofp.OFPPR_DELETE:
309 | reason = 'DELETE'
310 | elif msg.reason == ofp.OFPPR_MODIFY:
311 | reason = 'MODIFY'
312 | else:
313 | reason = 'unknown'
314 |
315 | print('MMMM---> _port_status_handler OFPPortStatus received: reason=%s desc=%s' % (reason, msg.desc))
316 |
317 | # 通过获得的网络拓扑,更新其bw权重
318 | def create_bandwidth_graph(self):
319 | # print("MMMM---> create bandwidth graph")
320 | for link in self.discovery.link_port_table:
321 | src_dpid, dst_dpid = link
322 | src_port, dst_port = self.discovery.link_port_table[link]
323 |
324 | if src_dpid in self.port_curr_speed.keys() and dst_dpid in self.port_curr_speed.keys():
325 | src_port_bw = self.port_curr_speed[src_dpid][src_port]
326 | dst_port_bw = self.port_curr_speed[dst_dpid][dst_port]
327 | src_dst_bandwidth = min(src_port_bw, dst_port_bw) # bottleneck bandwidth
328 |
329 | # print(f"monitor--> dst[{src_dpid}]_port[{src_port}]_bw: %.5f" % dst_port_bw)
330 | # print(f"monitor---> src[{dst_dpid}]_port[{dst_port}]_bw: %.5f" % src_port_bw)
331 | # print("monitor---> src_dst_bandwidth: %.5f" % src_dst_bandwidth)
332 |
333 | # 对图的edge设置 可用bw 值
334 | capacity = self.discovery.m_graph[src_dpid][dst_dpid]['bw']
335 | self.discovery.graph[src_dpid][dst_dpid]['bw'] = max(capacity - src_dst_bandwidth, 0)
336 |
337 | else:
338 | self.logger.info(
339 | "MMMM---> create_bandwidth_graph: [{}] [{}] not in port_free_bandwidth ".format(src_dpid,
340 | dst_dpid))
341 | self.discovery.graph[src_dpid][dst_dpid]['bw'] = -1
342 |
343 | # print("MMMM---> ", self.discovery.graph.edges(data=True))
344 | # print("MMMM---> " * 2, self.discovery.count + 1)
345 |
346 | # calculate loss tx - rx / tx
347 | def calculate_loss_of_link(self):
348 | """
349 | 发端口 和 收端口 ,端口loss
350 | """
351 | for link, port in self.discovery.link_port_table.items():
352 | src_dpid, dst_dpid = link
353 | src_port, dst_port = port
354 | if (src_dpid, src_port) in self.port_stats_table.keys() and \
355 | (dst_dpid, dst_port) in self.port_stats_table.keys():
356 | # {(dpid, port_no): (stat.tx_bytes, stat.rx_bytes, stat.rx_errors, stat.duration_sec,
357 | # stat.duration_nsec, stat.tx_packets, stat.rx_packets)}
358 | # 1. 顺向 2022/3/11 packets modify--> bytes
359 | tx = self.port_stats_table[(src_dpid, src_port)][-1][0] # tx_bytes
360 | rx = self.port_stats_table[(dst_dpid, dst_port)][-1][1] # rx_bytes
361 | loss_ratio = abs(float(tx - rx) / tx) * 100
362 | self._save_stats(self.port_loss, link, loss_ratio, 5)
363 | # print(f"MMMM--->[{link}]({dst_dpid}, {dst_port}) rx: ", rx, "tx: ", tx,
364 | # "loss_ratio: ", loss_ratio)
365 |
366 | # 2. 逆项
367 | tx = self.port_stats_table[(dst_dpid, dst_port)][-1][0] # tx_bytes
368 | rx = self.port_stats_table[(src_dpid, src_port)][-1][1] # rx_bytes
369 | loss_ratio = abs(float(tx - rx) / tx) * 100
370 | self._save_stats(self.port_loss, link[::-1], loss_ratio, 5)
371 |
372 | # print(f"MMMM--->[{link[::-1]}]({dst_dpid}, {dst_port}) rx: ", rx, "tx: ", tx,
373 | # "loss_ratio: ", loss_ratio)
374 | else:
375 | self.logger.info("MMMM---> calculate_loss_of_link error", )
376 |
377 | # update graph loss
378 | def update_graph_loss(self):
379 | """从1 往2 和 从2 往1,取最大作为链路loss """
380 | for link in self.discovery.link_port_table:
381 | src_dpid = link[0]
382 | dst_dpid = link[1]
383 | if link in self.port_loss.keys() and link[::-1] in self.port_loss.keys():
384 | src_loss = self.port_loss[link][-1] # 1-->2 -1取最新的那个
385 | dst_loss = self.port_loss[link[::-1]][-1] # 2-->1
386 | link_loss = max(src_loss, dst_loss) # 百分比 max loss between port1 and port2
387 | self.discovery.graph[src_dpid][dst_dpid]['loss'] = link_loss
388 |
389 | # print(f"MMMM---> update_graph_loss link[{link}]_loss: ", link_loss)
390 | else:
391 | self.discovery.graph[src_dpid][dst_dpid]['loss'] = 100
392 |
393 | def create_loss_graph(self):
394 | """
395 | 在graph中更新边的loss值
396 | """
397 | # self.calculate_loss_of_link()
398 | self.update_graph_loss()
399 |
--------------------------------------------------------------------------------
/ryu/network_structure.py:
--------------------------------------------------------------------------------
1 | # network_structure.py
2 | import copy
3 | import time
4 | import xml.etree.ElementTree as ET
5 |
6 | from ryu.base import app_manager
7 | from ryu.ofproto import ofproto_v1_3
8 | from ryu.controller import ofp_event
9 | from ryu.controller.handler import set_ev_cls, MAIN_DISPATCHER, CONFIG_DISPATCHER
10 | from ryu.lib import hub
11 | from ryu.lib import igmplib, mac
12 | from ryu.lib.dpid import str_to_dpid
13 | from ryu.lib.packet import packet, arp, ethernet, ipv4, igmp
14 | from ryu.topology import event
15 | from ryu.topology.api import get_switch, get_link, get_host
16 |
17 | import networkx as nx
18 | import matplotlib.pyplot as plt
19 |
20 | import setting
21 | from setting import print_pretty_table, print_pretty_list
22 |
23 |
24 | class NetworkStructure(app_manager.RyuApp):
25 | """
26 | 发现网络拓扑,保存网络结构
27 | """
28 | OFP_VERSION = [ofproto_v1_3.OFP_VERSION]
29 |
30 | # _CONTEXTS = {'igmplib': igmplib.IgmpLib}
31 |
32 | def __init__(self, *args, **kwargs):
33 | super(NetworkStructure, self).__init__(*args, **kwargs)
34 | self.start_time = time.time()
35 | self.name = 'discovery'
36 | # self._snoop = kwargs['igmplib']
37 | # self._snoop.set_querier_mode(dpid=str_to_dpid('000000000000001e'), server_port=2)
38 | self.topology_api_app = self
39 | self.link_info_xml = setting.LINKS_INFO # xml file path of links info
40 | self.m_graph = self.parse_topo_links_info() # 解析mininet构建的topo链路信息
41 |
42 | self.graph = nx.Graph()
43 | self.pre_graph = nx.Graph()
44 |
45 | self.access_table = {} # {(dpid, in_port): (src_ip, src_mac)}
46 | self.switch_all_ports_table = {} # {dpid: {port_no, ...}}
47 | self.all_switches_dpid = {} # dict_key[dpid]
48 | self.switch_port_table = {} # {dpid: {port, ...}
49 | self.link_port_table = {} # {(src.dpid, dst.dpid): (src.port_no, dst.port_no)}
50 | self.not_use_ports = {} # {dpid: {port, ...}} 交换机之间没有用来连接的port
51 | self.shortest_path_table = {} # {(src.dpid, dst.dpid): [path]}
52 | self.arp_table = {} # {(dpid, eth_src, arp_dst_ip): in_port}
53 | self.arp_src_dst_ip_table = {}
54 |
55 | # self.multiple_access_table = {} # {(dpid, in_port): (src_ip, src_mac)}
56 | # self.group_table = {} # {group_address: [(dpid, in_port), ...]}
57 |
58 | # self._discover_thread = hub.spawn(self._discover_network_structure)
59 | # self._show_graph = hub.spawn(self.show_graph_plt())
60 | self.initiation_delay = setting.INIT_TIME
61 | self.first_flag = True
62 | self.cal_path_flag = False
63 |
64 | self._structure_thread = hub.spawn(self.scheduler)
65 | self._shortest_path_thread = hub.spawn(self.cal_shortest_path_thread)
66 |
67 | def print_parameters(self):
68 | # self.logger.info("discovery---> access_table: %s", self.access_table)
69 | # self.logger.info("discovery---> link_port_table: %s", self.link_port_table)
70 | # self.logger.info("discovery---> not_use_ports: %s", self.not_use_ports)
71 | # self.logger.info("discovery---> shortest_path_table: %s", self.shortest_path_table)
72 | logger = self.logger.info if setting.LOGGER else print
73 | # 图
74 | # logger("============================= SSSS graph edges==============================")
75 | # logger('SSSS---> graph edges:\n', self.graph.edges)
76 | # logger("=============================end SSSS graph edges=============================")
77 |
78 | # 交换机dpid: {交换机所有port号}
79 | # {dpid: {port_no, ...}}
80 | print_pretty_table(self.switch_all_ports_table, ['dpid', 'port_no'], [10, 10],
81 | 'SSSS switch_all_ports_table', logger)
82 |
83 | # 交换机id: lldp发现的端口
84 | # {dpid: {port, ...}
85 | print_pretty_table(self.switch_port_table, ['dpid', 'port_no'], [10, 10], 'SSSS switch_port_table',
86 | logger)
87 |
88 | # {(dpid, in_port): (src_ip, src_mac)}
89 | print_pretty_table(self.access_table, ['(dpid, in_port)', '(src_ip, src_mac)'], [10, 40], 'SSSS access_table',
90 | logger)
91 |
92 | # {dpid: {port, ...}}
93 | print_pretty_table(self.not_use_ports, ['dpid', 'not_use_ports'], [10, 30], 'SSSS not_use_ports', logger)
94 |
95 | def scheduler(self):
96 | i = 0
97 | while True:
98 | if i == 3:
99 | self.get_topology(None)
100 | i = 0
101 | hub.sleep(setting.DISCOVERY_PERIOD)
102 | if setting.PRINT_SHOW:
103 | self.print_parameters()
104 | i += 1
105 |
106 | def cal_shortest_path_thread(self):
107 | self.cal_path_flag = False
108 | while True:
109 | if self.cal_path_flag:
110 | self.calculate_all_nodes_shortest_paths(weight=setting.WEIGHT)
111 | # print("*****discovery---> self.shortest_path_table:\n", self.shortest_path_table)
112 | hub.sleep(setting.DISCOVERY_PERIOD)
113 |
114 | # Flow mod and Table miss
115 | @set_ev_cls(ofp_event.EventOFPSwitchFeatures, CONFIG_DISPATCHER)
116 | def switch_features_handler(self, ev):
117 | datapath = ev.msg.datapath
118 | ofproto = datapath.ofproto
119 | parser = datapath.ofproto_parser
120 |
121 | self.logger.info("discovery---> switch: %s connected", datapath.id)
122 |
123 | # install table miss flow entry
124 | match = parser.OFPMatch() # match all
125 | actions = [parser.OFPActionOutput(ofproto.OFPP_CONTROLLER,
126 | ofproto.OFPCML_NO_BUFFER)]
127 |
128 | self.add_flow(datapath, 0, match, actions)
129 |
130 | def add_flow(self, datapath, priority, match, actions):
131 | inst = [datapath.ofproto_parser.OFPInstructionActions(datapath.ofproto.OFPIT_APPLY_ACTIONS,
132 | actions)]
133 | mod = datapath.ofproto_parser.OFPFlowMod(datapath=datapath, priority=priority,
134 | match=match, instructions=inst)
135 | datapath.send_msg(mod)
136 |
137 | # Packet In
138 | @set_ev_cls(ofp_event.EventOFPPacketIn, MAIN_DISPATCHER)
139 | def _packet_in_handler(self, ev):
140 | # print("discovery---> discovery PacketIn")
141 | msg = ev.msg
142 | datapath = msg.datapath
143 |
144 | # 输入端口号
145 | in_port = msg.match['in_port']
146 | pkt = packet.Packet(msg.data)
147 | arp_pkt = pkt.get_protocol(arp.arp)
148 |
149 | if isinstance(arp_pkt, arp.arp):
150 | # print("SSSS---> _packet_in_handler: arp packet")
151 | arp_src_ip = arp_pkt.src_ip
152 | src_mac = arp_pkt.src_mac
153 | self.storage_access_info(datapath.id, in_port, arp_src_ip, src_mac)
154 | # print("discovery---> access_table:\n ", self.access_table)
155 |
156 | # 将packet-in解析的arp的网络通路信息存储
157 | def storage_access_info(self, dpid, in_port, src_ip, src_mac):
158 | # print(f"SSSS--->storage_access_info, self.access_table: {self.access_table}")
159 | if in_port in self.not_use_ports[dpid]:
160 | # print("discovery--->", dpid, in_port, src_ip, src_mac)
161 | if (dpid, in_port) in self.access_table:
162 | if self.access_table[(dpid, in_port)] == (src_ip, src_mac):
163 | return
164 | else:
165 | self.access_table[(dpid, in_port)] = (src_ip, src_mac)
166 | return
167 | else:
168 | self.access_table.setdefault((dpid, in_port), None)
169 | self.access_table[(dpid, in_port)] = (src_ip, src_mac)
170 | return
171 |
172 | # 利用topology库获取拓扑信息
173 | events = [event.EventSwitchEnter, event.EventSwitchLeave,
174 | event.EventPortAdd, event.EventPortDelete, event.EventPortModify,
175 | event.EventLinkAdd, event.EventLinkDelete]
176 |
177 | @set_ev_cls(events)
178 | def get_topology(self, ev):
179 | present_time = time.time()
180 | if present_time - self.start_time < self.initiation_delay: # Set to 30s
181 | print(f'SSSS--->get_topology: need to WAIT {self.initiation_delay - (present_time - self.start_time):.2f}s')
182 | return
183 | elif self.first_flag:
184 | self.first_flag = False
185 | print("SSSS--->get_topology: complete WAIT")
186 |
187 | # self.logger.info("discovery---> EventSwitch/Port/Link")
188 | self.logger.info("[Topology Discovery Ok]")
189 | # 事件发生时,获得swicth列表
190 | switch_list = get_switch(self.topology_api_app, None)
191 | # 将swicth添加到self.switch_all_ports_table
192 | for switch in switch_list:
193 | dpid = switch.dp.id
194 | self.switch_all_ports_table.setdefault(dpid, set())
195 | self.switch_port_table.setdefault(dpid, set())
196 | self.not_use_ports.setdefault(dpid, set())
197 | # print("discovery---> ",switch, switch.ports)
198 | for p in switch.ports:
199 | self.switch_all_ports_table[dpid].add(p.port_no)
200 |
201 | self.all_switches_dpid = self.switch_all_ports_table.keys()
202 |
203 | # time.sleep(0.5)
204 | # 获得link
205 | link_list = get_link(self.topology_api_app, None)
206 | # print("discovery---> ",len(link_list))
207 |
208 | # 将link添加到self.link_table
209 | for link in link_list:
210 | src = link.src # 实际是个port实例,我找了半天
211 | dst = link.dst
212 | self.link_port_table[(src.dpid, dst.dpid)] = (src.port_no, dst.port_no)
213 |
214 | if src.dpid in self.all_switches_dpid:
215 | self.switch_port_table[src.dpid].add(src.port_no)
216 | if dst.dpid in self.all_switches_dpid:
217 | self.switch_port_table[dst.dpid].add(dst.port_no)
218 |
219 | # 统计没使用的端口
220 | for sw_dpid in self.switch_all_ports_table.keys():
221 | all_ports = self.switch_all_ports_table[sw_dpid]
222 | linked_port = self.switch_port_table[sw_dpid]
223 | # print("discovery---> all_ports, linked_port", all_ports, linked_port)
224 | self.not_use_ports[sw_dpid] = all_ports - linked_port
225 |
226 | # 建立拓扑 bw和delay未定
227 | self.build_topology_between_switches()
228 | self.cal_path_flag = True
229 |
230 | def build_topology_between_switches(self, bw=0, delay=0, loss=0):
231 | """ 根据 src_dpid 和 dst_dpid 建立拓扑,bw 和 delay 信息还未定"""
232 | # networkxs使用已有Link的src_dpid和dst_dpid信息建立拓扑
233 | _graph = nx.Graph()
234 |
235 | # self.graph.clear()
236 | for (src_dpid, dst_dpid) in self.link_port_table.keys():
237 | # 建立switch之间的连接,端口可以通过查link_port_table获得
238 | _graph.add_edge(src_dpid, dst_dpid, bw=bw, delay=delay, loss=loss)
239 | if _graph.edges == self.graph.edges:
240 | return
241 | else:
242 | self.graph = _graph
243 |
244 | def calculate_weight(self, node1, node2, weight_dict):
245 | """ 计算路径时,weight可以调用函数,该函数根据因子计算 bw * factor - delay * (1 - factor) 后的weight"""
246 | # weight可以调用的函数
247 | assert 'bw' in weight_dict and 'delay' in weight_dict, "edge weight should have bw and delay"
248 | try:
249 | weight = weight_dict['bw'] * setting.FACTOR - weight_dict['delay'] * (1 - setting.FACTOR)
250 | return weight
251 | except TypeError:
252 | print("discovery ERROR---> weight_dict['bw']: ", weight_dict['bw'])
253 | print("discovery ERROR---> weight_dict['delay']: ", weight_dict['delay'])
254 | return None
255 |
256 | def get_shortest_paths(self, src_dpid, dst_dpid, weight=None):
257 | """ 计算src到dst的最短路径,存在self.shortest_path_table中"""
258 | graph = self.graph.copy()
259 | # print(graph.edges)
260 | # print("SSSS--->get_shortest_paths ==calculate shortest path %s to %s" % (src_dpid, dst_dpid))
261 | self.shortest_path_table[(src_dpid, dst_dpid)] = nx.shortest_path(graph,
262 | source=src_dpid,
263 | target=dst_dpid,
264 | weight=weight,
265 | method=setting.METHOD)
266 | # print("SSSS--->get_shortest_paths ==[PATH] %s <---> %s: %s" % (
267 | # src_dpid, dst_dpid, self.shortest_path_table[(src_dpid, dst_dpid)]))
268 |
269 | def calculate_all_nodes_shortest_paths(self, weight=None):
270 | """ 根据已构建的图,计算所有nodes间的最短路径,weight为权值,可以为calculate_weight()该函数"""
271 | self.shortest_path_table = {} # 先清空,再计算
272 | for src in self.graph.nodes():
273 | for dst in self.graph.nodes():
274 | if src != dst:
275 | self.get_shortest_paths(src, dst, weight=weight)
276 | else:
277 | continue
278 |
279 | def get_host_ip_location(self, host_ip):
280 | """
281 | 通过host_ip查询 self.access_table: {(dpid, in_port): (src_ip, src_mac)}
282 | 获得(dpid, in_port)
283 | """
284 | if host_ip == "0.0.0.0" or host_ip == "255.255.255.255":
285 | return None
286 |
287 | for key in self.access_table.keys(): # {(dpid, in_port): (src_ip, src_mac)}
288 | if self.access_table[key][0] == host_ip:
289 | # print("discovery--->zzzz---> key", key)
290 | return key
291 | print("SSS--->get_host_ip_location: %s location is not found" % host_ip)
292 | return None
293 |
294 | def get_ip_by_dpid(self, dpid):
295 | """
296 | 通过 dpid 查询 {(dpid, in_port): (src_ip, src_mac)}
297 | 获得 ip src_ip
298 | """
299 | for key, value in self.access_table.items():
300 | if key[0] == dpid:
301 | return value[0]
302 | print("SSS--->get_ip_by_dpid: %s ip is not found" % dpid)
303 | return None
304 |
305 | def parse_topo_links_info(self):
306 | m_graph = nx.Graph()
307 | parser = ET.parse(self.link_info_xml)
308 | root = parser.getroot()
309 |
310 | # links_info_element = root.find("links_info")
311 |
312 | def _str_tuple2int_list(s: str):
313 | s = s.strip()
314 | assert s.startswith('(') and s.endswith(")"), '应该为str的元组,如 "(1, 2)"'
315 | s_ = s[1: -1].split(', ')
316 | return [int(i) for i in s_]
317 |
318 | node1, node2, port1, port2, bw, delay, loss = None, None, None, None, None, None, None
319 | for e in root.iter():
320 | if e.tag == 'links':
321 | node1, node2 = _str_tuple2int_list(e.text)
322 | elif e.tag == 'ports':
323 | port1, port2 = _str_tuple2int_list(e.text)
324 | elif e.tag == 'bw':
325 | bw = float(e.text)
326 | elif e.tag == 'delay':
327 | delay = float(e.text[:-2])
328 | elif e.tag == 'loss':
329 | loss = float(e.text)
330 | else:
331 | print(e.tag)
332 | continue
333 | m_graph.add_edge(node1, node2, port1=port1, port2=port2, bw=bw, delay=delay, loss=loss)
334 |
335 | for edge in m_graph.edges(data=True):
336 | print(edge)
337 | return m_graph
338 |
--------------------------------------------------------------------------------
/ryu/setting.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from functools import reduce
3 |
4 | WORK_DIR = Path.cwd().parent
5 |
6 | # setting.py
7 | FACTOR = 0.9 # the coefficient of 'bw' , 1 - FACTOR is the coefficient of 'delay'
8 |
9 | METHOD = 'dijkstra' # the calculation method of shortest path
10 |
11 | DISCOVERY_PERIOD = 10 # discover network structure's period, the unit is seconds.
12 |
13 | MONITOR_PERIOD = 5 # monitor period, bw
14 |
15 | DELAY_PERIOD = 1.3 # detector period, delay
16 |
17 | SCHEDULE_PERIOD = 6 # shortest forwarding network awareness period
18 |
19 | PRINT_SHOW = False # show or not show print
20 |
21 | INIT_TIME = 30 # wait init for awareness
22 |
23 | PRINT_NUM_OF_LINE = 8 # 一行打印8个值
24 |
25 | LOGGER = True # 是否保存日志
26 |
27 | LINKS_INFO = WORK_DIR / "mininet/links_info/links_info.xml" # 链路信息的xml文件路径
28 |
29 | # SRC_IP = "10.0.0.1"
30 | # DST_MULTICAST_IP = {'224.1.1.1': 1, } # 组播地址: 标号(下面的索引号)
31 | # DST_GROUP_IP = [["10.0.0.2", "10.0.0.3", "10.0.0.4"], ] # 组成员的ip,(索引为上面的标号)
32 |
33 | DST_MULTICAST_IP = {'224.1.1.1': ["10.0.0.2", "10.0.0.4", "10.0.0.11"], } # 组播地址: 组成员的ip
34 |
35 | WEIGHT = 'bw'
36 | # FINAL_SWITCH_FLOW_IDEA = 1
37 |
38 | finish_time_file = WORK_DIR / "mininet/finish_time.json"
39 |
40 |
41 | def list_insert_one_by_one(list1, list2):
42 | l = []
43 | for x, y in zip(list1, list2):
44 | l.extend([x, y])
45 | return l
46 |
47 |
48 | def gen_format_str(num):
49 | fmt = ''
50 | for i in range(num):
51 | fmt += '{{:<{}}}'
52 | # fmt += '\n'
53 | return fmt
54 |
55 |
56 | # 只能打印key: value的两列,还不如用pandas
57 | def print_pretty_table(param, titles, widths, table_name='zzlong', logger=None):
58 | """
59 | 打印一个漂亮的表
60 | :param param: 要打印的字典,dict
61 | :param titles: 每列的title
62 | :param widths: 每列的宽度
63 | :param table_name: 表名字
64 | :param logger: 用什么打印 print / logger.info
65 | :return: None
66 | """
67 | f = logger if logger else print
68 | all_width = reduce(lambda x, y: x + y, widths)
69 | cut_line = "=" * all_width
70 | # 表名字
71 | w = all_width - len(table_name)
72 | if w > 1:
73 | f(cut_line[:w // 2] + table_name + cut_line[w // 2: w])
74 | else:
75 | f("=" + table_name + "=")
76 |
77 | # 以表格输出
78 | if isinstance(param, dict):
79 | # 获得{:^{}}多少个这个
80 | fmt = gen_format_str(len(titles))
81 | # 确定宽度
82 | width_fmt = fmt.format(*widths)
83 | # 确定值
84 | title_fmt = width_fmt.format(*titles)
85 | # 打印第一行title
86 | f(title_fmt)
87 | # 打印分割线
88 | f(cut_line)
89 | # 打印每一行的值
90 | for k, v in param.items():
91 | content_fmt = width_fmt.format(str(k), str(v))
92 | # 打印内容
93 | f(content_fmt)
94 |
95 | # 打印分割线
96 | f(cut_line + '\n')
97 |
98 |
99 | # def print_pretty_list(param, num, width=10, table_name='zzlong', logger=None):
100 | # """
101 | # 按每行固定个,打印列表中的值
102 | # :param param: 要打印的列表 list
103 | # :param num: 每行多少个值
104 | # :param width: 每个值的宽度
105 | # :param table_name: 表名字
106 | # :param logger: 用什么打印 print / logger.info
107 | # :return: None
108 | # """
109 | # f = logger if logger else print
110 | # all_widths = num * width
111 | # cut_line = "=" * all_widths
112 | # # 表名字
113 | # w = all_widths - len(table_name)
114 | # if w > 1:
115 | # f(cut_line[:w // 2] + table_name + cut_line[w // 2: w])
116 | # else:
117 | # f("=" + table_name + "=")
118 | #
119 | # # 直接打印
120 | # temp = 0
121 | # for i in range(len(param) // num):
122 | # f(param[temp: temp + num])
123 | # temp += num
124 | # if param[temp:]:
125 | # f(param[temp:])
126 | # else:
127 | # pass
128 | #
129 | # # 打印分割线
130 | # f(cut_line + '\n')
131 |
132 |
133 | if __name__ == '__main__':
134 | # a = {'test1': [11, 12, 13], 'test2': [21, 22, 23], 'test3': [31, 32, 33]}
135 | # print_pretty_table(a, ['my_test', 'values'], [10, 14], 'test_table', print)
136 | #
137 | # b = list(range(30))
138 | # print_pretty_list(b, 10, 10)
139 | print(WORK_DIR)
140 | print(LINKS_INFO)
141 | print(finish_time_file)
142 |
--------------------------------------------------------------------------------