├── README.md ├── checkpoint └── model.pt ├── class_env.py ├── error.svg ├── main.py └── red.svg /README.md: -------------------------------------------------------------------------------- 1 | ## 基于强化学习的空战对抗决策算法 2 | 3 | ### air-combat-Reinforcement-Learning 4 | 5 | 原作者实现时采用了Tensorflow 1.x ,鉴于TF1.x的晦涩难懂,作者花了一点时间使用pytorch 重新实现大佬的代码, 增加了图示 , 对于算法的具体细则(我不是很懂 !!!!!)。 6 | 7 | 8 | 感谢 9 | https://github.com/y8107928/air-combat-Reinforcement-Learning.git 10 | 11 | ### 以下内容为原作者的注解 12 |

13 | 14 | 利用值函数逼近网络设计无人机空战自主决策系统,采用epsilon贪婪策略,三层网络结构。 15 | 其中包含了无人机作为质点时的运动模型和动力学模型的建模。 16 | 由于无人机作战的动作是连续并且复杂的,本项目仅考虑俯仰角gamma(又叫航倾角)和航向角pusin的变化,并且离散的规定每次变化的幅度为10度,假定速度v为恒定值。根据飞机的运动模型,由俯仰角、航向角和速度可以推算出飞机位置的改变,即x,y,z三个方向的速度分量,在每一步中,根据这些分量变化位置position信息,posintion中的三个值为x,y,z坐标,是东北天坐标系下的坐标值。从坐标信息和角度信息以及速度信息,可以计算出两个飞机的相对作战态势state。 17 | 在上文中提到,我们的动作是仅对俯仰角和航向角进行改变,即增大,减少和不变,故两个角度的变化组合一共有3×3=9种动作。在每个态势下,都有9种动作可以选择,将这个态势下的9种动作将会产生的新的态势,作为网络的输入,网络的输出是9个数字,代表每个动作的值函数。 18 | 由于是无监督学习,故我们需要利用值函数的Bellman公式生成标签。本文利用时间差分思想,(时间差分的公式贴在此处)。 19 | 怎么生成标签呢?我们先进行一次完整的游戏,将这次游戏中所有的值都记录下来,然后在游戏结束之后,根据胜利(红方为我方)、失败和最大步数、超出范围等给出一个回报值,然后设定一个回报的衰减率,将这个回报值根据步数越靠前影响越小的规律添加在每一步的目标中。 20 | 胜利的标准是,在攻击距离内达到攻击角度,这个以后再讲。 21 | 循环进行以上过程。 22 | 23 | 24 | 在训练后的结果图: 25 |

成功对抗状态

26 | 27 | ![](red.svg) 28 | 29 |

对抗超出包络

30 | 31 | ![](error.svg) 32 | 33 | 在对方匀速直线运动向我方飞行的时候,我方先快速爬升到一定高度,然后俯冲向敌机。 34 | 动作集的设定很暴力,在以后可以根据一些算法对动作集进行一定的改变。 35 | #### 目前是初步的程序编写,最近比较忙,之后会不断更新和详解。 {原作者 说的!} 36 | 37 | 38 | class_env文件是环境类,包含了无人机建模和各种坐标角度与态势之间的转换 39 | main文件是主文件,有网络和主要的执行过程和测试过程 40 | 41 |

