├── __init__.py ├── analysis ├── p1.png ├── p2.png ├── p4.png ├── p3-1.png └── p3-2.png ├── models ├── perddqn │ └── test_model │ │ ├── checkpoint │ │ ├── test.meta │ │ ├── test.index │ │ ├── test_featch.pkl │ │ └── test.data-00000-of-00001 └── ddpg │ └── ddpg.py ├── logs ├── events.out.tfevents.1650539775.DESKTOP-A11M4AH ├── events.out.tfevents.1650539812.DESKTOP-A11M4AH ├── events.out.tfevents.1650539831.DESKTOP-A11M4AH ├── events.out.tfevents.1650539880.DESKTOP-A11M4AH ├── events.out.tfevents.1650539972.DESKTOP-A11M4AH ├── events.out.tfevents.1650540065.DESKTOP-A11M4AH ├── events.out.tfevents.1650540307.DESKTOP-A11M4AH ├── events.out.tfevents.1650540404.DESKTOP-A11M4AH ├── events.out.tfevents.1650540853.DESKTOP-A11M4AH ├── events.out.tfevents.1650540946.DESKTOP-A11M4AH ├── events.out.tfevents.1650541183.DESKTOP-A11M4AH ├── events.out.tfevents.1650541199.DESKTOP-A11M4AH ├── events.out.tfevents.1650541226.DESKTOP-A11M4AH ├── events.out.tfevents.1650541248.DESKTOP-A11M4AH ├── events.out.tfevents.1650541384.DESKTOP-A11M4AH ├── events.out.tfevents.1650541425.DESKTOP-A11M4AH ├── events.out.tfevents.1650541720.DESKTOP-A11M4AH ├── events.out.tfevents.1650541759.DESKTOP-A11M4AH ├── events.out.tfevents.1650542270.DESKTOP-A11M4AH ├── events.out.tfevents.1650542307.DESKTOP-A11M4AH ├── events.out.tfevents.1650543077.DESKTOP-A11M4AH ├── events.out.tfevents.1650543080.DESKTOP-A11M4AH ├── events.out.tfevents.1650543619.DESKTOP-A11M4AH ├── events.out.tfevents.1650543621.DESKTOP-A11M4AH ├── events.out.tfevents.1650544171.DESKTOP-A11M4AH ├── events.out.tfevents.1650544177.DESKTOP-A11M4AH ├── events.out.tfevents.1650544722.DESKTOP-A11M4AH ├── events.out.tfevents.1650544739.DESKTOP-A11M4AH ├── events.out.tfevents.1650545246.DESKTOP-A11M4AH └── events.out.tfevents.1650545273.DESKTOP-A11M4AH ├── README.md ├── __delete_model.py ├── common.py ├── analysis_train_all.py ├── tools.py ├── analysis_test_all.py ├── env.py └── tf_model_perddqn.py /__init__.py: -------------------------------------------------------------------------------- 1 | # Use for package -------------------------------------------------------------------------------- /analysis/p1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/analysis/p1.png -------------------------------------------------------------------------------- /analysis/p2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/analysis/p2.png -------------------------------------------------------------------------------- /analysis/p4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/analysis/p4.png -------------------------------------------------------------------------------- /analysis/p3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/analysis/p3-1.png -------------------------------------------------------------------------------- /analysis/p3-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/analysis/p3-2.png -------------------------------------------------------------------------------- /models/perddqn/test_model/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "test" 2 | all_model_checkpoint_paths: "test" 3 | -------------------------------------------------------------------------------- /models/perddqn/test_model/test.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/models/perddqn/test_model/test.meta -------------------------------------------------------------------------------- /models/perddqn/test_model/test.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/models/perddqn/test_model/test.index -------------------------------------------------------------------------------- /models/perddqn/test_model/test_featch.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/models/perddqn/test_model/test_featch.pkl -------------------------------------------------------------------------------- /models/perddqn/test_model/test.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/models/perddqn/test_model/test.data-00000-of-00001 -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650539775.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650539775.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650539812.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650539812.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650539831.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650539831.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650539880.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650539880.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650539972.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650539972.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650540065.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650540065.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650540307.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650540307.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650540404.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650540404.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650540853.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650540853.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650540946.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650540946.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650541183.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650541183.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650541199.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650541199.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650541226.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650541226.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650541248.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650541248.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650541384.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650541384.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650541425.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650541425.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650541720.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650541720.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650541759.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650541759.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650542270.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650542270.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650542307.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650542307.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650543077.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650543077.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650543080.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650543080.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650543619.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650543619.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650543621.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650543621.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650544171.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650544171.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650544177.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650544177.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650544722.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650544722.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650544739.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650544739.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650545246.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650545246.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /logs/events.out.tfevents.1650545273.DESKTOP-A11M4AH: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/perpetualbrighten/PRE-DDQN/HEAD/logs/events.out.tfevents.1650545273.DESKTOP-A11M4AH -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PER-DDQN 2 | This is the code of the PER-DDQN algorithm for paper: 3 | 4 | "Deep Reinforcement Learning-Based UAV Path Planning Algorithm in Agricultural Time-Constrained Data Collection". 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /__delete_model.py: -------------------------------------------------------------------------------- 1 | # 删除相关学习模型、训练和测试数据以及各种log 2 | # 以便重新训练、重头分析,同时清理空间 3 | # python删除文件的方法 os.remove(path) path指的是文件的绝对路径,如: 4 | # os.remove(r"E:\code\practice\data\1.py") #删除文件 5 | # os.rmdir(r"E:\code\practice\data\2") #删除文件夹(只能删除空文件夹) 6 | # path_data = "E:\code\practice\data"# 7 | 8 | import os 9 | 10 | 11 | def del_file(path_data): 12 | for ld in os.listdir(path_data): # os.listdir(path_data)#返回一个列表,里面是当前目录下面的所有东西的相对路径 13 | file_data = os.path.join(path_data, ld) # 当前文件夹的下面的所有东西的绝对路径 14 | if os.path.isfile(file_data): # os.path.isfile判断是否为文件,如果是文件,就删除.如果是文件夹.递归给del_file. 15 | os.remove(file_data) 16 | else: 17 | del_file(file_data) 18 | 19 | 20 | path_list = [ 21 | r".\logs", 22 | r".\train_output", 23 | r".\test_output", 24 | r".\models\PER_DDQN", 25 | r".\analysis" 26 | ] 27 | 28 | for path in path_list: 29 | del_file(path) 30 | -------------------------------------------------------------------------------- /common.py: -------------------------------------------------------------------------------- 1 | """ 2 | 公共参数部分 3 | """ 4 | import numpy as np 5 | import math 6 | 7 | # UAV 环境相关 8 | WIDTH = 1000 # 地图的宽度 9 | HEIGHT = 1000 # 地图的高度 10 | UNIT = 1 # 每个方块的大小(像素值) 1000m*1000m的地图 11 | LDA = [4., 8., 15., 20.] # 假设有4类传感器,即有4个不同的泊松参数,传感器数据生成服从泊松分布 12 | max_LDA = max(LDA) 13 | E = 400 # UAV的能量 14 | C = 10000 # 传感器的容量都假设为10000 15 | P_u = pow(10, -5) # 传感器的发射功率 0.01mW,-20dbm 16 | P_d = 10 # 无人机下行发射功率 10W,40dBm 17 | H = 15. # 无人机固定悬停高度 18 | R_d = 30. # 无人机充电覆盖范围 10m能接受到0.1mW,30m能接收到0.01mW 19 | N_S_ = 30 # 设备个数 20 | V = 1 # 无人机均速飞行速度 1m/s 21 | 22 | b_S_pos = np.random.randint(0, 1000, [N_S_, 2]) # 初始化传感器位置 23 | b_S_cache = np.random.randint(200, 800, N_S_) # 初始化传感器当前数据缓存量 24 | b_S_death = np.random.randint(4000, 8000, N_S_) # 初始化数据的死亡时间 25 | 26 | # 能耗 27 | PF = 0.01 28 | PH = 0.02 29 | PT = 0.25 30 | 31 | 32 | # 非线性能量接收机模型 33 | Mj = 9.079 * pow(10, -6) 34 | aj = 47083 35 | bj = 2.9 * pow(10, -6) 36 | Oj = 1 / (1 + math.exp(aj * bj)) 37 | 38 | # 马尔可夫决策相关 39 | STATE_DIM = 4 40 | ACTION_DIM = 1 41 | 42 | # 奖励函数相关 43 | GAMMA = 0.99 44 | ALPHA = 0.01 45 | BETA = 0.001 46 | LAMBDA = 0.1 47 | OMEGA = 100 48 | 49 | 50 | -------------------------------------------------------------------------------- /analysis_train_all.py: -------------------------------------------------------------------------------- 1 | """ 2 | 这个文件对生成的训练阶段文件进行数据读取、画图等操作 3 | """ 4 | 5 | import seaborn as sns 6 | import numpy as np 7 | import pandas as pd 8 | import matplotlib.pyplot as plt 9 | from matplotlib.pyplot import MultipleLocator # 设置刻度 10 | 11 | # sns.plotting_context('paper') 12 | 13 | plt.rcParams["font.family"] = "Times New Roman" 14 | # plt.rcParams['font.sans-serif'] = ['SimHei'] # 如果要显示中文字体,则在此处设为:SimHei 15 | plt.rcParams['axes.unicode_minus'] = False # 显示负号 16 | plt.rcParams['font.size'] = 20 # 设置字体大小,全局有效 17 | 18 | 19 | params = { 20 | 'legend.fontsize': 15, 21 | 'figure.figsize': (12, 8), 22 | } 23 | plt.rcParams.update(params) 24 | 25 | 26 | PATH_LIST = [ 27 | r'.\analysis\train_ucb.csv', 28 | r'.\analysis\train_greedy.csv' 29 | ] 30 | 31 | raw_data = [] 32 | for i in range(2): 33 | train_df = pd.read_csv(PATH_LIST[i]) 34 | raw_data.append(train_df) 35 | print(train_df.info()) 36 | 37 | 38 | x = raw_data[0].index 39 | 40 | y_dt_d = raw_data[0]['number'] 41 | y_ql_d = raw_data[1]['number'] 42 | 43 | 44 | # print(y_mr, y_tr) 45 | plt.figure(figsize=(12, 8)) 46 | plt.grid(linestyle="--") # 设置背景网格线为虚线 47 | # 颜色绿色,点形圆形,线性虚线,设置图例显示内容,线条宽度为2 48 | 49 | plt.plot(x, y_dt_d, 'r-', linewidth=1, markersize=5, label=r"Adaptive UCB") 50 | plt.plot(x, y_ql_d, 'b-', linewidth=1, markersize=5, label=r"$\epsilon$-greedy") 51 | 52 | 53 | plt.xlabel("Number of Epoch") # 横坐标轴的标题 54 | plt.ylabel(r"Data-Energy Ratio $\eta^{total}$ (KB)") # 纵坐标轴的标题 55 | plt.legend() 56 | 57 | plt.savefig(fname=r".\demo\data_process.png", dpi=600) 58 | plt.show() 59 | 60 | # plt.ylim([0.50, 1]) # 设置纵坐标轴范围为 -2 到 2 61 | # plt.grid(linestyle="--") # 设置背景网格线为虚线 62 | # # plt.grid() # 显示网格 63 | # # plt.title('result') # 图形的标题 64 | 65 | 66 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | """ 2 | 实现各种函数计算相关 3 | """ 4 | 5 | import numpy as np 6 | from math import pi, atan2 7 | from common import * 8 | 9 | ETX = 50 * (10 ** (-9)) # Energy for transferring of each bit:发射单位报文损耗能量:50nJ/bit 10 | ERX = 50 * (10 ** (-9)) # Energy for receiving of each bit:接收单位报文损耗能量:50nJ/bit 11 | Efs = 10 * (10 ** (-12)) # Energy of free space model:自由空间传播模型:10pJ/bit/m2 12 | K = 2000 # 系数k用于对齐 13 | P = 0.001 # 功率为0.001J/s 14 | 15 | 16 | def ucb_Q(sum_cnt, single_cnt): 17 | res = 2 * np.log(sum_cnt + 1) 18 | res /= single_cnt + 1 19 | return np.sqrt(res) 20 | 21 | 22 | def get_reward(D, T): 23 | base_reward = 0 24 | base_reward -= ALPHA * D 25 | if T >= 0: 26 | base_reward -= BETA * T 27 | return base_reward 28 | 29 | 30 | def calc_move(pos_start, pos_end): 31 | """ 32 | 计算两个位置间移动向量 33 | :param pos_start:起始点 34 | :param pos_end:终点 35 | :return:向量 36 | """ 37 | return pos_end[0][0] - pos_start[0][0], pos_end[0][1] - pos_start[0][1] 38 | 39 | 40 | def calc_distance(pos_1, pos_2): 41 | """ 42 | 计算两个位置间距离 43 | :param pos_1:位置1 44 | :param pos_2:位置2 45 | :return:距离 46 | """ 47 | d1, d2 = calc_move(pos_1, pos_2) 48 | return (d1 ** 2 + d2 ** 2) ** 0.5 49 | 50 | 51 | def calc_region_id(base_pos, neighbor_pos): 52 | move = calc_move(base_pos, neighbor_pos) 53 | theta = atan2(move[1], move[0]) / pi * 180 54 | # 45 degree pre step 55 | move_pos = round(theta / 45) 56 | # using right hand 57 | if move_pos < 0: 58 | move_pos += 8 59 | return move_pos 60 | 61 | 62 | def cross_product(move_1, move_2): 63 | return move_1[0] * move_2[1] - move_1[1] * move_2[0] 64 | 65 | 66 | def calc_counterclockwise(move_1, move_2): 67 | return cross_product(move_1, move_2) > 0 68 | 69 | 70 | def calc_intersect(line_1_start, line_1_end, line_2_start, line_2_end): 71 | def get_A_B_C(line_1, line_2): 72 | x0, y0 = line_1 73 | x1, y1 = line_2 74 | a = y0 - y1 75 | b = x1 - x0 76 | c = x0 * y1 - x1 * y0 77 | return a, b, c 78 | 79 | (a0, b0, c0) = get_A_B_C(line_1_start, line_1_end) 80 | (a1, b1, c1) = get_A_B_C(line_2_start, line_2_end) 81 | D = a0 * b1 - a1 * b0 82 | if D == 0: 83 | return None 84 | x = (b0 * c1 - b1 * c0) / D 85 | y = (a1 * c0 - a0 * c1) / D 86 | 87 | # line1:A1,A2 88 | # line2:B1,B2 89 | # test if intersect 90 | B1B2 = calc_move(line_2_start, line_2_end) 91 | B1A1 = calc_move(line_2_start, line_1_start) 92 | B1A2 = calc_move(line_2_start, line_1_end) 93 | c_1 = cross_product(B1B2, B1A1) 94 | c_2 = cross_product(B1B2, B1A2) 95 | pass_1_flag = c_2 * c_1 <= 0 96 | if not pass_1_flag: 97 | return None 98 | 99 | A1A2 = calc_move(line_1_start, line_1_end) 100 | A1B1 = calc_move(line_1_start, line_2_start) 101 | A1B2 = calc_move(line_1_start, line_2_end) 102 | c_3 = cross_product(A1A2, A1B1) 103 | c_4 = cross_product(A1A2, A1B2) 104 | pass_2_flag = c_3 * c_4 <= 0 105 | if not pass_2_flag: 106 | return None 107 | 108 | return x, y 109 | 110 | 111 | def calc_cos_theta(move_1, move_2, is_360=True): 112 | """ 113 | 计算两个向量间cosθ 114 | :param move_1: 移动向量1 115 | :param move_2: 移动向量2 116 | :param is_360: 逆时针360° 117 | :return: cosθ 118 | """ 119 | # to avoid move_1 == move_2 120 | if is_360 and move_1 == move_2: 121 | return -3 122 | a_norm = (move_1[0] ** 2 + move_1[1] ** 2) ** 0.5 + 1e-5 123 | b_norm = (move_2[0] ** 2 + move_2[1] ** 2) ** 0.5 + 1e-5 124 | cos_theta = (move_1[0] * move_2[0] + move_1[1] * move_2[1]) / ( 125 | a_norm * b_norm) 126 | # calc_counterclockwise or not 127 | if is_360 and calc_counterclockwise(move_1, move_2): 128 | cos_theta = -2 - cos_theta 129 | return cos_theta 130 | 131 | 132 | def rotate_right_hand(vector, theta): 133 | x, y = vector 134 | x1 = x * np.cos(theta) + y * np.sin(theta) 135 | y1 = -x * np.sin(theta) + y * np.cos(theta) 136 | return x1, y1 137 | 138 | 139 | def softmax(x): 140 | """Compute softmax values for each sets of scores in x.""" 141 | e_x = np.exp(x - np.max(x)) 142 | return e_x / e_x.sum() 143 | 144 | 145 | def trans_speed(): 146 | return np.random.rand() * 450 + 50 147 | -------------------------------------------------------------------------------- /analysis_test_all.py: -------------------------------------------------------------------------------- 1 | """ 2 | 这个文件对生成的测试阶段文件进行数据读取、画图等操作 3 | """ 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | 9 | def draw_png(save_path, xticklabels, xlabel, ylabel, data_list, ymax=None): 10 | n_groups = 5 11 | fig, ax = plt.subplots() 12 | ax.grid(axis="y", linestyle="--") 13 | index = np.arange(n_groups) 14 | bar_width = 0.2 15 | 16 | error_config = {'ecolor': '0.3'} 17 | 18 | rects1 = ax.bar(index, data_list[0], 19 | bar_width, color='g', 20 | error_kw=error_config, zorder=3, 21 | label='MDTF') 22 | 23 | rects2 = ax.bar(index + bar_width * 1, data_list[1], 24 | bar_width, color='b', 25 | error_kw=error_config, zorder=3, 26 | label='MDLF') 27 | 28 | rects3 = ax.bar(index + bar_width * 2, data_list[2], 29 | bar_width, color='m', 30 | error_kw=error_config, zorder=3, 31 | label='ACO') 32 | 33 | rects4 = ax.bar(index + bar_width * 3, data_list[3], 34 | bar_width, color='r', 35 | error_kw=error_config, zorder=3, 36 | label='PER-DDQN') 37 | 38 | ax.set_xticks(index + bar_width) 39 | ax.set_xticklabels(xticklabels) 40 | if ymax: 41 | plt.ylim(ymax=ymax) 42 | plt.xlabel(xlabel) 43 | plt.ylabel(ylabel) 44 | 45 | plt.rcParams['font.size'] = 15 # 设置字体大小,全局有效 46 | plt.legend(title="Strategy") 47 | 48 | plt.rcParams['font.size'] = 20 # 设置字体大小,全局有效 49 | fig.tight_layout() 50 | plt.savefig(save_path, dpi=600) 51 | plt.show() 52 | 53 | 54 | # sns.plotting_context('paper') 55 | # sns.set(font = 'fangsong') # 解决Seaborn中文显示问题 56 | 57 | plt.rcParams["font.family"] = "Times New Roman" 58 | # plt.rcParams['font.sans-serif'] = ['SimHei'] # 如果要显示中文字体,则在此处设为:SimHei 59 | plt.rcParams['axes.unicode_minus'] = False # 显示负号 60 | plt.rcParams['font.size'] = 20 # 设置字体大小,全局有效 61 | 62 | params = { 63 | 'legend.fontsize': 15, 64 | 'figure.figsize': (12, 8), 65 | } 66 | plt.rcParams.update(params) 67 | 68 | PATH = r'.\analysis\test_all.csv' 69 | test_df = pd.read_csv(PATH) 70 | 71 | mdata = test_df.values[0] 72 | mdeath = test_df.values[1] 73 | aco = test_df.values[2] 74 | dqn = test_df.values[3] 75 | # -------------------------------------- # 76 | 77 | 78 | xticklabels = ('20', '25', '30', '35', '40') 79 | xlabel = "Number of Sensors" 80 | ylabel = r"Data-Energy Ratio $\eta^{total} $(KB)" 81 | 82 | lmdata = [] 83 | lmdeath = [] 84 | laco = [] 85 | ldqn = [] 86 | 87 | start = 0 88 | for i in range(start, start+15, 3): 89 | lmdata.append(mdata[i]) 90 | lmdeath.append(mdeath[i]) 91 | laco.append(aco[i]) 92 | ldqn.append(dqn[i]) 93 | 94 | data_list = (lmdata, lmdeath, laco, ldqn) 95 | 96 | draw_png( 97 | r".\demo\p2.png", 98 | xticklabels, xlabel, ylabel, data_list) 99 | 100 | 101 | xticklabels = ('20', '25', '30', '35', '40') 102 | xlabel = "Number of Sensors" 103 | ylabel = r"Flight Time(s)" 104 | 105 | lmdata = [] 106 | lmdeath = [] 107 | laco = [] 108 | ldqn = [] 109 | 110 | start = 1 111 | for i in range(start, start+15, 3): 112 | lmdata.append(mdata[i]) 113 | lmdeath.append(mdeath[i]) 114 | laco.append(aco[i]) 115 | ldqn.append(dqn[i]) 116 | 117 | data_list = (lmdata, lmdeath, laco, ldqn) 118 | 119 | draw_png( 120 | r".\demo\p3-1.png", 121 | xticklabels, xlabel, ylabel, data_list) 122 | 123 | 124 | xticklabels = ('20', '25', '30', '35', '40') 125 | xlabel = "Number of Sensors" 126 | ylabel = r"Execution Time (s)" 127 | 128 | lmdata = [] 129 | lmdeath = [] 130 | laco = [] 131 | ldqn = [] 132 | 133 | start = 2 134 | for i in range(start, start+15, 3): 135 | lmdata.append(mdata[i]) 136 | lmdeath.append(mdeath[i]) 137 | laco.append(aco[i]) 138 | ldqn.append(dqn[i]) 139 | 140 | data_list = (lmdata, lmdeath, laco, ldqn) 141 | 142 | draw_png( 143 | r".\demo\p4.png", 144 | xticklabels, xlabel, ylabel, data_list) 145 | 146 | 147 | # 数据采集量 148 | xticklabels = ('20', '25', '30', '35', '40') 149 | xlabel = "Number of Sensors" 150 | ylabel = r"Amount of Data Collected (MB)" 151 | 152 | lmdata = [] 153 | lmdeath = [] 154 | laco = [] 155 | ldqn = [] 156 | 157 | start = 15 158 | for i in range(start, start+5, 1): 159 | lmdata.append(mdata[i]) 160 | lmdeath.append(mdeath[i]) 161 | laco.append(aco[i]) 162 | ldqn.append(dqn[i]) 163 | 164 | data_list = (lmdata, lmdeath, laco, ldqn) 165 | 166 | draw_png( 167 | r".\demo\p3-2.png", 168 | xticklabels, xlabel, ylabel, data_list) 169 | 170 | 171 | -------------------------------------------------------------------------------- /models/ddpg/ddpg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tensorflow as tf 4 | import tensorlayer as tl 5 | 6 | 7 | ############################### DDPG #################################### 8 | 9 | class AGENT(object): 10 | """ 11 | DDPG class 12 | """ 13 | 14 | def __init__(self, args, a_num, a_dim, s_dim, a_bound, is_train): 15 | self.memory = np.zeros((args.mem_size, s_dim * 2 + a_dim + 1), dtype=np.float32) 16 | self.pointer = 0 17 | self.mem_size = args.mem_size 18 | self.batch_size = args.batch_size 19 | self.replace_tau = args.replace_tau 20 | self.a_num = a_num 21 | self.a_dim, self.s_dim, self.a_bound = a_dim, s_dim, a_bound.high 22 | self.is_train = is_train 23 | self.gamma = args.gamma 24 | 25 | w_init = tf.initializers.he_normal() 26 | b_init = tf.constant_initializer(0.001) 27 | 28 | # 建立actor网络,输入s, 输出a 29 | def get_actor(input_state, name=''): 30 | """ 31 | Build actor network 32 | :param input_state_shape: state 33 | :param name: name 34 | :return: act 35 | """ 36 | s = tl.layers.Input(input_state, name='A_s_input') 37 | x = tl.layers.Dense(n_units=400, act=tf.nn.relu6, W_init=w_init, b_init=b_init, name='A_l1')(s) 38 | x = tl.layers.Dense(n_units=300, act=tf.nn.relu6, W_init=w_init, b_init=b_init, name='A_l2')(x) 39 | x_v = tl.layers.Dense(n_units=300, act=tf.nn.relu6, W_init=w_init, b_init=b_init, name='v_l3')(x) 40 | x_angle = tl.layers.Dense(n_units=300, act=tf.nn.relu6, W_init=w_init, b_init=b_init, name='angle_l3')(x) 41 | v = tl.layers.Dense(n_units=1, act=tf.nn.sigmoid, W_init=w_init, name='action_v')(x_v) 42 | angle = tl.layers.Dense(n_units=1, act=tf.nn.tanh, W_init=w_init, name='action_angle')(x_angle) 43 | output = tl.layers.Concat(1)([v, angle]) 44 | 45 | return tl.models.Model(inputs=s, outputs=output, name='Actor' + name) 46 | 47 | # 建立Critic网络,输入s, a。输出Q值 48 | def get_critic(input_state, input_action_params, name=''): 49 | """ 50 | Build critic network 51 | :param input_state_shape: state 52 | :param input_action_shape: act 53 | :param name: name 54 | :return: Q value Q(s,a) 55 | """ 56 | s = tl.layers.Input(input_state, name='C_s_input') 57 | a_paras = tl.layers.Input(input_action_params, name='C_a_paras_input') 58 | inputs = tl.layers.Concat(1)([s, a_paras]) 59 | x = tl.layers.Dense(n_units=400, act=tf.nn.relu6, W_init=w_init, b_init=tf.constant_initializer(0.01), name='C_l1')(inputs) 60 | x = tl.layers.Dense(n_units=300, act=tf.nn.relu6, W_init=w_init, b_init=tf.constant_initializer(0.01), name='C_l2')(x) 61 | q = tl.layers.Dense(n_units=1, W_init=w_init, b_init=tf.constant_initializer(0.01), name='C_svalue')(x) 62 | return tl.models.Model(inputs=[s, a_paras], outputs=q, name='Critic' + name) 63 | 64 | self.actor = get_actor([None, s_dim]) 65 | self.critic = get_critic([None, s_dim], [None, a_dim]) 66 | self.actor.train() 67 | self.critic.train() 68 | 69 | # 更新参数,只用于首次赋值,之后就没用了 70 | def copy_para(from_model, to_model): 71 | """ 72 | Copy parameters for soft updating 73 | :param from_model: latest model 74 | :param to_model: target model 75 | :return: None 76 | """ 77 | for i, j in zip(from_model.trainable_weights, to_model.trainable_weights): 78 | j.assign(i) 79 | 80 | # 建立actor_target网络,并和actor参数一致,不能训练 81 | self.actor_target = get_actor([None, s_dim], name='_target') 82 | copy_para(self.actor, self.actor_target) 83 | self.actor_target.eval() 84 | 85 | # 建立critic_target网络,并和critic参数一致,不能训练 86 | self.critic_target = get_critic([None, s_dim], [None, a_dim], name='_target') 87 | copy_para(self.critic, self.critic_target) 88 | self.critic_target.eval() 89 | 90 | # 建立ema,滑动平均值 91 | self.ema = tf.train.ExponentialMovingAverage(decay=1 - args.replace_tau) # soft replacement 92 | 93 | self.actor_opt = tf.optimizers.Adam(lr=args.lr_actor) 94 | self.critic_opt = tf.optimizers.Adam(lr=args.lr_critic) 95 | 96 | def ema_update(self): 97 | """ 98 | 滑动平均更新 99 | """ 100 | # 其实和之前的硬更新类似,不过在更新赋值之前,用一个ema.average。 101 | paras = self.actor.trainable_weights + self.critic.trainable_weights # 获取要更新的参数包括actor和critic的 102 | self.ema.apply(paras) # 主要是建立影子参数 103 | for i, j in zip(self.actor_target.trainable_weights + self.critic_target.trainable_weights, paras): 104 | i.assign(self.ema.average(j)) # 用滑动平均赋值 105 | 106 | def soft_replace(self): 107 | paras = self.actor.trainable_weights + self.critic.trainable_weights # 获取要更新的参数包括actor和critic的 108 | for i, j in zip(self.actor_target.trainable_weights + self.critic_target.trainable_weights, paras): 109 | i.assign((1 - self.replace_tau) * i + self.replace_tau * j) 110 | 111 | def choose_action(self, s): 112 | """ 113 | Choose action 114 | :param s: state 115 | :return: act 116 | """ 117 | s = np.array([s], dtype=np.float32) 118 | a = self.actor(s) 119 | return a.numpy().squeeze() 120 | 121 | def learn(self): 122 | indices = np.random.choice(self.mem_size, size=self.batch_size) # 随机BATCH_SIZE个随机数 123 | bt = self.memory[indices, :] # 根据indices,选取数据bt,相当于随机 124 | bs = bt[:, :self.s_dim] # 从bt获得数据s 125 | ba = bt[:, self.s_dim:self.s_dim + self.a_dim] # 从bt获得数据a 126 | br = bt[:, self.s_dim + self.a_dim:-self.s_dim] # 从bt获得数据r 127 | bs_ = bt[:, -self.s_dim:] # 从bt获得数据s_ 128 | 129 | # Critic: 130 | # Critic更新和DQN很像,不过target不是argmax了,是用critic_target计算出来的。 131 | # br + GAMMA * q_ 132 | with tf.GradientTape() as tape: 133 | # batch_size*2 134 | a_ = self.actor_target(bs_) 135 | # batch_size*4 136 | q_ = self.critic_target([bs_, a_]) 137 | y = br + self.gamma * q_ 138 | q = self.critic([bs, ba]) 139 | td_error = tf.losses.mean_squared_error(tf.reshape(y, [-1]), tf.reshape(q, [-1])) 140 | c_grads = tape.gradient(td_error, self.critic.trainable_weights) 141 | self.critic_opt.apply_gradients(zip(c_grads, self.critic.trainable_weights)) 142 | 143 | # Actor: 144 | # Actor的目标就是获取最多Q值的。 145 | with tf.GradientTape() as tape: 146 | a = self.actor(bs) 147 | q = self.critic([bs, a]) 148 | a_loss = -tf.reduce_mean(q) # 【敲黑板】:注意这里用负号,是梯度上升!也就是离目标会越来越远的,就是越来越大。 149 | a_grads = tape.gradient(a_loss, self.actor.trainable_weights) 150 | self.actor_opt.apply_gradients(zip(a_grads, self.actor.trainable_weights)) 151 | 152 | # self.ema_update() 153 | self.soft_replace() 154 | return td_error.numpy(), -a_loss.numpy() 155 | 156 | # 保存s,a,r,s_ 157 | def store_transition(self, s, a, r, s_): 158 | """ 159 | Store data in data buffer 160 | :param s: state 161 | :param a: act 162 | :param r: reward 163 | :param s_: next state 164 | :return: None 165 | """ 166 | s = s.astype(np.float32) 167 | s_ = s_.astype(np.float32) 168 | # 把s, a, [r], s_横向堆叠 169 | transition = np.hstack((s, a, [r], s_)) 170 | 171 | # pointer是记录了曾经有多少数据进来。 172 | # index是记录当前最新进来的数据位置。 173 | # 所以是一个循环,当MEMORY_CAPACITY满了以后,index就重新在最底开始了 174 | index = self.pointer % self.mem_size # replace the old memory with new memory 175 | # 把transition,也就是s, a, [r], s_存进去。 176 | self.memory[index, :] = transition 177 | self.pointer += 1 178 | 179 | def save_ckpt(self, model_path, eps): 180 | """ 181 | save trained weights 182 | :return: None 183 | """ 184 | model_path = os.getcwd() + model_path 185 | if not os.path.exists(model_path): 186 | os.makedirs(model_path) 187 | tl.files.save_weights_to_hdf5('{}actor_{}.hdf5'.format(model_path, str(eps)), self.actor) 188 | tl.files.save_weights_to_hdf5('{}actor_target_{}.hdf5'.format(model_path, str(eps)), self.actor_target) 189 | tl.files.save_weights_to_hdf5('{}critic_{}.hdf5'.format(model_path, str(eps)), self.critic) 190 | tl.files.save_weights_to_hdf5('{}critic_target_{}.hdf5'.format(model_path, str(eps)), self.critic_target) 191 | 192 | def load_ckpt(self, version, eps): 193 | """ 194 | load trained weights 195 | :return: None 196 | """ 197 | model_path = os.getcwd() + '/{}/{}/'.format(version, 'models') 198 | tl.files.load_hdf5_to_weights_in_order('{}actor_{}.hdf5'.format(model_path, str(eps)), self.actor) 199 | tl.files.load_hdf5_to_weights_in_order('{}actor_target_{}.hdf5'.format(model_path, str(eps)), self.actor_target) 200 | tl.files.load_hdf5_to_weights_in_order('{}critic_{}.hdf5'.format(model_path, str(eps)), self.critic) 201 | tl.files.load_hdf5_to_weights_in_order('{}critic_target_{}.hdf5'.format(model_path, str(eps)), self.critic_target) -------------------------------------------------------------------------------- /env.py: -------------------------------------------------------------------------------- 1 | """ 2 | 无人机数据采集任务的路径规划算法的底层实现 3 | """ 4 | 5 | import sys 6 | import numpy as np 7 | import time 8 | import math 9 | import copy 10 | 11 | from common import * 12 | from tools import * 13 | from tf_model_perddqn import tf_model_perddqn 14 | from gym import spaces 15 | from gym.utils import seeding 16 | 17 | np.random.seed(1) 18 | 19 | if sys.version_info.major == 2: 20 | import Tkinter as tk 21 | else: 22 | import tkinter as tk 23 | 24 | 25 | # 定义无人机类 26 | class UAV(tk.Tk): 27 | def __init__(self, is_train=True, R_dc=10., R_eh=30.): 28 | super(UAV, self).__init__() 29 | # POI位置 30 | self.canvas = None 31 | self.np_random = None 32 | self.N_POI = N_S_ # 传感器数量 33 | self.dis = np.zeros(self.N_POI) # 距离的平方 34 | self.elevation = np.zeros(self.N_POI) # 仰角 35 | self.pro = np.zeros(self.N_POI) # 视距概率 36 | self.h = np.zeros(self.N_POI) # 信道增益 37 | self.N_UAV = 1 38 | self.max_speed = V 39 | self.H = H # 无人机飞行高度 10m 40 | self.energy = E 41 | self.X_min = 0 42 | self.Y_min = 0 43 | self.X_max = WIDTH * UNIT 44 | self.Y_max = HEIGHT * UNIT # 地图边界 45 | self.R_dc = R_dc # 水平覆盖距离 10m 46 | self.R_eh = R_eh # 水平覆盖距离 30m 47 | self.sdc = math.sqrt(pow(self.R_dc, 2) + pow(self.H, 2)) # 最大DC服务距离 48 | self.seh = math.sqrt(pow(self.R_eh, 2) + pow(self.H, 2)) # 最大EH服务距离 49 | self.noise = pow(10, -12) # 噪声功率为-90dbm 50 | self.AutoUAV = [] 51 | self.Aim = [] 52 | self.N_AIM = 1 # 选择服务的用户数 53 | self.FX = 0. 54 | self.time = 0 55 | self.has_visited = {} 56 | self.has_collected = 0. 57 | self.SoPcenter = np.random.randint(0, 1000, size=[self.N_POI, 2]) 58 | 59 | self.action_space = spaces.Box(low=np.array([0., -1.]), high=np.array([1., 1.]), 60 | dtype=np.float32) 61 | self.state_dim = STATE_DIM # 状态空间为每一个传感器的相对位置、预期剩余时间、缓存数据量以及当前电量 62 | self.state = np.zeros([N_S_, self.state_dim]) 63 | self.xy = np.zeros((self.N_UAV, 2)) # 无人机位置 64 | 65 | # 假设有4类传感器,即有4个不同的泊松参数,随机给传感器分配泊松参数 66 | CoLDA = np.random.randint(0, len(LDA), self.N_POI) 67 | self.lda = [LDA[CoLDA[i]] for i in range(self.N_POI)] # 给传感器们指定数据增长速度 68 | 69 | self.b_S_pos = np.random.randint(0, 1000, [N_S_, 2]) # 初始化传感器位置 70 | self.b_S_cache = np.random.randint(200, 800, N_S_) # 初始化传感器当前数据缓存量 71 | self.b_S_death = np.random.randint(4000, 8000, N_S_) # 初始化数据的死亡时间 72 | 73 | self.Fully_buffer = C 74 | self.N_Data_death = 0 # 数据死亡计数 75 | self.Q = np.array( 76 | [self.lda[i] * self.b_S_cache[i] / self.Fully_buffer for i in range(self.N_POI)]) 77 | self.idx_target = np.argmax(self.Q) 78 | self.updata = self.b_S_cache[self.idx_target] / self.Fully_buffer 79 | 80 | # 构建PER-DDQN网络 81 | self.model = tf_model_perddqn(is_trainable=is_train) 82 | 83 | self.title('MAP') 84 | self.geometry('{0}x{1}'.format(WIDTH * UNIT, HEIGHT * UNIT)) # Tkinter 的几何形状 85 | self.build_maze() 86 | 87 | # 创建地图 88 | def build_maze(self): 89 | # 创建画布 Canvas.白色背景,宽高。 90 | self.canvas = tk.Canvas(self, bg='white', width=WIDTH * UNIT, height=HEIGHT * UNIT) 91 | 92 | # 创建用户 93 | for i in range(self.N_POI): 94 | # 创建椭圆,指定起始位置。填充颜色 95 | if self.lda[i] == LDA[0]: 96 | self.canvas.create_oval( 97 | self.SoPcenter[i][0] - 5, self.SoPcenter[i][1] - 5, 98 | self.SoPcenter[i][0] + 5, self.SoPcenter[i][1] + 5, 99 | fill='pink') 100 | elif self.lda[i] == LDA[1]: 101 | self.canvas.create_oval( 102 | self.SoPcenter[i][0] - 5, self.SoPcenter[i][1] - 5, 103 | self.SoPcenter[i][0] + 5, self.SoPcenter[i][1] + 5, 104 | fill='blue') 105 | elif self.lda[i] == LDA[2]: 106 | self.canvas.create_oval( 107 | self.SoPcenter[i][0] - 5, self.SoPcenter[i][1] - 5, 108 | self.SoPcenter[i][0] + 5, self.SoPcenter[i][1] + 5, 109 | fill='green') 110 | elif self.lda[i] == LDA[3]: 111 | self.canvas.create_oval( 112 | self.SoPcenter[i][0] - 5, self.SoPcenter[i][1] - 5, 113 | self.SoPcenter[i][0] + 5, self.SoPcenter[i][1] + 5, 114 | fill='red') 115 | 116 | # 创建无人机 117 | self.xy = [[0., 0.]] 118 | 119 | for i in range(self.N_UAV): 120 | L_UAV = self.canvas.create_oval( 121 | self.xy[i][0] - R_d, self.xy[i][1] - R_d, 122 | self.xy[i][0] + R_d, self.xy[i][1] + R_d, 123 | fill='yellow') 124 | self.AutoUAV.append(L_UAV) 125 | 126 | # 用户选择 127 | pxy = self.SoPcenter[np.argmax(self.Q)] 128 | L_AIM = self.canvas.create_rectangle( 129 | pxy[0] - 10, pxy[1] - 10, 130 | pxy[0] + 10, pxy[1] + 10, 131 | fill='red') 132 | self.Aim.append(L_AIM) 133 | 134 | self.canvas.pack() 135 | 136 | def seed(self, seed=None): 137 | self.np_random, seed = seeding.np_random(seed) 138 | return [seed] 139 | 140 | # 重置,随机初始化无人机的位置 141 | def reset(self): 142 | self.render() 143 | for i in range(self.N_UAV): 144 | self.canvas.delete(self.AutoUAV[i]) 145 | self.AutoUAV = [] 146 | 147 | for i in range(len(self.Aim)): 148 | self.canvas.delete(self.Aim[i]) 149 | 150 | # 初始化无人机位置 151 | self.xy = [[0., 0.]] 152 | for i in range(self.N_UAV): 153 | L_UAV = self.canvas.create_oval( 154 | self.xy[i][0] - R_d, self.xy[i][1] - R_d, 155 | self.xy[i][0] + R_d, self.xy[i][1] + R_d, 156 | fill='yellow') 157 | self.AutoUAV.append(L_UAV) 158 | self.FX = 0. 159 | 160 | self.b_S_pos = np.random.randint(0, 1000, [N_S_, 2]) # 初始化传感器位置 161 | self.b_S_cache = np.random.randint(200, 800, N_S_) # 初始化传感器当前数据缓存量 162 | self.b_S_death = np.random.randint(4000, 8000, N_S_) # 初始化数据的死亡时间 163 | 164 | self.N_Data_death = 0 # 数据死亡计数 165 | self.Q = np.array([self.lda[i] * self.b_S_cache[i] / self.Fully_buffer for i in range(self.N_POI)]) # 数据收集优先级 166 | 167 | # 初始化状态空间值 168 | self.idx_target = np.argmax(self.Q) 169 | self.updata = self.b_S_cache[self.idx_target] / self.Fully_buffer 170 | self.pxy = self.SoPcenter[self.idx_target] # 初始选择优先级最大的 171 | 172 | L_AIM = self.canvas.create_rectangle( 173 | self.pxy[0] - 10, self.pxy[1] - 10, 174 | self.pxy[0] + 10, self.pxy[1] + 10, 175 | fill='red') 176 | self.Aim.append(L_AIM) 177 | 178 | self.state = np.zeros([N_S_, self.state_dim]) 179 | for i in range(N_S_): 180 | self.state[i][0] = calc_distance(self.xy, [[self.b_S_pos[i][0], self.b_S_pos[i][1]]]) 181 | self.state[i][1] = b_S_death[i] - self.time 182 | self.state[i][2] = b_S_cache[i] 183 | self.state[i][3] = self.energy 184 | 185 | return self.state 186 | 187 | # 传入当前状态和输入动作输出下一个状态和奖励 188 | def step_move(self, action): 189 | 190 | state_ = np.zeros([N_S_, self.state_dim]) 191 | xy_ = [[self.b_S_pos[action][0], self.b_S_pos[action][1]]] 192 | D = calc_distance(self.xy, xy_) 193 | 194 | ec = 0 # 能量消耗 195 | ec += PF * D / self.max_speed 196 | if self.b_S_death[action] >= self.time + D / self.max_speed: 197 | ec += (PH + PT) * self.b_S_cache[action] / trans_speed() # 悬停能耗 + 传输能耗 198 | else: 199 | ec += 0 200 | 201 | for i in range(self.N_UAV): 202 | self.canvas.move(self.AutoUAV[i], xy_[i][0] - self.xy[i][0], xy_[i][1] - self.xy[i][1]) 203 | 204 | self.xy = xy_ 205 | for i in range(self.N_POI): # 数据溢出处理 206 | if self.b_S_cache[i] >= self.Fully_buffer: 207 | self.N_Data_death += 1 # 数据死亡用户计数 208 | self.b_S_cache[i] = self.Fully_buffer 209 | self.updata = self.b_S_cache[self.idx_target] / self.Fully_buffer 210 | 211 | # 获取下一步状态空间 212 | for i in range(N_S_): 213 | state_[i][0] = calc_distance(self.xy, [[self.b_S_pos[i][0], self.b_S_pos[i][1]]]) 214 | state_[i][1] = b_S_death[i] - self.time 215 | state_[i][2] = b_S_cache[i] 216 | state_[i][3] = self.energy 217 | 218 | reward = get_reward(D, self.b_S_death[action] - self.time - D / self.max_speed) 219 | self.Q_dis() 220 | 221 | Done = False 222 | if self.b_S_death[action] < self.time + D / self.max_speed: 223 | reward -= LAMBDA * D 224 | if self.energy <= 0 or len(self.has_visited) == N_S_: 225 | Done = True 226 | reward -= 1 - self.has_collected / sum(b_S_cache) 227 | # 只给目标用户收集数据 228 | 229 | self.model.insert_experience(self.state[action], reward, Done, state_[action]) 230 | 231 | # print(self.energy) 232 | self.state = state_ 233 | self.energy -= ec 234 | self.time += D / self.max_speed 235 | self.has_visited[action] = True 236 | self.has_collected += self.b_S_cache[action] 237 | 238 | return state_, reward, Done 239 | 240 | def step_hover(self, hover_time): 241 | # 无人机不动,所以是s[:5]不变 242 | self.N_Data_death = 0 # 记录每时隙数据溢出用户数 243 | self.b_S_cache += [np.random.poisson(self.lda[i]) * hover_time for i in range(self.N_POI)] # 传感器数据缓存量更新 244 | for i in range(self.N_POI): # 数据溢出处理 245 | if self.b_S_cache[i] >= self.Fully_buffer: 246 | self.N_Data_death += 1 # 数据溢出用户计数 247 | self.b_S_cache[i] = self.Fully_buffer 248 | self.updata = self.b_S_cache[self.idx_target] / self.Fully_buffer 249 | self.state[5] = self.N_Data_death / self.N_POI # 数据死亡占比 250 | 251 | # 每次无人机更新位置后,计算无人机与所有用户的距离与仰角,以及路径增益 252 | def Q_dis(self): 253 | for i in range(self.N_POI): 254 | self.dis[i] = math.sqrt( 255 | pow(self.SoPcenter[i][0] - self.xy[0][0], 2) + pow(self.SoPcenter[i][1] - self.xy[0][1], 2) + pow( 256 | self.H, 2)) # 原始距离 257 | self.elevation[i] = 180 / math.pi * np.arcsin(self.H / self.dis[i]) # 仰角 258 | self.pro[i] = 1 / (1 + 10 * math.exp(-0.6 * (self.elevation[i] - 10))) # 视距概率 259 | self.h[i] = \ 260 | (self.pro[i] + (1 - self.pro[i]) * 0.2) * pow(self.dis[i], -2.3) * pow(10, -30 / 10) # 参考距离增益为-30db 261 | 262 | # 输入是10-4W~10-5W,输出是0~9.079muW 263 | def Non_linear_EH(self, Pr): 264 | if Pr == 0: 265 | return 0 266 | P_prac = Mj / (1 + math.exp(-aj * (Pr - bj))) 267 | Peh = (P_prac - Mj * Oj) / (1 - Oj) # 以W为单位 268 | return Peh * pow(10, 6) 269 | 270 | # 输入是10-4W~10-5W,输出是0~9.079muW 271 | def linear_EH(self, Pr): 272 | if Pr == 0: 273 | return 0 274 | return Pr * pow(10, 6) * 0.2 275 | 276 | # 重选目标用户 277 | def CHOOSE_AIM(self): 278 | for i in range(len(self.Aim)): 279 | self.canvas.delete(self.Aim[i]) 280 | 281 | # 重选目标用户 282 | self.Q = np.array([self.lda[i] * self.b_S_cache[i] / C for i in range(self.N_POI)]) # 数据收集优先级 283 | self.idx_target = np.argmax(self.Q) 284 | self.updata = self.b_S_cache[self.idx_target] / self.Fully_buffer 285 | self.pxy = self.SoPcenter[self.idx_target] 286 | L_AIM = self.canvas.create_rectangle( 287 | self.pxy[0] - 10, self.pxy[1] - 10, 288 | self.pxy[0] + 10, self.pxy[1] + 10, 289 | fill='red') 290 | self.Aim.append(L_AIM) 291 | 292 | self.state[:2] = (self.pxy - self.xy[0]).flatten() / 400. 293 | self.render() 294 | return self.state 295 | 296 | # 调用Tkinter的update方法, 0.01秒去走一步。 297 | def render(self, epcho_time=0.01): 298 | time.sleep(epcho_time) 299 | self.update() 300 | 301 | def sample_action(self): 302 | action, is_max_Q = self.model.get_decision(self.state) 303 | return action 304 | 305 | 306 | def update(): 307 | for epcho in range(30): 308 | env.reset() 309 | while True: 310 | env.render() 311 | action = env.sample_action() 312 | s, r, done = env.step_move(action) 313 | if done: 314 | break 315 | 316 | 317 | if __name__ == '__main__': 318 | env = UAV() 319 | env.after(10, update) 320 | env.mainloop() 321 | -------------------------------------------------------------------------------- /tf_model_perddqn.py: -------------------------------------------------------------------------------- 1 | """ 2 | 构建PER_DDQN网络 3 | """ 4 | 5 | import os 6 | import pickle 7 | 8 | import pandas as pd 9 | import tensorflow as tf 10 | import numpy as np 11 | 12 | from common import STATE_DIM, ACTION_DIM 13 | from tools import ucb_Q 14 | from numpy import random 15 | 16 | # os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 17 | 18 | # Hyper Parameters for PER-DDQN 19 | GAMMA = 0.99 # discount factor for target Q 20 | GREEDY_EPSILON = 0.01 21 | REPLAY_SIZE = 3200 # experience replay buffer size 22 | BATCH_SIZE = 32 # size of mini-batch 23 | REPLACE_TARGET_FREQ = 10 # frequency to update target Q network 24 | 25 | FEATURE_MAX = np.array( 26 | [1500, 8000, 800, 400], 27 | dtype=np.float32) 28 | FEATURE_MIN = np.array( 29 | [0, 0, 0, 0], 30 | dtype=np.float32) 31 | 32 | 33 | # 利用SumTree构建PER-DDQN 34 | class SumTree: 35 | data_pointer = 0 36 | 37 | def __init__(self, capacity): 38 | self.capacity = capacity # for all priority values 39 | self.tree = np.zeros(2 * capacity - 1) 40 | # [--------------Parent nodes-------------][-------leaves to recode priority-------] 41 | # size: capacity - 1 size: capacity 42 | self.data = np.zeros(capacity, dtype=object) # for all transitions 43 | # [--------------data frame-------------] 44 | # size: capacity 45 | 46 | def add(self, p, data): 47 | tree_idx = self.data_pointer + self.capacity - 1 48 | self.data[self.data_pointer] = data # update data_frame 49 | self.update(tree_idx, p) # update tree_frame 50 | 51 | self.data_pointer += 1 52 | if self.data_pointer >= self.capacity: # replace when exceed the capacity 53 | self.data_pointer = 0 54 | 55 | def update(self, tree_idx, p): 56 | change = p - self.tree[tree_idx] 57 | self.tree[tree_idx] = p 58 | # then propagate the change through tree 59 | while tree_idx != 0: # this method is faster than the recursive loop in the reference code 60 | tree_idx = (tree_idx - 1) // 2 61 | self.tree[tree_idx] += change 62 | 63 | def get_leaf(self, v): 64 | """ 65 | Tree structure and array storage: 66 | Tree index: 67 | 0 -> storing priority sum 68 | / \ 69 | 1 2 70 | / \ / \ 71 | 3 4 5 6 -> storing priority for transitions 72 | Array type for storing: 73 | [0,1,2,3,4,5,6] 74 | """ 75 | parent_idx = 0 76 | while True: # the while loop is faster than the method in the reference code 77 | cl_idx = 2 * parent_idx + 1 # this leaf's left and right kids 78 | cr_idx = cl_idx + 1 79 | if cl_idx >= len(self.tree): # reach bottom, end search 80 | leaf_idx = parent_idx 81 | break 82 | else: # downward search, always search for a higher priority node 83 | if v <= self.tree[cl_idx]: 84 | parent_idx = cl_idx 85 | else: 86 | v -= self.tree[cl_idx] 87 | parent_idx = cr_idx 88 | 89 | data_idx = leaf_idx - self.capacity + 1 90 | return leaf_idx, self.tree[leaf_idx], self.data[data_idx] 91 | 92 | @property 93 | def total_p(self): 94 | return self.tree[0] # the root 95 | 96 | 97 | # 经验回放池 98 | class Memory(object): # stored as ( s, a, r, s_ ) in SumTree 99 | epsilon = 0.01 # small amount to avoid zero priority 100 | alpha = 0.6 # [0~1] convert the importance of TD error to priority 101 | beta = 0.4 # importance-sampling, from initial value increasing to 1 102 | beta_increment_per_sampling = 0.001 103 | abs_err_upper = 1. # clipped abs error 104 | 105 | def __init__(self, capacity): 106 | self.tree = SumTree(capacity) 107 | 108 | def store(self, transition): 109 | max_p = np.max(self.tree.tree[-self.tree.capacity:]) 110 | if max_p == 0: 111 | max_p = self.abs_err_upper 112 | self.tree.add(max_p, transition) # set the max p for new p 113 | 114 | def sample(self, n): 115 | b_idx, b_memory, ISWeights = np.empty((n,), dtype=np.int32), \ 116 | np.empty((n,), dtype=object), np.empty((n, 1)) 117 | pri_seg = self.tree.total_p / n # priority segment 118 | self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) # max = 1 119 | 120 | min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight 121 | if min_prob == 0: 122 | min_prob = 0.00001 123 | for i in range(n): 124 | a, b = pri_seg * i, pri_seg * (i + 1) 125 | v = np.random.uniform(a, b) 126 | idx, p, data = self.tree.get_leaf(v) 127 | prob = p / self.tree.total_p 128 | ISWeights[i, 0] = np.power(prob / min_prob, -self.beta) 129 | b_idx[i], b_memory[i] = idx, data 130 | return b_idx, b_memory, ISWeights 131 | 132 | def batch_update(self, tree_idx, abs_errors): 133 | abs_errors += self.epsilon # convert to abs and avoid 0 134 | clipped_errors = np.minimum(abs_errors, self.abs_err_upper) 135 | ps = np.power(clipped_errors, self.alpha) 136 | for ti, p in zip(tree_idx, ps): 137 | self.tree.update(ti, p) 138 | 139 | 140 | class tf_model_perddqn: 141 | 142 | def __init__(self, 143 | learning_rate=1e-3, 144 | temperature=1e-3, 145 | is_trainable=True): 146 | 147 | self.target_replace_op = None 148 | self.learning_rate = learning_rate 149 | self.is_trainable = is_trainable 150 | self.temperature = temperature 151 | self.model_path_head = r'./models/' 152 | self.model_path_rear = r'/test_model/test' 153 | self.cnt = 0 154 | self.ucb_cnt = np.zeros(50) 155 | self.graph = tf.Graph() 156 | 157 | self.replay_total = 0 158 | # init some parameters 159 | self.time_step = 0 160 | self.epsilon = GREEDY_EPSILON 161 | self.state_dim = STATE_DIM 162 | self.action_dim = ACTION_DIM 163 | self.memory = Memory(capacity=REPLAY_SIZE) 164 | 165 | with self.graph.as_default(): 166 | self.create_Q_network() 167 | self.create_training_method() 168 | 169 | # Init session 170 | self.saver = tf.train.Saver() 171 | self.session = tf.InteractiveSession() 172 | self.session.run(tf.global_variables_initializer()) 173 | 174 | def create_Q_network(self): 175 | # input layer 176 | self.state_input = tf.placeholder("float", [None, self.state_dim]) 177 | self.ISWeights = tf.placeholder(tf.float32, [None, 1]) 178 | # network weights 179 | self.state_input = (self.state_input - FEATURE_MIN) / (FEATURE_MAX - FEATURE_MIN) 180 | self.state_input = tf.clip_by_value(self.state_input, 0, 1) 181 | with tf.variable_scope('current_net'): 182 | 183 | W1 = self.weight_variable([self.state_dim, 16]) 184 | b1 = self.bias_variable([16]) 185 | W2 = self.weight_variable([16, 64]) 186 | b2 = self.bias_variable([64]) 187 | W3 = self.weight_variable([64, self.action_dim]) 188 | b3 = self.bias_variable([self.action_dim]) 189 | 190 | # hidden layers 191 | h_layer1 = tf.nn.relu(tf.matmul(self.state_input, W1) + b1) 192 | h_layer2 = tf.nn.relu(tf.matmul(h_layer1, W2) + b2) 193 | 194 | # Q Value layer 195 | self.Q_value = tf.matmul(h_layer2, W3) + b3 196 | 197 | with tf.variable_scope('target_net'): 198 | W1t = self.weight_variable([self.state_dim, 16]) 199 | b1t = self.bias_variable([16]) 200 | W2t = self.weight_variable([16, 64]) 201 | b2t = self.bias_variable([64]) 202 | W3t = self.weight_variable([64, self.action_dim]) 203 | b3t = self.bias_variable([self.action_dim]) 204 | 205 | # hidden layers 206 | h_layer1t = tf.nn.relu(tf.matmul(self.state_input, W1t) + b1t) 207 | h_layer2t = tf.nn.relu(tf.matmul(h_layer1t, W2t) + b2t) 208 | 209 | # Q Value layer 210 | self.target_Q_value = tf.matmul(h_layer2t, W3) + b3 211 | 212 | with tf.variable_scope('predict_Q'): 213 | self.output_layer = self.target_Q_value 214 | self.probability = tf.nn.softmax( 215 | tf.reshape(self.output_layer, [-1]) / self.temperature) 216 | 217 | t_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='target_net') 218 | e_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='current_net') 219 | 220 | with tf.variable_scope('soft_replacement'): 221 | self.target_replace_op = [tf.assign(t, e) for t, e in zip(t_params, e_params)] 222 | 223 | def create_training_method(self): 224 | self.y_input = tf.placeholder("float", [None]) 225 | Q_action = self.Q_value 226 | Q_action = tf.reshape(Q_action, [-1]) 227 | self.cost = tf.reduce_mean(self.ISWeights * (tf.square(self.y_input - Q_action))) 228 | self.abs_errors = tf.abs(self.y_input - Q_action) 229 | self.optimizer = tf.train.AdamOptimizer(self.learning_rate).minimize(self.cost) 230 | 231 | def store_transition(self, ss, r, done, s_,): 232 | transition = np.hstack((ss, r, done, s_,)) 233 | self.memory.store(transition) # have high priority for newly arrived transition 234 | 235 | def perceive(self, sub_state, reward, done, next_state): 236 | self.store_transition(sub_state, reward, done, next_state) 237 | self.replay_total += 1 238 | if self.replay_total > BATCH_SIZE: 239 | self.train_Q_network() 240 | 241 | def train_Q_network(self): 242 | self.time_step += 1 243 | # Step 1: obtain random mini-batch from replay memory 244 | tree_idx, minibatch, ISWeights = self.memory.sample(BATCH_SIZE) 245 | 246 | sub_state_batch = [data[0:self.state_dim] for data in minibatch] 247 | reward_batch = [data[self.state_dim] for data in minibatch] 248 | next_state_batch = [data[self.state_dim+2:] for data in minibatch] 249 | 250 | # Step 2: calculate y 251 | 252 | target_Q_batch = [] 253 | for next_state in next_state_batch: 254 | next_state = next_state.reshape(-1, self.state_dim) 255 | current_Q_batch = self.Q_value.eval(feed_dict={self.state_input: next_state}) 256 | current_Q_batch = current_Q_batch.reshape(-1) 257 | max_action_next = np.argmax(current_Q_batch) 258 | target_Q = self.target_Q_value.eval( 259 | feed_dict={self.state_input: next_state[max_action_next][np.newaxis, :]}) 260 | target_Q_batch.append(target_Q[0, 0]) 261 | 262 | y_batch = [] 263 | for i in range(0, BATCH_SIZE): 264 | done = minibatch[i][self.state_dim+1] 265 | if done: 266 | y_batch.append(reward_batch[i]) 267 | else: 268 | target_Q_value = target_Q_batch[i] 269 | y_batch.append(reward_batch[i] + GAMMA * target_Q_value) 270 | 271 | self.optimizer.run(feed_dict={ 272 | self.y_input: y_batch, 273 | self.state_input: sub_state_batch, 274 | self.ISWeights: ISWeights 275 | }) 276 | _, abs_errors, _ = self.session.run([self.optimizer, self.abs_errors, self.cost], feed_dict={ 277 | self.y_input: y_batch, 278 | self.state_input: sub_state_batch, 279 | self.ISWeights: ISWeights 280 | }) 281 | self.memory.batch_update(tree_idx, abs_errors) # update priority 282 | 283 | def update_target_q_network(self, episode): 284 | # update target Q netowrk 285 | if episode % REPLACE_TARGET_FREQ == 0: 286 | self.session.run(self.target_replace_op) 287 | # print('episode '+str(episode) +', target Q network params replaced!') 288 | 289 | def weight_variable(self, shape): 290 | initial = tf.truncated_normal(shape) 291 | return tf.Variable(initial) 292 | 293 | def bias_variable(self, shape): 294 | initial = tf.constant(0.01, shape=shape) 295 | return tf.Variable(initial) 296 | # softmax最终输出层 297 | 298 | def save_model(self, model_name): 299 | model_path = self.model_path_head + model_name + self.model_path_rear 300 | return self.saver.save(self.session, model_path) 301 | 302 | def load_model(self, model_name): 303 | model_path = self.model_path_head + model_name + self.model_path_rear 304 | if not os.path.exists(model_path + '.index'): 305 | return 306 | tf.reset_default_graph() 307 | return self.saver.restore(self.session, model_path) 308 | 309 | def learn(self): 310 | # 计算每次迭代的td_error和损失 311 | """ 312 | indices = np.random.choice(self.mem_size, size=self.batch_size) # 随机BATCH_SIZE个随机数 313 | bt = self.memory[indices, :] # 根据indices,选取数据bt,相当于随机 314 | bs = bt[:, :self.s_dim] # 从bt获得数据s 315 | ba = bt[:, self.s_dim:self.s_dim + self.a_dim] # 从bt获得数据a 316 | br = bt[:, self.s_dim + self.a_dim:-self.s_dim] # 从bt获得数据r 317 | bs_ = bt[:, -self.s_dim:] # 从bt获得数据s_ 318 | with tf.GradientTape() as tape: 319 | # batch_size*2 320 | a_ = self.target(bs_) 321 | # batch_size*4 322 | q_ = self.target([bs_, a_]) 323 | y = br + self.gamma * q_ 324 | q = self.target([bs, ba]) 325 | td_error = tf.losses.mean_squared_error(tf.reshape(y, [-1]), tf.reshape(q, [-1])) 326 | c_grads = tape.gradient(td_error, self.target.trainable_weights) 327 | self.target.apply_gradients(zip(c_grads, self.target.trainable_weights)) 328 | 329 | with tf.GradientTape() as tape: 330 | a = self.actor(bs) 331 | q = self.critic([bs, a]) 332 | a_loss = -tf.reduce_mean(q) 333 | a_grads = tape.gradient(a_loss, self.actor.trainable_weights) 334 | self.actor_opt.apply_gradients(zip(a_grads, self.actor.trainable_weights)) 335 | self.ema_update() 336 | self.soft_replace() 337 | """ 338 | pass 339 | 340 | def predict(self, xs): 341 | """ 342 | 预测 343 | :param xs:特征 344 | :return:预测Q值 345 | """ 346 | y_predict = self.session.run(self.output_layer, feed_dict={self.state_input: xs}) 347 | return y_predict.reshape(-1) 348 | 349 | def get_probability(self, xs): 350 | """ 351 | 得到概率 352 | :param xs:特征 353 | :return:softmax概率概率 354 | """ 355 | 356 | return self.session.run(self.probability, feed_dict={self.state_input: xs}) 357 | 358 | def get_decision(self, xs): 359 | """ 360 | 做出决策 361 | :param xs: 特征 362 | :return:node(softmax策略后选择的转发节点) is_max_Q(是否为最大Q的点) 363 | """ 364 | prob = self.get_probability(xs) 365 | prob_list = list(enumerate(prob)) 366 | prob_list = sorted(prob_list, key=lambda x: x[1]) 367 | # 累加和 368 | accu_threshold = 0 369 | rand = random.rand() 370 | selected_next = prob_list[-1][0] 371 | is_select_max_Q = True 372 | ucb_max = 0 373 | 374 | # 在训练阶段使用改进UCB策略或者Greedy-epsilon选择最优节点 375 | if self.is_trainable: 376 | i = 0 377 | pos = 0 378 | sum_cnt = sum(self.ucb_cnt) 379 | for item in prob_list[::-1]: 380 | Q = item[1] + ucb_Q(self.ucb_cnt[i], sum_cnt) 381 | if ucb_max < Q: 382 | selected_next = item[0] 383 | pos = i 384 | ucb_max = Q 385 | i += 1 386 | self.ucb_cnt[pos] += 1 387 | 388 | # 在测试阶段使用softmax或者最大值策略 389 | else: 390 | rand = random.rand() 391 | 392 | for item in prob_list: 393 | accu_threshold += item[1] 394 | if rand <= accu_threshold: 395 | selected_next = item[0] 396 | break 397 | 398 | if len(prob_list) > 1 and selected_next != prob_list[-1][0]: 399 | is_select_max_Q = False 400 | return selected_next, is_select_max_Q 401 | 402 | def insert_experience(self, sub_state, reward, done, next_state): 403 | """ 404 | 插入经验池 405 | :return:None 406 | """ 407 | 408 | # path = r".\analysis\a.csv" 409 | # if not os.path.exists(path): 410 | # 411 | # col.to_csv(path, index=False) 412 | # 413 | # t = next_state.reshape(-1, self.state_dim) 414 | # res = np.vstack((sub_state, t)) 415 | # p = pd.DataFrame(res) 416 | # p.to_csv(path, mode='a', header=False, index=False) 417 | 418 | if self.is_trainable: 419 | self.perceive(sub_state, reward, done, next_state) 420 | self.cnt += 1 421 | self.update_target_q_network(self.cnt) 422 | 423 | 424 | 425 | 426 | 427 | 428 | --------------------------------------------------------------------------------