├── README.md
├── checkpoint
└── model.pt
├── class_env.py
├── error.svg
├── main.py
└── red.svg
/README.md:
--------------------------------------------------------------------------------
1 | ## 基于强化学习的空战对抗决策算法
2 |
3 | ### air-combat-Reinforcement-Learning
4 |
5 | 原作者实现时采用了Tensorflow 1.x ,鉴于TF1.x的晦涩难懂,作者花了一点时间使用pytorch 重新实现大佬的代码, 增加了图示 , 对于算法的具体细则(我不是很懂 !!!!!)。
6 |
7 |
8 | 感谢
9 | https://github.com/y8107928/air-combat-Reinforcement-Learning.git
10 |
11 | ### 以下内容为原作者的注解
12 |
13 |
14 | 利用值函数逼近网络设计无人机空战自主决策系统,采用epsilon贪婪策略,三层网络结构。
15 | 其中包含了无人机作为质点时的运动模型和动力学模型的建模。
16 | 由于无人机作战的动作是连续并且复杂的,本项目仅考虑俯仰角gamma(又叫航倾角)和航向角pusin的变化,并且离散的规定每次变化的幅度为10度,假定速度v为恒定值。根据飞机的运动模型,由俯仰角、航向角和速度可以推算出飞机位置的改变,即x,y,z三个方向的速度分量,在每一步中,根据这些分量变化位置position信息,posintion中的三个值为x,y,z坐标,是东北天坐标系下的坐标值。从坐标信息和角度信息以及速度信息,可以计算出两个飞机的相对作战态势state。
17 | 在上文中提到,我们的动作是仅对俯仰角和航向角进行改变,即增大,减少和不变,故两个角度的变化组合一共有3×3=9种动作。在每个态势下,都有9种动作可以选择,将这个态势下的9种动作将会产生的新的态势,作为网络的输入,网络的输出是9个数字,代表每个动作的值函数。
18 | 由于是无监督学习,故我们需要利用值函数的Bellman公式生成标签。本文利用时间差分思想,(时间差分的公式贴在此处)。
19 | 怎么生成标签呢?我们先进行一次完整的游戏,将这次游戏中所有的值都记录下来,然后在游戏结束之后,根据胜利(红方为我方)、失败和最大步数、超出范围等给出一个回报值,然后设定一个回报的衰减率,将这个回报值根据步数越靠前影响越小的规律添加在每一步的目标中。
20 | 胜利的标准是,在攻击距离内达到攻击角度,这个以后再讲。
21 | 循环进行以上过程。
22 |
23 |
24 | 在训练后的结果图:
25 |
成功对抗状态
26 |
27 | 
28 |
29 | 对抗超出包络
30 |
31 | 
32 |
33 | 在对方匀速直线运动向我方飞行的时候,我方先快速爬升到一定高度,然后俯冲向敌机。
34 | 动作集的设定很暴力,在以后可以根据一些算法对动作集进行一定的改变。
35 | #### 目前是初步的程序编写,最近比较忙,之后会不断更新和详解。 {原作者 说的!}
36 |
37 |
38 | class_env文件是环境类,包含了无人机建模和各种坐标角度与态势之间的转换
39 | main文件是主文件,有网络和主要的执行过程和测试过程
40 |
41 |
42 |
--------------------------------------------------------------------------------
/checkpoint/model.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linaom1214/RL_air-combat/e9c6fa83a9109a115f0cf7ef6a61f70b6e4f599f/checkpoint/model.pt
--------------------------------------------------------------------------------
/class_env.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | import random
5 | '''这一版不考虑速度的改变'''
6 | import numpy
7 | #创建环境类
8 | class AirCombat():
9 | def __init__(self):
10 | #初始化状态空间
11 | self.position_r = [130000,100000,3000]
12 | self.position_b = [130000,110000,3000]
13 | #初始化动作空间
14 | #执行一步动作,转换一次状态,返回回报值和下一个状态
15 | x_r = self.position_r[0]
16 | x_b = self.position_b[0]
17 | y_r = self.position_r[1]
18 | y_b = self.position_b[1]
19 | z_r = self.position_r[2]
20 | z_b = self.position_b[2]
21 | self.gamma_r = 0
22 | self.gamma_b = 0
23 |
24 | self.pusin_r = 0
25 | self.pusin_b = 180
26 |
27 | self.v_r = 250
28 | self.v_b = 250
29 | self.active_r = [self.v_r,self.gamma_r,self.pusin_r]
30 | self.active_b = [self.v_b,self.gamma_b,self.pusin_b]
31 |
32 |
33 | '''输入当前位置状态信息和选择后的动作信息,得到下一步的位置状态信息'''
34 | '''考虑敌机为简单的匀速直线运动'''
35 | def generate_next_position(self,position_r,position_b,action_now,action_b):
36 | x_r = position_r[0]
37 | x_b = position_b[0]
38 | y_r = position_r[1]
39 | y_b = position_b[1]
40 | z_r = position_r[2]
41 | z_b = position_b[2]
42 |
43 | v_r = action_now[0]
44 | gamma_r = action_now[1]
45 | pusin_r = action_now[2]
46 |
47 | v_b= action_b[0]
48 | gamma_b= action_b[1]
49 | pusin_b= action_b[2]
50 |
51 | x_r_ = v_r*math.cos(gamma_r*(3.141592653/180))*math.sin(pusin_r*(3.141592653/180))
52 | y_r_ = v_r*math.cos(gamma_r*(3.141592653/180))*math.cos(pusin_r*(3.141592653/180))
53 | z_r_ = v_r*math.sin(gamma_r*(3.141592653/180))
54 |
55 | x_b_ = v_b*math.cos(gamma_b*(3.141592653/180))*math.sin(pusin_b*(3.141592653/180))
56 |
57 | y_b_ = v_b*math.cos(gamma_b*(3.141592653/180))*math.cos(pusin_b*(3.141592653/180))
58 |
59 | z_b_ = v_b*math.sin(gamma_b*(3.141592653/180))
60 |
61 | x_r_next = x_r + x_r_
62 | y_r_next = y_r + y_r_
63 | z_r_next = z_r + z_r_
64 |
65 | x_b_next = x_b + x_b_
66 | y_b_next = y_b + y_b_
67 | z_b_next = z_b + z_b_
68 |
69 | position_r_next = [x_r_next,y_r_next,z_r_next]
70 | position_b_next = [x_b_next,y_b_next,z_b_next]
71 |
72 | return position_r_next,position_b_next
73 |
74 | '''输入当前的位置信息和速度角度等状态信息,得到其所对应的态势信息'''
75 | def generate_state(self,position_r,position_b,action_r,action_b):
76 | x_r = position_r[0]
77 | x_b = position_b[0]
78 | y_r = position_r[1]
79 | y_b = position_b[1]
80 | z_r = position_r[2]
81 | z_b = position_b[2]
82 | v_r= action_r[0]
83 | gamma_r= action_r[1]
84 | pusin_r= action_r[2]
85 | v_b= action_b[0]
86 | gamma_b= action_b[1]
87 | pusin_b= action_b[2]
88 | d = math.sqrt((x_r-x_b)**2+(y_r-y_b)**2+(z_r-z_b)**2)
89 | q_r = math.acos(((x_b-x_r)*math.cos(gamma_r*(3.141592653/180))*math.sin(pusin_r*(3.141592653/180))+(y_b-y_r)*math.cos(gamma_r*(3.141592653/180))*math.cos(pusin_r*(3.141592653/180))+(z_b-z_r)*math.sin(gamma_r*(3.141592653/180)))/d)
90 | q_r_ = q_r*(180/3.141592653)
91 | q_b = math.acos(((x_r-x_b)*math.cos(gamma_b)*math.sin(pusin_b)+(y_r-y_b)*math.cos(gamma_b)*math.cos(pusin_b)+(z_r-z_b)*math.sin(gamma_b))/d)
92 | q_b_ = q_b*(180/3.141592653)
93 | beta = math.acos(math.cos(gamma_r*(3.141592653/180))*math.sin(pusin_r*(3.141592653/180))*math.cos(gamma_b*(3.141592653/180))*math.sin(pusin_b*(3.141592653/180))+math.cos(gamma_r*(3.141592653/180))*math.cos(pusin_r*(3.141592653/180))*math.cos(gamma_b*(3.141592653/180))*math.cos(pusin_b*(3.141592653/180))+math.sin(gamma_r*(3.141592653/180))*math.sin(gamma_b*(3.141592653/180)))
94 | beta_ = beta*(180/3.141592653)
95 | delta_h = z_r-z_b
96 | delta_v2 = v_r**2-v_b**2
97 | v2 = v_r**2
98 | h = z_r
99 | taishi = [q_r_,q_b_,d,beta_,delta_h,delta_v2,v2,h]
100 | return taishi
101 |
102 | '''注意角度问题,尚未做出更改'''
103 | '''输入当前的状态以及对动作的选择,得到当前的状态对应动作的下一步状态'''
104 | def action(self,v_r,gamma_r,pusin_r,flag,choose):
105 |
106 | #向右为正方向
107 | gamma_r_increase = gamma_r + 10
108 | gamma_r_decrease = gamma_r - 10
109 | gamma_r_constant = gamma_r
110 | pusin_r_increase = pusin_r + 10
111 | pusin_r_decrease = pusin_r - 10
112 | pusin_r_constant = pusin_r
113 |
114 | action_1 = [v_r,gamma_r_increase,pusin_r_increase]
115 | action_2 = [v_r,gamma_r_increase,pusin_r_constant]
116 | action_3 = [v_r,gamma_r_increase,pusin_r_decrease]
117 | action_4 = [v_r,gamma_r_constant,pusin_r_increase]
118 | action_5 = [v_r,gamma_r_constant,pusin_r_constant]
119 | action_6 = [v_r,gamma_r_constant,pusin_r_decrease]
120 | action_7 = [v_r,gamma_r_decrease,pusin_r_increase]
121 | action_8 = [v_r,gamma_r_decrease,pusin_r_constant]
122 | action_9 = [v_r,gamma_r_decrease,pusin_r_decrease]
123 | if flag == True:
124 | return [action_1,action_2,action_3,action_4,action_5,action_6,action_7,action_8,action_9]
125 | else:
126 | if choose == 0:
127 | return action_1
128 | elif choose == 1:
129 | return action_2
130 | elif choose == 2:
131 | return action_3
132 | elif choose == 3:
133 | return action_4
134 | elif choose == 4:
135 | return action_5
136 | elif choose == 5:
137 | return action_6
138 | elif choose == 6:
139 | return action_7
140 | elif choose == 7:
141 | return action_8
142 | elif choose == 8:
143 | return action_9
144 |
145 | def action_b(self,action_b):
146 | v_b_ = action_b[0]
147 | gamma_b = action_b[1]
148 | pusin_b = action_b[2]
149 | action_b = [v_b_,gamma_b,pusin_b]
150 | return action_b
151 |
152 | def normalize(self,array):
153 | array[0] = array[0] / 200
154 | array[1] = array[1] / 200
155 | array[2] = array[2] / 20000
156 | array[3] = array[3] / 200
157 | array[4] = array[4] / 10000
158 | array[5] = array[5]
159 | array[6] = array[6] / 40000
160 | array[7] = array[7] / 10000
161 | return array
162 | #def normal(self,state_array):
163 | def position_clip(self,position_r_list,position_b_list):
164 | x_r = position_r_list[0]
165 | x_b = position_b_list[0]
166 | if x_r>(x_b+100) or x_r<(x_b-100):
167 | x_r = x_b+random.randint(0,10)
168 | position_r_list[0] = x_r
169 | return position_r_list
170 | def epsilon_greedy(self,prediction,epsilon):
171 | num = random.random()
172 | if num>epsilon:
173 | temp = np.argmax(prediction)
174 | else:
175 | temp = random.randint(0,8)
176 | return temp
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | # import tensorflow as tf
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from matplotlib import pyplot as plt
7 | from mpl_toolkits.mplot3d import Axes3D
8 | from class_env import AirCombat
9 |
10 | CHECKPOINT_DIR = './checkpoint/'
11 | EPOCH = 1000
12 | MAX_STEPS = 1000
13 | num_Action = 0
14 | K = 0.5
15 | alpha = 0.1
16 | gamma = 0.4
17 | EPSILON = 0.1
18 |
19 |
20 | # 判断胜利失败函数
21 | # fig = plt.figure()
22 | # ax = fig.add_subplot(111,projection='3d') #这种方法也可以画多个子图
23 | # 创建网络模型
24 | class Net(nn.Module):
25 | def __init__(self):
26 | super(Net, self).__init__()
27 | self.layer1 = nn.Linear(72, 100)
28 | self.layer2 = nn.Linear(100, 30)
29 | self.layer3 = nn.Linear(30, 9)
30 |
31 | def forward(self, x):
32 | x = torch.sigmoid(self.layer1(x))
33 | x = torch.sigmoid(self.layer2(x))
34 | x = self.layer3(x)
35 | return x
36 |
37 |
38 | def greedy(prediction, list):
39 | temp = np.argmax(prediction)
40 | action = list[temp]
41 | return action
42 |
43 |
44 | def main():
45 | model = Net()
46 | model = model.to('cuda')
47 | opt = torch.optim.Adam(model.parameters(), lr=0.001)
48 | loss_fun = nn.MSELoss()
49 | # NUMPY S0_A_List
50 | for num in range(EPOCH):
51 | X = []
52 | Y = []
53 | Z = []
54 | X_B = []
55 | Y_B = []
56 | Z_B = []
57 | print("第 %d 次游戏" % num)
58 | aircombat = AirCombat()
59 | position_r_0 = aircombat.position_r # [1,8]
60 | x_r = position_r_0[0]
61 | y_r = position_r_0[1]
62 | z_r = position_r_0[2]
63 | X.append(x_r)
64 | Y.append(y_r)
65 | Z.append(z_r)
66 |
67 | position_b_0 = aircombat.position_b
68 | x_b = position_b_0[0]
69 | y_b = position_b_0[1]
70 | z_b = position_b_0[2]
71 | X_B.append(x_b)
72 | Y_B.append(y_b)
73 | Z_B.append(z_b)
74 |
75 | action_r_0 = aircombat.active_r
76 | action_b = aircombat.active_b
77 | # 态势使用两方的状态和动力学信息求出,下面是第一个态势
78 | s0 = aircombat.generate_state(position_r_0, position_b_0, action_r_0, action_b)
79 | # position_r_next, position_b_next = aircombat.generate_next_position(position_r_0, position_b_0,
80 | # action_now=action_r_0)
81 | v_r = action_r_0[0]
82 | gamma_r = action_r_0[1]
83 | pusin_r = action_r_0[2]
84 | action_list = aircombat.action(v_r, gamma_r, pusin_r, flag=True, choose=1)
85 | action_b = aircombat.action_b(action_b) # 匀速直线运动,即保持原先状态不动
86 | '''对于当前态势的所有动作结果,即网络的输入'''
87 | position_r_state1, position_b_state1 = aircombat.generate_next_position(position_r_0, position_b_0,
88 | action_list[0], action_b)
89 | position_r_state2, position_b_state2 = aircombat.generate_next_position(position_r_0, position_b_0,
90 | action_list[1], action_b)
91 | position_r_state3, position_b_state3 = aircombat.generate_next_position(position_r_0, position_b_0,
92 | action_list[2], action_b)
93 | position_r_state4, position_b_state4 = aircombat.generate_next_position(position_r_0, position_b_0,
94 | action_list[3], action_b)
95 | position_r_state5, position_b_state5 = aircombat.generate_next_position(position_r_0, position_b_0,
96 | action_list[4], action_b)
97 | position_r_state6, position_b_state6 = aircombat.generate_next_position(position_r_0, position_b_0,
98 | action_list[5], action_b)
99 | position_r_state7, position_b_state7 = aircombat.generate_next_position(position_r_0, position_b_0,
100 | action_list[6], action_b)
101 | position_r_state8, position_b_state8 = aircombat.generate_next_position(position_r_0, position_b_0,
102 | action_list[7], action_b)
103 | position_r_state9, position_b_state9 = aircombat.generate_next_position(position_r_0, position_b_0,
104 | action_list[8], action_b)
105 |
106 | state_1 = aircombat.generate_state(position_r_state1, position_b_state1, action_list[0], action_b)
107 | state_2 = aircombat.generate_state(position_r_state2, position_b_state2, action_list[1], action_b)
108 | state_3 = aircombat.generate_state(position_r_state3, position_b_state3, action_list[2], action_b)
109 | state_4 = aircombat.generate_state(position_r_state4, position_b_state4, action_list[3], action_b)
110 | state_5 = aircombat.generate_state(position_r_state5, position_b_state5, action_list[4], action_b)
111 | state_6 = aircombat.generate_state(position_r_state6, position_b_state6, action_list[5], action_b)
112 | state_7 = aircombat.generate_state(position_r_state7, position_b_state7, action_list[6], action_b)
113 | state_8 = aircombat.generate_state(position_r_state8, position_b_state8, action_list[7], action_b)
114 | state_9 = aircombat.generate_state(position_r_state9, position_b_state9, action_list[8], action_b)
115 | state_1 = aircombat.normalize(state_1)
116 | state_2 = aircombat.normalize(state_2)
117 | state_3 = aircombat.normalize(state_3)
118 | state_4 = aircombat.normalize(state_4)
119 | state_5 = aircombat.normalize(state_5)
120 | state_6 = aircombat.normalize(state_6)
121 | state_7 = aircombat.normalize(state_7)
122 | state_8 = aircombat.normalize(state_8)
123 | state_9 = aircombat.normalize(state_9)
124 |
125 | start_input_list = [state_1, state_2, state_3, state_4, state_5, state_6, state_7, state_8, state_9]
126 |
127 | start_input_array = np.array(start_input_list)
128 |
129 | start_input = start_input_array.reshape((1, 72))
130 | state_list = []
131 | state_list.append(start_input)
132 |
133 | r_list = []
134 | j = 0
135 | prediction_list = []
136 | state_list = []
137 | # labels_0 = np.zeros([9])
138 | # labels_0 = tf.convert_to_tensor(labels_0)
139 | state_now = start_input
140 | # action_list_now = action_list
141 | t_f = True
142 | while (t_f == True): # 前向传播,生成结果
143 | input = torch.unsqueeze(torch.from_numpy(state_now).type(torch.FloatTensor),0).cuda()
144 | prediction = model(input)
145 | prediction = prediction.cpu().detach().numpy()
146 | # print('prediction',prediction)
147 | prediction_list.append(prediction)
148 | temp = aircombat.epsilon_greedy(prediction, EPSILON)
149 | # temp = np.argmax(prediction)
150 | action_now = aircombat.action(v_r, gamma_r, pusin_r, flag=False, choose=temp)
151 | position_r_now, position_b_now = aircombat.generate_next_position(position_r_0, position_b_0, action_now,
152 | action_b)
153 | # position_r_now = aircombat.position_clip(position_r_now,position_b_now)
154 | state_next = aircombat.generate_state(position_r_now, position_b_now, action_now, action_b)
155 |
156 | v_r = action_now[0]
157 | gamma_r = action_now[1]
158 | pusin_r = action_now[2]
159 | action_list_now = aircombat.action(v_r, gamma_r, pusin_r, flag=True, choose=1)
160 | # 匀速直线运动,即保持原先状态不动
161 | '''对于当前态势的所有动作结果,即网络的输入'''
162 | position_r_state1, position_b_state1 = aircombat.generate_next_position(position_r_now, position_b_now,
163 | action_list_now[0], action_b)
164 | position_r_state2, position_b_state2 = aircombat.generate_next_position(position_r_now, position_b_now,
165 | action_list_now[1], action_b)
166 | position_r_state3, position_b_state3 = aircombat.generate_next_position(position_r_now, position_b_now,
167 | action_list_now[2], action_b)
168 | position_r_state4, position_b_state4 = aircombat.generate_next_position(position_r_now, position_b_now,
169 | action_list_now[3], action_b)
170 | position_r_state5, position_b_state5 = aircombat.generate_next_position(position_r_now, position_b_now,
171 | action_list_now[4], action_b)
172 | position_r_state6, position_b_state6 = aircombat.generate_next_position(position_r_now, position_b_now,
173 | action_list_now[5], action_b)
174 | position_r_state7, position_b_state7 = aircombat.generate_next_position(position_r_now, position_b_now,
175 | action_list_now[6], action_b)
176 | position_r_state8, position_b_state8 = aircombat.generate_next_position(position_r_now, position_b_now,
177 | action_list_now[7], action_b)
178 | position_r_state9, position_b_state9 = aircombat.generate_next_position(position_r_now, position_b_now,
179 | action_list_now[8], action_b)
180 |
181 | state_1 = aircombat.generate_state(position_r_state1, position_b_state1, action_list_now[0], action_b)
182 | state_2 = aircombat.generate_state(position_r_state2, position_b_state2, action_list_now[1], action_b)
183 | state_3 = aircombat.generate_state(position_r_state3, position_b_state3, action_list_now[2], action_b)
184 | state_4 = aircombat.generate_state(position_r_state4, position_b_state4, action_list_now[3], action_b)
185 | state_5 = aircombat.generate_state(position_r_state5, position_b_state5, action_list_now[4], action_b)
186 | state_6 = aircombat.generate_state(position_r_state6, position_b_state6, action_list_now[5], action_b)
187 | state_7 = aircombat.generate_state(position_r_state7, position_b_state7, action_list_now[6], action_b)
188 | state_8 = aircombat.generate_state(position_r_state8, position_b_state8, action_list_now[7], action_b)
189 | state_9 = aircombat.generate_state(position_r_state9, position_b_state9, action_list_now[8], action_b)
190 | state_1 = aircombat.normalize(state_1)
191 | state_2 = aircombat.normalize(state_2)
192 | state_3 = aircombat.normalize(state_3)
193 | state_4 = aircombat.normalize(state_4)
194 | state_5 = aircombat.normalize(state_5)
195 | state_6 = aircombat.normalize(state_6)
196 | state_7 = aircombat.normalize(state_7)
197 | state_8 = aircombat.normalize(state_8)
198 | state_9 = aircombat.normalize(state_9)
199 | start_input_now = [state_1, state_2, state_3, state_4, state_5, state_6, state_7, state_8, state_9]
200 |
201 | state_now = np.array(start_input_now)
202 | state_now = state_now.reshape((1, 72))
203 | state_list.append(state_now)
204 | q_r = state_next[0]
205 | q_b = state_next[1]
206 | d = state_next[2]
207 | delta_h = state_next[4]
208 | position_r_0 = position_r_now
209 | position_b_0 = position_b_now
210 | x_r = position_r_now[0]
211 | y_r = position_r_now[1]
212 | z_r = position_r_now[2]
213 | x_b = position_b_now[0]
214 | y_b = position_b_now[1]
215 | z_b = position_b_now[2]
216 |
217 | X.append(x_r)
218 | Y.append(y_r)
219 | Z.append(z_r)
220 |
221 | X_B.append(x_b)
222 | Y_B.append(y_b)
223 | Z_B.append(z_b)
224 | x_instance = abs(x_r - x_b)
225 | j += 1
226 | print('第%d步' % j)
227 | print('态势:', state_next)
228 | print('红方位置:', x_r, y_r, z_r)
229 | print('红方速度角度:', action_now)
230 | print('蓝方位置:', x_b, y_b, z_b)
231 | print('蓝方速度角度:', action_b)
232 | if j == MAX_STEPS:
233 | print('达到最大步数')
234 | t_f = False
235 | R = -5
236 | if d < 2500:
237 | if q_r < 30 and q_b > 30 and delta_h > 100:
238 | print('red win!')
239 | t_f = False
240 | R = 10
241 | if q_r > 30 and q_b < 30 and delta_h < 100:
242 | print('blue win...')
243 | t_f = False
244 | R = -10
245 | if x_r > 200000 or x_r < 0 or y_r > 200000 or y_r < 0 or z_r > 11000 or z_r < 0 or x_b > 200000 or x_b < 0 or y_b > 200000 or y_b < 0 or z_b > 11000 or z_b < 0:
246 | print('超出作战范围')
247 | t_f = False
248 | R = -5
249 |
250 | X = np.array(X)
251 | Y = np.array(Y)
252 | Z = np.array(Z)
253 | X_B = np.array(X_B)
254 | Y_B = np.array(Y_B)
255 | Z_B = np.array(Z_B)
256 |
257 | # ax.scatter3D(X, Y, Z, cmap='Blues')
258 | # ax.plot3D(X, Y, Z, 'gray')
259 | # ax.scatter3D(X_B, Y_B, Z_B, cmap='Reds')
260 | # plt.show()
261 | # 记录一次学习过程中的所有回报,值函数等信息,以训练网络
262 | r_list = np.zeros(j)
263 | for i in range(j):
264 | r_list[i] = K ** (j - i) * R
265 | delta = np.zeros_like(prediction_list)
266 | for k in range(j - 1):
267 | delta[k] = prediction_list[k] + alpha * (r_list[k] + gamma * prediction_list[k + 1] - prediction_list[k])
268 | delta[j - 1] = 0.01
269 | target = np.zeros_like(prediction_list)
270 | for k in range(j):
271 | target[k] = prediction_list[k] + delta[k]
272 | print('训练中')
273 |
274 | for i in range(j): # 训练过程
275 | opt.zero_grad()
276 | state_list_ = torch.from_numpy(state_list[i]).type(torch.FloatTensor).cuda()
277 | pred = model(state_list_)
278 | target_ = torch.from_numpy(target[i]).cuda()
279 | loss = loss_fun(pred, target_)
280 | loss.backward()
281 | if num % 100 ==0 :
282 | torch.save(model.state_dict(), f'{CHECKPOINT_DIR}/model.pt')
283 | print('Write checkpoint at %s' % num)
284 |
285 |
286 | def test():
287 | model = Net()
288 | model.state_dict(torch.load(f'{CHECKPOINT_DIR}/model.pt'))
289 | model = model.to('cuda')
290 | # 对一切先初始化
291 | # fig = plt.figure()
292 | ax1 = plt.axes(projection='3d')
293 | X = []
294 | Y = []
295 | Z = []
296 | X_B = []
297 | Y_B = []
298 | Z_B = []
299 | aircombat = AirCombat()
300 | position_r_0 = aircombat.position_r # [1,8]
301 | action_r_0 = aircombat.active_r
302 | position_b_0 = aircombat.position_b
303 | action_b = aircombat.active_b
304 | # 态势使用两方的状态和动力学信息求出,下面是第一个态势
305 | s0 = aircombat.generate_state(position_r_0, position_b_0, action_r_0, action_b)
306 | # position_r_next, position_b_next = aircombat.generate_next_position(position_r_0, position_b_0,
307 | # action_now=action_r_0)
308 | v_r = action_r_0[0]
309 | gamma_r = action_r_0[1]
310 | pusin_r = action_r_0[2]
311 | action_list = aircombat.action(v_r, gamma_r, pusin_r, flag=True, choose=1)
312 | action_b = aircombat.action_b(action_b) # 匀速直线运动,即保持原先状态不动
313 | '''对于当前态势的所有动作结果,即网络的输入'''
314 | position_r_state1, position_b_state1 = aircombat.generate_next_position(position_r_0, position_b_0,
315 | action_list[0], action_b)
316 | position_r_state2, position_b_state2 = aircombat.generate_next_position(position_r_0, position_b_0,
317 | action_list[1], action_b)
318 | position_r_state3, position_b_state3 = aircombat.generate_next_position(position_r_0, position_b_0,
319 | action_list[2], action_b)
320 | position_r_state4, position_b_state4 = aircombat.generate_next_position(position_r_0, position_b_0,
321 | action_list[3], action_b)
322 | position_r_state5, position_b_state5 = aircombat.generate_next_position(position_r_0, position_b_0,
323 | action_list[4], action_b)
324 | position_r_state6, position_b_state6 = aircombat.generate_next_position(position_r_0, position_b_0,
325 | action_list[5], action_b)
326 | position_r_state7, position_b_state7 = aircombat.generate_next_position(position_r_0, position_b_0,
327 | action_list[6], action_b)
328 | position_r_state8, position_b_state8 = aircombat.generate_next_position(position_r_0, position_b_0,
329 | action_list[7], action_b)
330 | position_r_state9, position_b_state9 = aircombat.generate_next_position(position_r_0, position_b_0,
331 | action_list[8], action_b)
332 |
333 | state_1 = aircombat.generate_state(position_r_state1, position_b_state1, action_list[0], action_b)
334 | state_2 = aircombat.generate_state(position_r_state2, position_b_state2, action_list[1], action_b)
335 | state_3 = aircombat.generate_state(position_r_state3, position_b_state3, action_list[2], action_b)
336 | state_4 = aircombat.generate_state(position_r_state4, position_b_state4, action_list[3], action_b)
337 | state_5 = aircombat.generate_state(position_r_state5, position_b_state5, action_list[4], action_b)
338 | state_6 = aircombat.generate_state(position_r_state6, position_b_state6, action_list[5], action_b)
339 | state_7 = aircombat.generate_state(position_r_state7, position_b_state7, action_list[6], action_b)
340 | state_8 = aircombat.generate_state(position_r_state8, position_b_state8, action_list[7], action_b)
341 | state_9 = aircombat.generate_state(position_r_state9, position_b_state9, action_list[8], action_b)
342 | state_1 = aircombat.normalize(state_1)
343 | state_2 = aircombat.normalize(state_2)
344 | state_3 = aircombat.normalize(state_3)
345 | state_4 = aircombat.normalize(state_4)
346 | state_5 = aircombat.normalize(state_5)
347 | state_6 = aircombat.normalize(state_6)
348 | state_7 = aircombat.normalize(state_7)
349 | state_8 = aircombat.normalize(state_8)
350 | state_9 = aircombat.normalize(state_9)
351 | start_input_list = [state_1, state_2, state_3, state_4, state_5, state_6, state_7, state_8, state_9]
352 |
353 | start_input_array = np.array(start_input_list)
354 |
355 | start_input = start_input_array.reshape((1, 72))
356 | state_list = []
357 | state_list.append(start_input)
358 |
359 | r_list = []
360 | j = 0
361 | prediction_list = []
362 | state_list = []
363 | # labels_0 = np.zeros([9])
364 | # labels_0 = tf.convert_to_tensor(labels_0)
365 | state_now = start_input
366 |
367 | # action_list_now = action_list
368 | t_f = True
369 | while (t_f == True): # 前向传播,生成结果
370 | input = torch.unsqueeze(torch.from_numpy(state_now).type(torch.FloatTensor), 0).cuda()
371 | prediction = model(input)
372 | prediction = prediction.cpu().detach().numpy()
373 | # print('prediction',prediction)
374 | prediction_list.append(prediction)
375 | temp = np.argmax(prediction)
376 | action_now = aircombat.action(v_r, gamma_r, pusin_r, flag=False, choose=temp)
377 | position_r_now, position_b_now = aircombat.generate_next_position(position_r_0, position_b_0, action_now,
378 | action_b)
379 |
380 | state_next = aircombat.generate_state(position_r_now, position_b_now, action_now, action_b)
381 |
382 | v_r = action_now[0]
383 | gamma_r = action_now[1]
384 | pusin_r = action_now[2]
385 | action_list_now = aircombat.action(v_r, gamma_r, pusin_r, flag=True, choose=1)
386 | # 匀速直线运动,即保持原先状态不动
387 | '''对于当前态势的所有动作结果,即网络的输入'''
388 | position_r_state1, position_b_state1 = aircombat.generate_next_position(position_r_now, position_b_now,
389 | action_list_now[0], action_b)
390 | position_r_state2, position_b_state2 = aircombat.generate_next_position(position_r_now, position_b_now,
391 | action_list_now[1], action_b)
392 | position_r_state3, position_b_state3 = aircombat.generate_next_position(position_r_now, position_b_now,
393 | action_list_now[2], action_b)
394 | position_r_state4, position_b_state4 = aircombat.generate_next_position(position_r_now, position_b_now,
395 | action_list_now[3], action_b)
396 | position_r_state5, position_b_state5 = aircombat.generate_next_position(position_r_now, position_b_now,
397 | action_list_now[4], action_b)
398 | position_r_state6, position_b_state6 = aircombat.generate_next_position(position_r_now, position_b_now,
399 | action_list_now[5], action_b)
400 | position_r_state7, position_b_state7 = aircombat.generate_next_position(position_r_now, position_b_now,
401 | action_list_now[6], action_b)
402 | position_r_state8, position_b_state8 = aircombat.generate_next_position(position_r_now, position_b_now,
403 | action_list_now[7], action_b)
404 | position_r_state9, position_b_state9 = aircombat.generate_next_position(position_r_now, position_b_now,
405 | action_list_now[8], action_b)
406 |
407 | state_1 = aircombat.generate_state(position_r_state1, position_b_state1, action_list_now[0], action_b)
408 | state_2 = aircombat.generate_state(position_r_state2, position_b_state2, action_list_now[1], action_b)
409 | state_3 = aircombat.generate_state(position_r_state3, position_b_state3, action_list_now[2], action_b)
410 | state_4 = aircombat.generate_state(position_r_state4, position_b_state4, action_list_now[3], action_b)
411 | state_5 = aircombat.generate_state(position_r_state5, position_b_state5, action_list_now[4], action_b)
412 | state_6 = aircombat.generate_state(position_r_state6, position_b_state6, action_list_now[5], action_b)
413 | state_7 = aircombat.generate_state(position_r_state7, position_b_state7, action_list_now[6], action_b)
414 | state_8 = aircombat.generate_state(position_r_state8, position_b_state8, action_list_now[7], action_b)
415 | state_9 = aircombat.generate_state(position_r_state9, position_b_state9, action_list_now[8], action_b)
416 | state_1 = aircombat.normalize(state_1)
417 | state_2 = aircombat.normalize(state_2)
418 | state_3 = aircombat.normalize(state_3)
419 | state_4 = aircombat.normalize(state_4)
420 | state_5 = aircombat.normalize(state_5)
421 | state_6 = aircombat.normalize(state_6)
422 | state_7 = aircombat.normalize(state_7)
423 | state_8 = aircombat.normalize(state_8)
424 | state_9 = aircombat.normalize(state_9)
425 | start_input_now = [state_1, state_2, state_3, state_4, state_5, state_6, state_7, state_8, state_9]
426 |
427 | state_now = np.array(start_input_now)
428 | state_now = state_now.reshape((1, 72))
429 | state_list.append(state_now)
430 | q_r = state_next[0]
431 | q_b = state_next[1]
432 | d = state_next[2]
433 | delta_h = state_next[4]
434 | position_r_0 = position_r_now
435 | position_b_0 = position_b_now
436 | x_r = position_r_now[0]
437 | y_r = position_r_now[1]
438 | z_r = position_r_now[2]
439 | x_b = position_b_now[0]
440 | y_b = position_b_now[1]
441 | z_b = position_b_now[2]
442 |
443 | X.append(x_r)
444 | Y.append(y_r)
445 | Z.append(z_r)
446 |
447 | X_B.append(x_b)
448 | Y_B.append(y_b)
449 | Z_B.append(z_b)
450 |
451 | j += 1
452 | print('第%d步' % j)
453 | print('态势:', state_next)
454 | print('红方位置:', x_r, y_r, z_r)
455 | print('红方速度角度:', action_now)
456 | print('蓝方位置:', x_b, y_b, z_b)
457 | print('蓝方速度角度:', action_b)
458 | if j == MAX_STEPS:
459 | print('达到最大步数')
460 | t_f = False
461 | R = -5
462 | if d < 2500:
463 | if q_r < 30 and q_b > 30 and delta_h > 0:
464 | print('red win!')
465 | t_f = False
466 | R = 10
467 | if q_r > 30 and q_b < 30 and delta_h < 0:
468 | print('blue win...')
469 | t_f = False
470 | R = -10
471 | if x_r > 200000 or x_r < 0 or y_r > 200000 or y_r < 0 or z_r > 11000 or z_r < 0 or x_b > 200000 or x_b < 0 or y_b > 200000 or y_b < 0 or z_b > 11000 or z_b < 0:
472 | print('超出作战范围')
473 | t_f = False
474 | R = -5
475 | X = np.array(X)
476 | Y = np.array(Y)
477 | Z = np.array(Z)
478 | X_B = np.array(X_B)
479 | Y_B = np.array(Y_B)
480 | Z_B = np.array(Z_B)
481 |
482 | ax1.scatter3D(X_B, Y_B, Z_B, label='Blue') # , cmap='Reds'
483 | ax1.plot3D(X, Y, Z,'gray' ) #
484 | ax1.scatter3D(X, Y, Z, label='Red') # cmap='Blues'
485 | plt.legend()
486 |
487 | if R==10:
488 | flag='red'
489 | elif R==-10:
490 | flag='blue'
491 | else:
492 | flag='error'
493 | plt.savefig(f'{flag}.svg')
494 | plt.show()
495 | # 记录一次学习过程中的所有回报,值函数等信息,以训练网络
496 |
497 |
498 | if __name__ == '__main__':
499 | main()
500 | # test()
501 |
--------------------------------------------------------------------------------
/red.svg:
--------------------------------------------------------------------------------
1 |
2 |
4 |
5 |
1201 |
--------------------------------------------------------------------------------