42 | -------------------------------------------------------------------------------- /checkpoint/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linaom1214/RL_air-combat/e9c6fa83a9109a115f0cf7ef6a61f70b6e4f599f/checkpoint/model.pt -------------------------------------------------------------------------------- /class_env.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | import random 5 | '''这一版不考虑速度的改变''' 6 | import numpy 7 | #创建环境类 8 | class AirCombat(): 9 | def __init__(self): 10 | #初始化状态空间 11 | self.position_r = [130000,100000,3000] 12 | self.position_b = [130000,110000,3000] 13 | #初始化动作空间 14 | #执行一步动作,转换一次状态,返回回报值和下一个状态 15 | x_r = self.position_r[0] 16 | x_b = self.position_b[0] 17 | y_r = self.position_r[1] 18 | y_b = self.position_b[1] 19 | z_r = self.position_r[2] 20 | z_b = self.position_b[2] 21 | self.gamma_r = 0 22 | self.gamma_b = 0 23 | 24 | self.pusin_r = 0 25 | self.pusin_b = 180 26 | 27 | self.v_r = 250 28 | self.v_b = 250 29 | self.active_r = [self.v_r,self.gamma_r,self.pusin_r] 30 | self.active_b = [self.v_b,self.gamma_b,self.pusin_b] 31 | 32 | 33 | '''输入当前位置状态信息和选择后的动作信息,得到下一步的位置状态信息''' 34 | '''考虑敌机为简单的匀速直线运动''' 35 | def generate_next_position(self,position_r,position_b,action_now,action_b): 36 | x_r = position_r[0] 37 | x_b = position_b[0] 38 | y_r = position_r[1] 39 | y_b = position_b[1] 40 | z_r = position_r[2] 41 | z_b = position_b[2] 42 | 43 | v_r = action_now[0] 44 | gamma_r = action_now[1] 45 | pusin_r = action_now[2] 46 | 47 | v_b= action_b[0] 48 | gamma_b= action_b[1] 49 | pusin_b= action_b[2] 50 | 51 | x_r_ = v_r*math.cos(gamma_r*(3.141592653/180))*math.sin(pusin_r*(3.141592653/180)) 52 | y_r_ = v_r*math.cos(gamma_r*(3.141592653/180))*math.cos(pusin_r*(3.141592653/180)) 53 | z_r_ = v_r*math.sin(gamma_r*(3.141592653/180)) 54 | 55 | x_b_ = v_b*math.cos(gamma_b*(3.141592653/180))*math.sin(pusin_b*(3.141592653/180)) 56 | 57 | y_b_ = v_b*math.cos(gamma_b*(3.141592653/180))*math.cos(pusin_b*(3.141592653/180)) 58 | 59 | z_b_ = v_b*math.sin(gamma_b*(3.141592653/180)) 60 | 61 | x_r_next = x_r + x_r_ 62 | y_r_next = y_r + y_r_ 63 | z_r_next = z_r + z_r_ 64 | 65 | x_b_next = x_b + x_b_ 66 | y_b_next = y_b + y_b_ 67 | z_b_next = z_b + z_b_ 68 | 69 | position_r_next = [x_r_next,y_r_next,z_r_next] 70 | position_b_next = [x_b_next,y_b_next,z_b_next] 71 | 72 | return position_r_next,position_b_next 73 | 74 | '''输入当前的位置信息和速度角度等状态信息,得到其所对应的态势信息''' 75 | def generate_state(self,position_r,position_b,action_r,action_b): 76 | x_r = position_r[0] 77 | x_b = position_b[0] 78 | y_r = position_r[1] 79 | y_b = position_b[1] 80 | z_r = position_r[2] 81 | z_b = position_b[2] 82 | v_r= action_r[0] 83 | gamma_r= action_r[1] 84 | pusin_r= action_r[2] 85 | v_b= action_b[0] 86 | gamma_b= action_b[1] 87 | pusin_b= action_b[2] 88 | d = math.sqrt((x_r-x_b)**2+(y_r-y_b)**2+(z_r-z_b)**2) 89 | q_r = math.acos(((x_b-x_r)*math.cos(gamma_r*(3.141592653/180))*math.sin(pusin_r*(3.141592653/180))+(y_b-y_r)*math.cos(gamma_r*(3.141592653/180))*math.cos(pusin_r*(3.141592653/180))+(z_b-z_r)*math.sin(gamma_r*(3.141592653/180)))/d) 90 | q_r_ = q_r*(180/3.141592653) 91 | q_b = math.acos(((x_r-x_b)*math.cos(gamma_b)*math.sin(pusin_b)+(y_r-y_b)*math.cos(gamma_b)*math.cos(pusin_b)+(z_r-z_b)*math.sin(gamma_b))/d) 92 | q_b_ = q_b*(180/3.141592653) 93 | beta = math.acos(math.cos(gamma_r*(3.141592653/180))*math.sin(pusin_r*(3.141592653/180))*math.cos(gamma_b*(3.141592653/180))*math.sin(pusin_b*(3.141592653/180))+math.cos(gamma_r*(3.141592653/180))*math.cos(pusin_r*(3.141592653/180))*math.cos(gamma_b*(3.141592653/180))*math.cos(pusin_b*(3.141592653/180))+math.sin(gamma_r*(3.141592653/180))*math.sin(gamma_b*(3.141592653/180))) 94 | beta_ = beta*(180/3.141592653) 95 | delta_h = z_r-z_b 96 | delta_v2 = v_r**2-v_b**2 97 | v2 = v_r**2 98 | h = z_r 99 | taishi = [q_r_,q_b_,d,beta_,delta_h,delta_v2,v2,h] 100 | return taishi 101 | 102 | '''注意角度问题,尚未做出更改''' 103 | '''输入当前的状态以及对动作的选择,得到当前的状态对应动作的下一步状态''' 104 | def action(self,v_r,gamma_r,pusin_r,flag,choose): 105 | 106 | #向右为正方向 107 | gamma_r_increase = gamma_r + 10 108 | gamma_r_decrease = gamma_r - 10 109 | gamma_r_constant = gamma_r 110 | pusin_r_increase = pusin_r + 10 111 | pusin_r_decrease = pusin_r - 10 112 | pusin_r_constant = pusin_r 113 | 114 | action_1 = [v_r,gamma_r_increase,pusin_r_increase] 115 | action_2 = [v_r,gamma_r_increase,pusin_r_constant] 116 | action_3 = [v_r,gamma_r_increase,pusin_r_decrease] 117 | action_4 = [v_r,gamma_r_constant,pusin_r_increase] 118 | action_5 = [v_r,gamma_r_constant,pusin_r_constant] 119 | action_6 = [v_r,gamma_r_constant,pusin_r_decrease] 120 | action_7 = [v_r,gamma_r_decrease,pusin_r_increase] 121 | action_8 = [v_r,gamma_r_decrease,pusin_r_constant] 122 | action_9 = [v_r,gamma_r_decrease,pusin_r_decrease] 123 | if flag == True: 124 | return [action_1,action_2,action_3,action_4,action_5,action_6,action_7,action_8,action_9] 125 | else: 126 | if choose == 0: 127 | return action_1 128 | elif choose == 1: 129 | return action_2 130 | elif choose == 2: 131 | return action_3 132 | elif choose == 3: 133 | return action_4 134 | elif choose == 4: 135 | return action_5 136 | elif choose == 5: 137 | return action_6 138 | elif choose == 6: 139 | return action_7 140 | elif choose == 7: 141 | return action_8 142 | elif choose == 8: 143 | return action_9 144 | 145 | def action_b(self,action_b): 146 | v_b_ = action_b[0] 147 | gamma_b = action_b[1] 148 | pusin_b = action_b[2] 149 | action_b = [v_b_,gamma_b,pusin_b] 150 | return action_b 151 | 152 | def normalize(self,array): 153 | array[0] = array[0] / 200 154 | array[1] = array[1] / 200 155 | array[2] = array[2] / 20000 156 | array[3] = array[3] / 200 157 | array[4] = array[4] / 10000 158 | array[5] = array[5] 159 | array[6] = array[6] / 40000 160 | array[7] = array[7] / 10000 161 | return array 162 | #def normal(self,state_array): 163 | def position_clip(self,position_r_list,position_b_list): 164 | x_r = position_r_list[0] 165 | x_b = position_b_list[0] 166 | if x_r>(x_b+100) or x_r<(x_b-100): 167 | x_r = x_b+random.randint(0,10) 168 | position_r_list[0] = x_r 169 | return position_r_list 170 | def epsilon_greedy(self,prediction,epsilon): 171 | num = random.random() 172 | if num>epsilon: 173 | temp = np.argmax(prediction) 174 | else: 175 | temp = random.randint(0,8) 176 | return temp -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | # import tensorflow as tf 3 | import torch 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from matplotlib import pyplot as plt 7 | from mpl_toolkits.mplot3d import Axes3D 8 | from class_env import AirCombat 9 | 10 | CHECKPOINT_DIR = './checkpoint/' 11 | EPOCH = 1000 12 | MAX_STEPS = 1000 13 | num_Action = 0 14 | K = 0.5 15 | alpha = 0.1 16 | gamma = 0.4 17 | EPSILON = 0.1 18 | 19 | 20 | # 判断胜利失败函数 21 | # fig = plt.figure() 22 | # ax = fig.add_subplot(111,projection='3d') #这种方法也可以画多个子图 23 | # 创建网络模型 24 | class Net(nn.Module): 25 | def __init__(self): 26 | super(Net, self).__init__() 27 | self.layer1 = nn.Linear(72, 100) 28 | self.layer2 = nn.Linear(100, 30) 29 | self.layer3 = nn.Linear(30, 9) 30 | 31 | def forward(self, x): 32 | x = torch.sigmoid(self.layer1(x)) 33 | x = torch.sigmoid(self.layer2(x)) 34 | x = self.layer3(x) 35 | return x 36 | 37 | 38 | def greedy(prediction, list): 39 | temp = np.argmax(prediction) 40 | action = list[temp] 41 | return action 42 | 43 | 44 | def main(): 45 | model = Net() 46 | model = model.to('cuda') 47 | opt = torch.optim.Adam(model.parameters(), lr=0.001) 48 | loss_fun = nn.MSELoss() 49 | # NUMPY S0_A_List 50 | for num in range(EPOCH): 51 | X = [] 52 | Y = [] 53 | Z = [] 54 | X_B = [] 55 | Y_B = [] 56 | Z_B = [] 57 | print("第 %d 次游戏" % num) 58 | aircombat = AirCombat() 59 | position_r_0 = aircombat.position_r # [1,8] 60 | x_r = position_r_0[0] 61 | y_r = position_r_0[1] 62 | z_r = position_r_0[2] 63 | X.append(x_r) 64 | Y.append(y_r) 65 | Z.append(z_r) 66 | 67 | position_b_0 = aircombat.position_b 68 | x_b = position_b_0[0] 69 | y_b = position_b_0[1] 70 | z_b = position_b_0[2] 71 | X_B.append(x_b) 72 | Y_B.append(y_b) 73 | Z_B.append(z_b) 74 | 75 | action_r_0 = aircombat.active_r 76 | action_b = aircombat.active_b 77 | # 态势使用两方的状态和动力学信息求出,下面是第一个态势 78 | s0 = aircombat.generate_state(position_r_0, position_b_0, action_r_0, action_b) 79 | # position_r_next, position_b_next = aircombat.generate_next_position(position_r_0, position_b_0, 80 | # action_now=action_r_0) 81 | v_r = action_r_0[0] 82 | gamma_r = action_r_0[1] 83 | pusin_r = action_r_0[2] 84 | action_list = aircombat.action(v_r, gamma_r, pusin_r, flag=True, choose=1) 85 | action_b = aircombat.action_b(action_b) # 匀速直线运动,即保持原先状态不动 86 | '''对于当前态势的所有动作结果,即网络的输入''' 87 | position_r_state1, position_b_state1 = aircombat.generate_next_position(position_r_0, position_b_0, 88 | action_list[0], action_b) 89 | position_r_state2, position_b_state2 = aircombat.generate_next_position(position_r_0, position_b_0, 90 | action_list[1], action_b) 91 | position_r_state3, position_b_state3 = aircombat.generate_next_position(position_r_0, position_b_0, 92 | action_list[2], action_b) 93 | position_r_state4, position_b_state4 = aircombat.generate_next_position(position_r_0, position_b_0, 94 | action_list[3], action_b) 95 | position_r_state5, position_b_state5 = aircombat.generate_next_position(position_r_0, position_b_0, 96 | action_list[4], action_b) 97 | position_r_state6, position_b_state6 = aircombat.generate_next_position(position_r_0, position_b_0, 98 | action_list[5], action_b) 99 | position_r_state7, position_b_state7 = aircombat.generate_next_position(position_r_0, position_b_0, 100 | action_list[6], action_b) 101 | position_r_state8, position_b_state8 = aircombat.generate_next_position(position_r_0, position_b_0, 102 | action_list[7], action_b) 103 | position_r_state9, position_b_state9 = aircombat.generate_next_position(position_r_0, position_b_0, 104 | action_list[8], action_b) 105 | 106 | state_1 = aircombat.generate_state(position_r_state1, position_b_state1, action_list[0], action_b) 107 | state_2 = aircombat.generate_state(position_r_state2, position_b_state2, action_list[1], action_b) 108 | state_3 = aircombat.generate_state(position_r_state3, position_b_state3, action_list[2], action_b) 109 | state_4 = aircombat.generate_state(position_r_state4, position_b_state4, action_list[3], action_b) 110 | state_5 = aircombat.generate_state(position_r_state5, position_b_state5, action_list[4], action_b) 111 | state_6 = aircombat.generate_state(position_r_state6, position_b_state6, action_list[5], action_b) 112 | state_7 = aircombat.generate_state(position_r_state7, position_b_state7, action_list[6], action_b) 113 | state_8 = aircombat.generate_state(position_r_state8, position_b_state8, action_list[7], action_b) 114 | state_9 = aircombat.generate_state(position_r_state9, position_b_state9, action_list[8], action_b) 115 | state_1 = aircombat.normalize(state_1) 116 | state_2 = aircombat.normalize(state_2) 117 | state_3 = aircombat.normalize(state_3) 118 | state_4 = aircombat.normalize(state_4) 119 | state_5 = aircombat.normalize(state_5) 120 | state_6 = aircombat.normalize(state_6) 121 | state_7 = aircombat.normalize(state_7) 122 | state_8 = aircombat.normalize(state_8) 123 | state_9 = aircombat.normalize(state_9) 124 | 125 | start_input_list = [state_1, state_2, state_3, state_4, state_5, state_6, state_7, state_8, state_9] 126 | 127 | start_input_array = np.array(start_input_list) 128 | 129 | start_input = start_input_array.reshape((1, 72)) 130 | state_list = [] 131 | state_list.append(start_input) 132 | 133 | r_list = [] 134 | j = 0 135 | prediction_list = [] 136 | state_list = [] 137 | # labels_0 = np.zeros([9]) 138 | # labels_0 = tf.convert_to_tensor(labels_0) 139 | state_now = start_input 140 | # action_list_now = action_list 141 | t_f = True 142 | while (t_f == True): # 前向传播,生成结果 143 | input = torch.unsqueeze(torch.from_numpy(state_now).type(torch.FloatTensor),0).cuda() 144 | prediction = model(input) 145 | prediction = prediction.cpu().detach().numpy() 146 | # print('prediction',prediction) 147 | prediction_list.append(prediction) 148 | temp = aircombat.epsilon_greedy(prediction, EPSILON) 149 | # temp = np.argmax(prediction) 150 | action_now = aircombat.action(v_r, gamma_r, pusin_r, flag=False, choose=temp) 151 | position_r_now, position_b_now = aircombat.generate_next_position(position_r_0, position_b_0, action_now, 152 | action_b) 153 | # position_r_now = aircombat.position_clip(position_r_now,position_b_now) 154 | state_next = aircombat.generate_state(position_r_now, position_b_now, action_now, action_b) 155 | 156 | v_r = action_now[0] 157 | gamma_r = action_now[1] 158 | pusin_r = action_now[2] 159 | action_list_now = aircombat.action(v_r, gamma_r, pusin_r, flag=True, choose=1) 160 | # 匀速直线运动,即保持原先状态不动 161 | '''对于当前态势的所有动作结果,即网络的输入''' 162 | position_r_state1, position_b_state1 = aircombat.generate_next_position(position_r_now, position_b_now, 163 | action_list_now[0], action_b) 164 | position_r_state2, position_b_state2 = aircombat.generate_next_position(position_r_now, position_b_now, 165 | action_list_now[1], action_b) 166 | position_r_state3, position_b_state3 = aircombat.generate_next_position(position_r_now, position_b_now, 167 | action_list_now[2], action_b) 168 | position_r_state4, position_b_state4 = aircombat.generate_next_position(position_r_now, position_b_now, 169 | action_list_now[3], action_b) 170 | position_r_state5, position_b_state5 = aircombat.generate_next_position(position_r_now, position_b_now, 171 | action_list_now[4], action_b) 172 | position_r_state6, position_b_state6 = aircombat.generate_next_position(position_r_now, position_b_now, 173 | action_list_now[5], action_b) 174 | position_r_state7, position_b_state7 = aircombat.generate_next_position(position_r_now, position_b_now, 175 | action_list_now[6], action_b) 176 | position_r_state8, position_b_state8 = aircombat.generate_next_position(position_r_now, position_b_now, 177 | action_list_now[7], action_b) 178 | position_r_state9, position_b_state9 = aircombat.generate_next_position(position_r_now, position_b_now, 179 | action_list_now[8], action_b) 180 | 181 | state_1 = aircombat.generate_state(position_r_state1, position_b_state1, action_list_now[0], action_b) 182 | state_2 = aircombat.generate_state(position_r_state2, position_b_state2, action_list_now[1], action_b) 183 | state_3 = aircombat.generate_state(position_r_state3, position_b_state3, action_list_now[2], action_b) 184 | state_4 = aircombat.generate_state(position_r_state4, position_b_state4, action_list_now[3], action_b) 185 | state_5 = aircombat.generate_state(position_r_state5, position_b_state5, action_list_now[4], action_b) 186 | state_6 = aircombat.generate_state(position_r_state6, position_b_state6, action_list_now[5], action_b) 187 | state_7 = aircombat.generate_state(position_r_state7, position_b_state7, action_list_now[6], action_b) 188 | state_8 = aircombat.generate_state(position_r_state8, position_b_state8, action_list_now[7], action_b) 189 | state_9 = aircombat.generate_state(position_r_state9, position_b_state9, action_list_now[8], action_b) 190 | state_1 = aircombat.normalize(state_1) 191 | state_2 = aircombat.normalize(state_2) 192 | state_3 = aircombat.normalize(state_3) 193 | state_4 = aircombat.normalize(state_4) 194 | state_5 = aircombat.normalize(state_5) 195 | state_6 = aircombat.normalize(state_6) 196 | state_7 = aircombat.normalize(state_7) 197 | state_8 = aircombat.normalize(state_8) 198 | state_9 = aircombat.normalize(state_9) 199 | start_input_now = [state_1, state_2, state_3, state_4, state_5, state_6, state_7, state_8, state_9] 200 | 201 | state_now = np.array(start_input_now) 202 | state_now = state_now.reshape((1, 72)) 203 | state_list.append(state_now) 204 | q_r = state_next[0] 205 | q_b = state_next[1] 206 | d = state_next[2] 207 | delta_h = state_next[4] 208 | position_r_0 = position_r_now 209 | position_b_0 = position_b_now 210 | x_r = position_r_now[0] 211 | y_r = position_r_now[1] 212 | z_r = position_r_now[2] 213 | x_b = position_b_now[0] 214 | y_b = position_b_now[1] 215 | z_b = position_b_now[2] 216 | 217 | X.append(x_r) 218 | Y.append(y_r) 219 | Z.append(z_r) 220 | 221 | X_B.append(x_b) 222 | Y_B.append(y_b) 223 | Z_B.append(z_b) 224 | x_instance = abs(x_r - x_b) 225 | j += 1 226 | print('第%d步' % j) 227 | print('态势:', state_next) 228 | print('红方位置:', x_r, y_r, z_r) 229 | print('红方速度角度:', action_now) 230 | print('蓝方位置:', x_b, y_b, z_b) 231 | print('蓝方速度角度:', action_b) 232 | if j == MAX_STEPS: 233 | print('达到最大步数') 234 | t_f = False 235 | R = -5 236 | if d < 2500: 237 | if q_r < 30 and q_b > 30 and delta_h > 100: 238 | print('red win!') 239 | t_f = False 240 | R = 10 241 | if q_r > 30 and q_b < 30 and delta_h < 100: 242 | print('blue win...') 243 | t_f = False 244 | R = -10 245 | if x_r > 200000 or x_r < 0 or y_r > 200000 or y_r < 0 or z_r > 11000 or z_r < 0 or x_b > 200000 or x_b < 0 or y_b > 200000 or y_b < 0 or z_b > 11000 or z_b < 0: 246 | print('超出作战范围') 247 | t_f = False 248 | R = -5 249 | 250 | X = np.array(X) 251 | Y = np.array(Y) 252 | Z = np.array(Z) 253 | X_B = np.array(X_B) 254 | Y_B = np.array(Y_B) 255 | Z_B = np.array(Z_B) 256 | 257 | # ax.scatter3D(X, Y, Z, cmap='Blues') 258 | # ax.plot3D(X, Y, Z, 'gray') 259 | # ax.scatter3D(X_B, Y_B, Z_B, cmap='Reds') 260 | # plt.show() 261 | # 记录一次学习过程中的所有回报,值函数等信息,以训练网络 262 | r_list = np.zeros(j) 263 | for i in range(j): 264 | r_list[i] = K ** (j - i) * R 265 | delta = np.zeros_like(prediction_list) 266 | for k in range(j - 1): 267 | delta[k] = prediction_list[k] + alpha * (r_list[k] + gamma * prediction_list[k + 1] - prediction_list[k]) 268 | delta[j - 1] = 0.01 269 | target = np.zeros_like(prediction_list) 270 | for k in range(j): 271 | target[k] = prediction_list[k] + delta[k] 272 | print('训练中') 273 | 274 | for i in range(j): # 训练过程 275 | opt.zero_grad() 276 | state_list_ = torch.from_numpy(state_list[i]).type(torch.FloatTensor).cuda() 277 | pred = model(state_list_) 278 | target_ = torch.from_numpy(target[i]).cuda() 279 | loss = loss_fun(pred, target_) 280 | loss.backward() 281 | if num % 100 ==0 : 282 | torch.save(model.state_dict(), f'{CHECKPOINT_DIR}/model.pt') 283 | print('Write checkpoint at %s' % num) 284 | 285 | 286 | def test(): 287 | model = Net() 288 | model.state_dict(torch.load(f'{CHECKPOINT_DIR}/model.pt')) 289 | model = model.to('cuda') 290 | # 对一切先初始化 291 | # fig = plt.figure() 292 | ax1 = plt.axes(projection='3d') 293 | X = [] 294 | Y = [] 295 | Z = [] 296 | X_B = [] 297 | Y_B = [] 298 | Z_B = [] 299 | aircombat = AirCombat() 300 | position_r_0 = aircombat.position_r # [1,8] 301 | action_r_0 = aircombat.active_r 302 | position_b_0 = aircombat.position_b 303 | action_b = aircombat.active_b 304 | # 态势使用两方的状态和动力学信息求出,下面是第一个态势 305 | s0 = aircombat.generate_state(position_r_0, position_b_0, action_r_0, action_b) 306 | # position_r_next, position_b_next = aircombat.generate_next_position(position_r_0, position_b_0, 307 | # action_now=action_r_0) 308 | v_r = action_r_0[0] 309 | gamma_r = action_r_0[1] 310 | pusin_r = action_r_0[2] 311 | action_list = aircombat.action(v_r, gamma_r, pusin_r, flag=True, choose=1) 312 | action_b = aircombat.action_b(action_b) # 匀速直线运动,即保持原先状态不动 313 | '''对于当前态势的所有动作结果,即网络的输入''' 314 | position_r_state1, position_b_state1 = aircombat.generate_next_position(position_r_0, position_b_0, 315 | action_list[0], action_b) 316 | position_r_state2, position_b_state2 = aircombat.generate_next_position(position_r_0, position_b_0, 317 | action_list[1], action_b) 318 | position_r_state3, position_b_state3 = aircombat.generate_next_position(position_r_0, position_b_0, 319 | action_list[2], action_b) 320 | position_r_state4, position_b_state4 = aircombat.generate_next_position(position_r_0, position_b_0, 321 | action_list[3], action_b) 322 | position_r_state5, position_b_state5 = aircombat.generate_next_position(position_r_0, position_b_0, 323 | action_list[4], action_b) 324 | position_r_state6, position_b_state6 = aircombat.generate_next_position(position_r_0, position_b_0, 325 | action_list[5], action_b) 326 | position_r_state7, position_b_state7 = aircombat.generate_next_position(position_r_0, position_b_0, 327 | action_list[6], action_b) 328 | position_r_state8, position_b_state8 = aircombat.generate_next_position(position_r_0, position_b_0, 329 | action_list[7], action_b) 330 | position_r_state9, position_b_state9 = aircombat.generate_next_position(position_r_0, position_b_0, 331 | action_list[8], action_b) 332 | 333 | state_1 = aircombat.generate_state(position_r_state1, position_b_state1, action_list[0], action_b) 334 | state_2 = aircombat.generate_state(position_r_state2, position_b_state2, action_list[1], action_b) 335 | state_3 = aircombat.generate_state(position_r_state3, position_b_state3, action_list[2], action_b) 336 | state_4 = aircombat.generate_state(position_r_state4, position_b_state4, action_list[3], action_b) 337 | state_5 = aircombat.generate_state(position_r_state5, position_b_state5, action_list[4], action_b) 338 | state_6 = aircombat.generate_state(position_r_state6, position_b_state6, action_list[5], action_b) 339 | state_7 = aircombat.generate_state(position_r_state7, position_b_state7, action_list[6], action_b) 340 | state_8 = aircombat.generate_state(position_r_state8, position_b_state8, action_list[7], action_b) 341 | state_9 = aircombat.generate_state(position_r_state9, position_b_state9, action_list[8], action_b) 342 | state_1 = aircombat.normalize(state_1) 343 | state_2 = aircombat.normalize(state_2) 344 | state_3 = aircombat.normalize(state_3) 345 | state_4 = aircombat.normalize(state_4) 346 | state_5 = aircombat.normalize(state_5) 347 | state_6 = aircombat.normalize(state_6) 348 | state_7 = aircombat.normalize(state_7) 349 | state_8 = aircombat.normalize(state_8) 350 | state_9 = aircombat.normalize(state_9) 351 | start_input_list = [state_1, state_2, state_3, state_4, state_5, state_6, state_7, state_8, state_9] 352 | 353 | start_input_array = np.array(start_input_list) 354 | 355 | start_input = start_input_array.reshape((1, 72)) 356 | state_list = [] 357 | state_list.append(start_input) 358 | 359 | r_list = [] 360 | j = 0 361 | prediction_list = [] 362 | state_list = [] 363 | # labels_0 = np.zeros([9]) 364 | # labels_0 = tf.convert_to_tensor(labels_0) 365 | state_now = start_input 366 | 367 | # action_list_now = action_list 368 | t_f = True 369 | while (t_f == True): # 前向传播,生成结果 370 | input = torch.unsqueeze(torch.from_numpy(state_now).type(torch.FloatTensor), 0).cuda() 371 | prediction = model(input) 372 | prediction = prediction.cpu().detach().numpy() 373 | # print('prediction',prediction) 374 | prediction_list.append(prediction) 375 | temp = np.argmax(prediction) 376 | action_now = aircombat.action(v_r, gamma_r, pusin_r, flag=False, choose=temp) 377 | position_r_now, position_b_now = aircombat.generate_next_position(position_r_0, position_b_0, action_now, 378 | action_b) 379 | 380 | state_next = aircombat.generate_state(position_r_now, position_b_now, action_now, action_b) 381 | 382 | v_r = action_now[0] 383 | gamma_r = action_now[1] 384 | pusin_r = action_now[2] 385 | action_list_now = aircombat.action(v_r, gamma_r, pusin_r, flag=True, choose=1) 386 | # 匀速直线运动,即保持原先状态不动 387 | '''对于当前态势的所有动作结果,即网络的输入''' 388 | position_r_state1, position_b_state1 = aircombat.generate_next_position(position_r_now, position_b_now, 389 | action_list_now[0], action_b) 390 | position_r_state2, position_b_state2 = aircombat.generate_next_position(position_r_now, position_b_now, 391 | action_list_now[1], action_b) 392 | position_r_state3, position_b_state3 = aircombat.generate_next_position(position_r_now, position_b_now, 393 | action_list_now[2], action_b) 394 | position_r_state4, position_b_state4 = aircombat.generate_next_position(position_r_now, position_b_now, 395 | action_list_now[3], action_b) 396 | position_r_state5, position_b_state5 = aircombat.generate_next_position(position_r_now, position_b_now, 397 | action_list_now[4], action_b) 398 | position_r_state6, position_b_state6 = aircombat.generate_next_position(position_r_now, position_b_now, 399 | action_list_now[5], action_b) 400 | position_r_state7, position_b_state7 = aircombat.generate_next_position(position_r_now, position_b_now, 401 | action_list_now[6], action_b) 402 | position_r_state8, position_b_state8 = aircombat.generate_next_position(position_r_now, position_b_now, 403 | action_list_now[7], action_b) 404 | position_r_state9, position_b_state9 = aircombat.generate_next_position(position_r_now, position_b_now, 405 | action_list_now[8], action_b) 406 | 407 | state_1 = aircombat.generate_state(position_r_state1, position_b_state1, action_list_now[0], action_b) 408 | state_2 = aircombat.generate_state(position_r_state2, position_b_state2, action_list_now[1], action_b) 409 | state_3 = aircombat.generate_state(position_r_state3, position_b_state3, action_list_now[2], action_b) 410 | state_4 = aircombat.generate_state(position_r_state4, position_b_state4, action_list_now[3], action_b) 411 | state_5 = aircombat.generate_state(position_r_state5, position_b_state5, action_list_now[4], action_b) 412 | state_6 = aircombat.generate_state(position_r_state6, position_b_state6, action_list_now[5], action_b) 413 | state_7 = aircombat.generate_state(position_r_state7, position_b_state7, action_list_now[6], action_b) 414 | state_8 = aircombat.generate_state(position_r_state8, position_b_state8, action_list_now[7], action_b) 415 | state_9 = aircombat.generate_state(position_r_state9, position_b_state9, action_list_now[8], action_b) 416 | state_1 = aircombat.normalize(state_1) 417 | state_2 = aircombat.normalize(state_2) 418 | state_3 = aircombat.normalize(state_3) 419 | state_4 = aircombat.normalize(state_4) 420 | state_5 = aircombat.normalize(state_5) 421 | state_6 = aircombat.normalize(state_6) 422 | state_7 = aircombat.normalize(state_7) 423 | state_8 = aircombat.normalize(state_8) 424 | state_9 = aircombat.normalize(state_9) 425 | start_input_now = [state_1, state_2, state_3, state_4, state_5, state_6, state_7, state_8, state_9] 426 | 427 | state_now = np.array(start_input_now) 428 | state_now = state_now.reshape((1, 72)) 429 | state_list.append(state_now) 430 | q_r = state_next[0] 431 | q_b = state_next[1] 432 | d = state_next[2] 433 | delta_h = state_next[4] 434 | position_r_0 = position_r_now 435 | position_b_0 = position_b_now 436 | x_r = position_r_now[0] 437 | y_r = position_r_now[1] 438 | z_r = position_r_now[2] 439 | x_b = position_b_now[0] 440 | y_b = position_b_now[1] 441 | z_b = position_b_now[2] 442 | 443 | X.append(x_r) 444 | Y.append(y_r) 445 | Z.append(z_r) 446 | 447 | X_B.append(x_b) 448 | Y_B.append(y_b) 449 | Z_B.append(z_b) 450 | 451 | j += 1 452 | print('第%d步' % j) 453 | print('态势:', state_next) 454 | print('红方位置:', x_r, y_r, z_r) 455 | print('红方速度角度:', action_now) 456 | print('蓝方位置:', x_b, y_b, z_b) 457 | print('蓝方速度角度:', action_b) 458 | if j == MAX_STEPS: 459 | print('达到最大步数') 460 | t_f = False 461 | R = -5 462 | if d < 2500: 463 | if q_r < 30 and q_b > 30 and delta_h > 0: 464 | print('red win!') 465 | t_f = False 466 | R = 10 467 | if q_r > 30 and q_b < 30 and delta_h < 0: 468 | print('blue win...') 469 | t_f = False 470 | R = -10 471 | if x_r > 200000 or x_r < 0 or y_r > 200000 or y_r < 0 or z_r > 11000 or z_r < 0 or x_b > 200000 or x_b < 0 or y_b > 200000 or y_b < 0 or z_b > 11000 or z_b < 0: 472 | print('超出作战范围') 473 | t_f = False 474 | R = -5 475 | X = np.array(X) 476 | Y = np.array(Y) 477 | Z = np.array(Z) 478 | X_B = np.array(X_B) 479 | Y_B = np.array(Y_B) 480 | Z_B = np.array(Z_B) 481 | 482 | ax1.scatter3D(X_B, Y_B, Z_B, label='Blue') # , cmap='Reds' 483 | ax1.plot3D(X, Y, Z,'gray' ) # 484 | ax1.scatter3D(X, Y, Z, label='Red') # cmap='Blues' 485 | plt.legend() 486 | 487 | if R==10: 488 | flag='red' 489 | elif R==-10: 490 | flag='blue' 491 | else: 492 | flag='error' 493 | plt.savefig(f'{flag}.svg') 494 | plt.show() 495 | # 记录一次学习过程中的所有回报,值函数等信息,以训练网络 496 | 497 | 498 | if __name__ == '__main__': 499 | main() 500 | # test() 501 | -------------------------------------------------------------------------------- /red.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 2021-07-07T16:45:37.364146 11 | image/svg+xml 12 | 13 | 14 | Matplotlib v3.3.3, https://matplotlib.org/ 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 31 | 32 | 33 | 39 | 40 | 41 | 42 | 47 | 48 | 49 | 50 | 51 | 56 | 57 | 58 | 59 | 60 | 65 | 66 | 67 | 68 | 69 | 72 | 73 | 74 | 78 | 82 | 86 | 90 | 94 | 98 | 102 | 106 | 107 | 108 | 109 | 112 | 113 | 114 | 115 | 116 | 117 | 130 | 162 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 216 | 217 | 218 | 219 | 220 | 221 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 278 | 279 | 280 | 281 | 282 | 283 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 315 | 316 | 317 | 318 | 319 | 320 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 359 | 360 | 361 | 362 | 363 | 364 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 409 | 410 | 411 | 412 | 413 | 414 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 439 | 440 | 441 | 445 | 449 | 453 | 457 | 461 | 465 | 466 | 467 | 468 | 471 | 472 | 473 | 474 | 475 | 476 | 477 | 478 | 479 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 489 | 490 | 491 | 492 | 493 | 494 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 543 | 544 | 545 | 546 | 547 | 548 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 621 | 622 | 623 | 627 | 631 | 635 | 639 | 643 | 647 | 651 | 652 | 653 | 654 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 717 | 718 | 721 | 722 | 723 | 724 | 725 | 726 | 727 | 728 | 729 | 730 | 731 | 732 | 733 | 734 | 737 | 738 | 739 | 740 | 741 | 742 | 743 | 744 | 745 | 746 | 747 | 748 | 749 | 750 | 753 | 754 | 755 | 756 | 757 | 758 | 759 | 760 | 761 | 762 | 763 | 764 | 765 | 766 | 767 | 795 | 796 | 797 | 798 | 809 | 810 | 811 | 812 | 813 | 814 | 815 | 816 | 817 | 818 | 819 | 820 | 821 | 822 | 823 | 824 | 825 | 826 | 827 | 828 | 829 | 830 | 831 | 832 | 833 | 834 | 835 | 836 | 837 | 838 | 839 | 840 | 841 | 842 | 843 | 844 | 845 | 846 | 847 | 848 | 849 | 850 | 851 | 852 | 853 | 854 | 855 | 856 | 857 | 858 | 859 | 860 | 861 | 862 | 863 | 864 | 865 | 866 | 867 | 868 | 869 | 870 | 871 | 872 | 873 | 874 | 875 | 876 | 877 | 878 | 879 | 880 | 881 | 882 | 883 | 884 | 885 | 886 | 887 | 888 | 889 | 890 | 891 | 892 | 893 | 894 | 905 | 906 | 907 | 908 | 909 | 910 | 911 | 912 | 913 | 914 | 915 | 916 | 917 | 918 | 919 | 920 | 921 | 922 | 923 | 924 | 925 | 926 | 927 | 928 | 929 | 930 | 931 | 932 | 933 | 934 | 935 | 936 | 937 | 938 | 939 | 940 | 941 | 942 | 943 | 944 | 945 | 946 | 947 | 948 | 949 | 950 | 951 | 952 | 953 | 954 | 955 | 956 | 957 | 958 | 959 | 960 | 961 | 962 | 963 | 964 | 965 | 966 | 967 | 968 | 969 | 970 | 971 | 972 | 973 | 974 | 975 | 976 | 977 | 978 | 979 | 980 | 981 | 982 | 983 | 984 | 985 | 986 | 987 | 988 | 989 | 990 | 1001 | 1002 | 1003 | 1004 | 1015 | 1016 | 1017 | 1018 | 1019 | 1020 | 1021 | 1022 | 1023 | 1024 | 1053 | 1059 | 1080 | 1104 | 1105 | 1106 | 1107 | 1108 | 1109 | 1110 | 1111 | 1112 | 1113 | 1124 | 1125 | 1126 | 1127 | 1128 | 1129 | 1130 | 1131 | 1132 | 1133 | 1160 | 1186 | 1187 | 1188 | 1189 | 1190 | 1191 | 1192 | 1193 | 1194 | 1195 | 1196 | 1197 | 1198 | 1199 | 1200 | 1201 | --------------------------------------------------------------------------------