├── .idea
├── .gitignore
├── vcs.xml
├── inspectionProfiles
│ ├── profiles_settings.xml
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
└── MARL-UAVs-Targets-Tracking.iml
├── imgs
└── 2d-demo.png
├── src
├── __pycache__
│ └── environment.cpython-38.pyc
├── configs
│ ├── C-METHOD.yaml
│ ├── MAAC.yaml
│ ├── MAAC-G.yaml
│ └── MAAC-R.yaml
├── agent
│ ├── target.py
│ └── uav.py
├── utils
│ ├── data_util.py
│ ├── args_util.py
│ └── draw_util.py
├── models
│ ├── PMINet.py
│ └── actor_critic.py
├── main.py
├── environment.py
└── train.py
├── .gitignore
└── README.md
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 |
--------------------------------------------------------------------------------
/imgs/2d-demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidweidawang/MARL-UAVs-Targets-Tracking/HEAD/imgs/2d-demo.png
--------------------------------------------------------------------------------
/src/__pycache__/environment.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/davidweidawang/MARL-UAVs-Targets-Tracking/HEAD/src/__pycache__/environment.cpython-38.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/MARL-UAVs-Targets-Tracking.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/src/configs/C-METHOD.yaml:
--------------------------------------------------------------------------------
1 | exp_name: MAAC
2 | result_dir: ../results/MAAC
3 | # 第一个cuda device的编号, -1代表cpu
4 | first_device: 0
5 | # cuda device的数量, 使用cpu时无效
6 | gpus: 1
7 | seed: 42
8 | cooperative: 0
9 | environment:
10 | n_uav: 10
11 | m_targets: 10
12 | x_max: 2000
13 | y_max: 2000
14 | na: 12 # 离散化动作空间维数
15 |
16 | uav:
17 | dt: 1
18 | v_max: 20
19 | h_max: 6 # 表示 pi / 6
20 | dc: 500
21 | dp: 200
22 | alpha: 0.6
23 | beta: 0.2
24 | gamma: 0.2
25 |
26 | target:
27 | v_max: 5
28 | h_max: 6 # 表示 pi / 6
29 |
30 |
31 |
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | .idea/misc.xml
3 | .idea/MARL-UAVs-Targets-Tracking.iml
4 |
5 | src/agent/__pycache__/target.cpython-310.pyc
6 | src/agent/__pycache__/uav.cpython-310.pyc
7 | src/models/__pycache__/actor_critic.cpython-310.pyc
8 | src/models/__pycache__/PMINet.cpython-310.pyc
9 | src/utils/__pycache__/args_util.cpython-310.pyc
10 | src/utils/__pycache__/data_util.cpython-310.pyc
11 | src/utils/__pycache__/draw_util.cpython-310.pyc
12 |
13 | src/__pycache__/train.cpython-310.pyc
14 | src/__pycache__/environment.cpython-310.pyc
15 |
16 |
17 |
18 | /.idea
19 | *.iml
20 |
21 |
--------------------------------------------------------------------------------
/src/configs/MAAC.yaml:
--------------------------------------------------------------------------------
1 | exp_name: C-METHOD
2 | result_dir: ../results/C-METHOD
3 | # 第一个cuda device的编号, -1代表cpu
4 | first_device: 0
5 | # cuda device的数量, 使用cpu时无效
6 | gpus: 1
7 | seed: 42
8 | environment:
9 | n_uav: 10
10 | m_targets: 10
11 | x_max: 2000
12 | y_max: 2000
13 | na: 12 # 离散化动作空间维数
14 |
15 | uav:
16 | dt: 1
17 | v_max: 20
18 | h_max: 6 # 表示 pi / 6
19 | dc: 500
20 | dp: 200
21 | alpha: 0.6
22 | beta: 0.2
23 | gamma: 0.2
24 |
25 | target:
26 | v_max: 5
27 | h_max: 6 # 表示 pi / 6
28 |
29 | actor_critic:
30 | buffer_size: 1000000
31 | sample_size: 0 # 表示每次采样的大小为每个epoch的step数
32 | actor_lr: 1e-4
33 | critic_lr: 5e-4
34 | hidden_dim: 128
35 | gamma: 0.95
36 |
--------------------------------------------------------------------------------
/src/configs/MAAC-G.yaml:
--------------------------------------------------------------------------------
1 | exp_name: MAAC-G
2 | result_dir: ../results/MAAC-G
3 | # 第一个cuda device的编号, -1代表cpu
4 | first_device: 0
5 | # cuda device的数量, 使用cpu时无效
6 | gpus: 1
7 | seed: 42
8 | cooperative: 0.3
9 | environment:
10 | n_uav: 10
11 | m_targets: 10
12 | x_max: 2000
13 | y_max: 2000
14 | na: 12 # 离散化动作空间维数
15 |
16 | uav:
17 | dt: 1
18 | v_max: 20
19 | h_max: 6 # 表示 pi / 6
20 | dc: 500
21 | dp: 200
22 | alpha: 0.6
23 | beta: 0.2
24 | gamma: 0.2
25 |
26 | target:
27 | v_max: 5
28 | h_max: 6 # 表示 pi / 6
29 |
30 | actor_critic:
31 | buffer_size: 1000000
32 | sample_size: 0 # 表示每次采样的大小为每个epoch的step数
33 | actor_lr: 1e-4
34 | critic_lr: 5e-4
35 | hidden_dim: 128
36 | gamma: 0.95
37 |
--------------------------------------------------------------------------------
/src/configs/MAAC-R.yaml:
--------------------------------------------------------------------------------
1 | exp_name: MAAC-R
2 | result_dir: ../results/MAAC-R
3 | # 第一个cuda device的编号, -1代表cpu
4 | first_device: 0
5 | # cuda device的数量, 使用cpu时无效
6 | gpus: 1
7 | seed: 42
8 | cooperative: 0.3
9 | environment:
10 | n_uav: 10
11 | m_targets: 10
12 | x_max: 2000
13 | y_max: 2000
14 | na: 12 # 离散化动作空间维数
15 |
16 | uav:
17 | dt: 1
18 | v_max: 20
19 | h_max: 6 # 表示 pi / 6
20 | dc: 500
21 | dp: 200
22 | alpha: 0.6
23 | beta: 0.2
24 | gamma: 0.2
25 |
26 | target:
27 | v_max: 5
28 | h_max: 6 # 表示 pi / 6
29 |
30 | actor_critic:
31 | buffer_size: 1000000
32 | sample_size: 0 # 表示每次采样的大小为每个epoch的step数
33 | actor_lr: 1e-4
34 | critic_lr: 1e-4
35 | hidden_dim: 256
36 | gamma: 0.95
37 |
38 | pmi:
39 | hidden_dim: 128
40 | b2_size: 3000
41 | batch_size: 128
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MARL-UAVs-Targets-Tracking
2 | The implement and improvement of the paper “Improving multi-target cooperative tracking guidance for UAV swarms using multi-agent reinforcement learning”.
3 |
4 | 
5 |
6 | ### Environment
7 |
8 | You can simply use pip install to config the environment:
9 |
10 | ```sh
11 | pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
12 | pip install numpy matplotlib tqdm tensorboard scipy
13 | pip install imageio[ffmpeg]
14 | ```
15 |
16 | ### Run the code
17 |
18 | ```sh
19 | cd src
20 | python main.py
21 | ```
22 |
23 | ### ToDo List
24 |
25 | - [x] vanilla MAAC
26 | - [x] Actor-Critic framework
27 | - [x] MAAC-R
28 | - [x] reciprocal reward (with PMI network)
29 | - [x] MAAC-G
30 | - [x] receive the global reward
31 | - [x] 3D demo
32 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/agent/target.py:
--------------------------------------------------------------------------------
1 | from math import cos, sin, pi
2 | import random
3 |
4 |
5 | class TARGET:
6 | def __init__(self, x0: float, y0: float, h0: float, a0: float, v_max: float, h_max: float, dt):
7 | """
8 | :param x0: scalar
9 | :param y0: scalar
10 | :param h0: scalar
11 | :param v_max: scalar
12 | :param h_max: scalar
13 | """
14 | # the position, velocity and heading of this uav
15 | self.x = x0
16 | self.y = y0
17 | self.h = h0
18 | self.v_max = v_max
19 |
20 | # the max heading angular rate and the action of this uav
21 | self.h_max = h_max
22 | self.a = a0
23 |
24 | # time interval
25 | self.dt = dt
26 |
27 | def update_position(self, x_max, y_max) -> (float, float):
28 | """
29 | receive the action (heading angular rate), then update the current position
30 | :param y_max:
31 | :param x_max:
32 | :return:
33 | """
34 | self.a = random.uniform(-self.h_max, self.h_max)
35 | dx = self.dt * self.v_max * cos(self.h) # x 方向位移
36 | dy = self.dt * self.v_max * sin(self.h) # y 方向位移
37 | self.x += dx
38 | self.y += dy
39 |
40 | # if self.x > x_max:
41 | # self.x = x_max
42 | # if self.x < 0:
43 | # self.x = 0
44 | #
45 | # if self.y > y_max:
46 | # self.y = y_max
47 | # if self.y < 0:
48 | # self.y = 0
49 |
50 | # self.h += self.dt * self.a # 更新朝向角度
51 | # self.h = (self.h + pi) % (2 * pi) - pi # 确保朝向角度在 [-pi, pi) 范围内
52 | if 0 > self.y or self.y > y_max:
53 | self.h = -self.h
54 | elif self.x < 0 or self.x > x_max:
55 | if self.h > 0:
56 | self.h = pi-self.h
57 | else:
58 | self.h = -pi-self.h
59 |
60 | return self.x, self.y
61 |
--------------------------------------------------------------------------------
/src/utils/data_util.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import os.path
3 | import numpy as np
4 |
5 |
6 | def save_csv(config, return_list):
7 | """
8 | :param config:
9 | :param return_list:
10 | return_list = {
11 | 'return_list': self.return_list,
12 | 'target_tracking_return_list' :target_tracking_return_list,
13 | 'boundary_punishment_return_list':boundary_punishment_return_list,
14 | 'duplicate_tracking_punishment_return_list':duplicate_tracking_punishment_return_list
15 | }
16 | :return:
17 | """
18 | with open(os.path.join(config["save_dir"], 'return_list.csv'), mode='w', newline='') as file:
19 | writer = csv.writer(file)
20 | writer.writerow(['Reward']) # 写入表头
21 | for reward in return_list['return_list']:
22 | writer.writerow([reward])
23 |
24 | with open(os.path.join(config["save_dir"], 'target_tracking_return_list.csv'), mode='w', newline='') as file:
25 | writer = csv.writer(file)
26 | writer.writerow(['target_tracking']) # 写入表头
27 | for reward in return_list['target_tracking_return_list']:
28 | writer.writerow([reward])
29 |
30 | with open(os.path.join(config["save_dir"], 'boundary_punishment_return_list.csv'), mode='w', newline='') as file:
31 | writer = csv.writer(file)
32 | writer.writerow(['boundary_punishment']) # 写入表头
33 | for reward in return_list['boundary_punishment_return_list']:
34 | writer.writerow([reward])
35 |
36 | with open(os.path.join(config["save_dir"], 'duplicate_tracking_punishment_return_list.csv'), mode='w', newline='') as file:
37 | writer = csv.writer(file)
38 | writer.writerow(['duplicate_tracking_punishment']) # 写入表头
39 | for reward in return_list['duplicate_tracking_punishment_return_list']:
40 | writer.writerow([reward])
41 |
42 |
43 | def clip_and_normalize(val, floor, ceil, choice=1):
44 | if val < floor or val > ceil:
45 | val = max(val, floor)
46 | val = min(val, ceil)
47 | print("overstep in clip.")
48 | val = np.clip(val, floor, ceil)
49 | mid = (floor + ceil) / 2
50 | if choice == -1:
51 | val = (val - floor) / (ceil - floor) - 1 # (-1, 0)
52 | elif choice == 0:
53 | val = (val - floor) / (ceil - floor) # (0, 1)
54 | else:
55 | val = (val - mid) / (mid - floor) # (-1, 1)
56 | return val
57 |
--------------------------------------------------------------------------------
/src/utils/args_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import time
4 | import yaml
5 | import numpy as np
6 | import torch
7 |
8 |
9 | def get_config(config_file):
10 | """
11 | :param config_file: str, 超参数所在的文件位置
12 | :return: dict, 解析后的超参数字典
13 | """
14 | with open(config_file, 'r', encoding="UTF-8") as f:
15 | config = yaml.load(f, Loader=yaml.FullLoader)
16 |
17 | # set global seed of random, numpy and torch
18 | if 'seed' in config and config['seed'] is not None:
19 | np.random.seed(config['seed'])
20 | random.seed(config['seed'])
21 | torch.manual_seed(config['seed'])
22 |
23 | # create name for this experiment
24 | run_id = str(os.getpid())
25 | exp_name = '_'.join([
26 | config['exp_name'],
27 | time.strftime('%Y-%b-%d-%H-%M-%S'), run_id
28 | ])
29 |
30 | # save paths
31 | save_dir = os.path.join(config['result_dir'], exp_name)
32 | args_save_name = os.path.join(save_dir, 'args.yaml')
33 | config['save_dir'] = save_dir
34 |
35 | # snapshot hyperparameters
36 | mkdir(config['result_dir'])
37 | mkdir(save_dir)
38 | mkdir(os.path.join(save_dir, "actor"))
39 | mkdir(os.path.join(save_dir, "critic"))
40 | mkdir(os.path.join(save_dir, "pmi"))
41 | mkdir(os.path.join(save_dir, "animated"))
42 | mkdir(os.path.join(save_dir, "t_xy"))
43 | mkdir(os.path.join(save_dir, "u_xy"))
44 | mkdir(os.path.join(save_dir, "covered_target_num"))
45 |
46 | # create cuda devices
47 | set_device(config)
48 |
49 | with open(args_save_name, 'w') as f:
50 | yaml.dump(config, f, default_flow_style=False)
51 |
52 | return config
53 |
54 |
55 | def mkdir(folder):
56 | if not os.path.isdir(folder):
57 | os.makedirs(folder)
58 |
59 |
60 | def set_device(config):
61 | """
62 | :param config: dict
63 | :return: None
64 | """
65 | if config['gpus'] == -1 or not torch.cuda.is_available():
66 | os.environ["CUDA_VISIBLE_DEVICES"] = ""
67 | print('use cpu')
68 | config['devices'] = [torch.device('cpu')]
69 | else:
70 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(str(i) for i in range(config['first_device'],
71 | config['first_device'] + config['gpus']))
72 | print('use gpus: {}'.format(config['gpus']))
73 | config['devices'] = [torch.device('cuda', i) for i in range(config['first_device'],
74 | config['first_device'] + config['gpus'])]
75 |
76 |
77 | if __name__ == "__main__":
78 | example = get_config("../configs/MAAC.yaml")
79 | print(type(example))
80 | print(example)
81 |
--------------------------------------------------------------------------------
/src/models/PMINet.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | import numpy as np
3 | # 定义模型
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as f
7 | import os
8 |
9 |
10 | class CustomLoss(nn.Module):
11 | def __init__(self):
12 | super(CustomLoss, self).__init__()
13 |
14 | @staticmethod
15 | def forward(output1, output2):
16 | loss = torch.mean(torch.log(1 + torch.exp(-output1)) + torch.log(1 + torch.exp(output2)))
17 | return loss
18 |
19 |
20 | class PMINetwork(nn.Module):
21 | def __init__(self, comm_dim=5, obs_dim=4, boundary_state_dim=3, hidden_dim=64, b2_size=3000):
22 | super(PMINetwork, self).__init__()
23 | self.comm_dim = comm_dim
24 | self.obs_dim = obs_dim
25 | self.boundary_state_dim = boundary_state_dim
26 | self.hidden_dim = hidden_dim
27 | self.b2_size = b2_size
28 |
29 | self.fc_comm = nn.Linear(comm_dim, hidden_dim)
30 | self.bn_comm = nn.BatchNorm1d(hidden_dim)
31 | self.fc_obs = nn.Linear(obs_dim, hidden_dim)
32 | self.bn_obs = nn.BatchNorm1d(hidden_dim)
33 | self.fc_boundary_state = nn.Linear(boundary_state_dim, hidden_dim)
34 | self.bn_boundary_state = nn.BatchNorm1d(hidden_dim)
35 |
36 | self.fc1 = nn.Linear(hidden_dim * 3, hidden_dim)
37 | self.bn1 = nn.BatchNorm1d(hidden_dim)
38 | self.fc2 = nn.Linear(hidden_dim, 1)
39 | self.optimizer = optim.Adam(self.parameters(), lr=0.001)
40 |
41 | def forward(self, x):
42 | if isinstance(x, np.ndarray):
43 | x = torch.tensor(x, dtype=torch.float32)
44 | x = x.float()
45 | comm = x[:, :self.comm_dim]
46 | obs = x[:, self.comm_dim:self.comm_dim + self.obs_dim]
47 | boundary_state = x[:, self.comm_dim + self.obs_dim:self.comm_dim + self.obs_dim + self.boundary_state_dim]
48 |
49 | # Process each part with BatchNorm
50 | comm_vec = self.fc_comm(comm)
51 | comm_vec = f.relu(self.bn_comm(comm_vec))
52 | obs_vec = self.fc_obs(obs)
53 | obs_vec = f.relu(self.bn_obs(obs_vec))
54 | boundary_state_vec = self.fc_boundary_state(boundary_state)
55 | boundary_state_vec = f.relu(self.bn_boundary_state(boundary_state_vec))
56 |
57 | # Concatenate and process through further layers with BatchNorm
58 | combined = torch.cat((comm_vec, obs_vec, boundary_state_vec), dim=1)
59 | x = self.fc1(combined)
60 | x = f.relu(self.bn1(x))
61 | output = self.fc2(x)
62 | return output
63 |
64 | def inference(self, single_data):
65 | self.eval()
66 | if isinstance(single_data, np.ndarray):
67 | single_data = torch.tensor(single_data, dtype=torch.float32)
68 |
69 | if single_data.ndim == 1:
70 | single_data = single_data.unsqueeze(0)
71 | output = self.forward(single_data)
72 | return output.item() # Extract and return the single scalar value
73 |
74 | def train_pmi(self, config, train_data, n_uav):
75 | self.train()
76 | loss_function = CustomLoss()
77 | # train_data (timesteps*n_uav,12)
78 | timesteps = train_data.size(0) // n_uav
79 | train_data = train_data.view(timesteps, n_uav, 12)
80 | timestep_indices = torch.randint(low=0, high=timesteps, size=(self.b2_size,))
81 | uav_indices = torch.randint(low=0, high=n_uav, size=(self.b2_size, 2))
82 | selected_data = torch.zeros((self.b2_size, 2, 12))
83 | for i in range(self.b2_size):
84 | selected_data[i] = train_data[timestep_indices[i], uav_indices[i]]
85 |
86 | avg_loss = 0
87 | for i in range(self.b2_size // config["pmi"]["batch_size"]):
88 | self.optimizer.zero_grad()
89 | batch_data = selected_data[i * config["pmi"]["batch_size"]:(i + 1) * config["pmi"]["batch_size"]]
90 | input_1_2 = batch_data[:, 0].squeeze(1)
91 | input_1_3 = batch_data[:, 1].squeeze(1)
92 | output_1_2 = self.forward(input_1_2)
93 | output_1_3 = self.forward(input_1_3)
94 | loss = loss_function(output_1_2, output_1_3)
95 | avg_loss += abs(loss.item())
96 | # 反向传播和优化
97 | loss.backward()
98 | self.optimizer.step()
99 | avg_loss /= (self.b2_size // config["pmi"]["batch_size"])
100 | return avg_loss
101 |
102 | def save(self, save_dir, epoch_i):
103 | torch.save({
104 | 'model_state_dict': self.state_dict(),
105 | 'optimizer_state_dict': self.optimizer.state_dict()
106 | }, os.path.join(save_dir, "pmi", 'pmi_weights_' + str(epoch_i) + '.pth'))
107 |
108 | def load(self, path):
109 | if path and os.path.exists(path):
110 | checkpoint = torch.load(path)
111 | self.load_state_dict(checkpoint['model_state_dict'])
112 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
113 |
--------------------------------------------------------------------------------
/src/utils/draw_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import matplotlib.pyplot as plt
3 | import matplotlib.patches as patches
4 | import matplotlib.colors as mcolors
5 | import imageio
6 | from PIL import Image
7 | import numpy as np
8 |
9 | # 初始化文本对象为None
10 | text_obj = None
11 |
12 | def resize_image(image_path):
13 | img = Image.open(image_path).convert('RGB')
14 | # Resize the image to be divisible by 16
15 | new_width = (img.width // 16) * 16
16 | new_height = (img.height // 16) * 16
17 | if new_width != img.width or new_height != img.height:
18 | img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) # Updated to use Image.Resampling.LANCZOS
19 | return np.array(img)
20 |
21 | def get_gradient_color(start_color, end_color, num_points, idx):
22 | start_rgba = np.array(mcolors.to_rgba(start_color))
23 | end_rgba = np.array(mcolors.to_rgba(end_color))
24 | ratio = idx / max(1, num_points - 1)
25 | gradient_rgba = start_rgba + (end_rgba - start_rgba) * ratio
26 | return mcolors.to_hex(gradient_rgba)
27 |
28 | def update(ax, env, uav_plots, target_plots, uav_search_patches, frame, frames, num_steps, interval=2, paint_all=True):
29 | global text_obj
30 |
31 | if frame == 0:
32 | return
33 | for i, uav in enumerate(env.uav_list):
34 | uav_x = env.position['all_uav_xs'][0: frame: interval]
35 | uav_y = env.position['all_uav_ys'][0: frame: interval]
36 | uav_x = [sublist[i] for sublist in uav_x]
37 | uav_y = [sublist[i] for sublist in uav_y]
38 | if uav_x and uav_y: # Ensure the lists are not empty
39 | colors = [get_gradient_color('#E1FFFF', '#0000FF', frame, idx) for idx in range(len(uav_x))]
40 | uav_plots[i].set_offsets(np.column_stack([uav_x, uav_y]))
41 | uav_plots[i].set_color(colors)
42 | uav_search_patches[i].center = (uav_x[-1], uav_y[-1])
43 | else:
44 | print(f"Warning: UAV {i} position list is empty at frame {frame}.")
45 |
46 | for i in range(env.m_targets):
47 | target_x = env.position['all_target_xs'][0: frame: interval]
48 | target_y = env.position['all_target_ys'][0: frame: interval]
49 | target_x = [sublist[i] for sublist in target_x]
50 | target_y = [sublist[i] for sublist in target_y]
51 | if target_x and target_y: # Ensure the lists are not empty
52 | colors = [get_gradient_color('#FFC0CB', '#DC143C', frame, idx) for idx in range(len(target_x))]
53 | target_plots[i].set_offsets(np.column_stack([target_x, target_y]))
54 | target_plots[i].set_color(colors)
55 | else:
56 | print(f"Warning: Target {i} position list is empty at frame {frame}.")
57 |
58 | text_str = (
59 | f"detected target num = {env.covered_target_num[frame]}\n"
60 | f"detected target rate = {env.covered_target_num[frame] / env.m_targets * 100:.2f}%"
61 | )
62 |
63 | # 清除之前的文本对象(如果存在)
64 | if text_obj is not None:
65 | text_obj.remove()
66 |
67 | # 绘制新的文本对象,没有边框,颜色为深蓝色
68 | text_obj = ax.text(0.02, 0.98, text_str, transform=ax.transAxes, fontsize=10, verticalalignment='top',
69 | color='black')
70 |
71 |
72 | def draw_animation(config, env, num_steps, ep_num, frames=100):
73 | fig, ax = plt.subplots(figsize=(6, 6))
74 | ax.set_xlim(-env.x_max / 3, env.x_max / 3 * 4)
75 | ax.set_ylim(-env.y_max / 3, env.y_max / 3 * 4)
76 | uav_plots = [ax.scatter([], [], marker='o', color='b', linestyle='None', s=2,alpha=1) for _ in range(env.n_uav)]
77 | target_plots = [ax.scatter([], [], marker='o', color='r', linestyle='None', s=3,alpha=1) for _ in range(env.m_targets)]
78 | uav_search_patches = [patches.Circle((0, 0), uav.dp, color='lightblue', alpha=0.2) for uav in env.uav_list]
79 | for patch in uav_search_patches:
80 | ax.add_patch(patch)
81 |
82 | save_dir = os.path.join(config["save_dir"], "frames")
83 | os.makedirs(save_dir, exist_ok=True)
84 |
85 | # Save frames at intervals of 5 num_steps
86 | step_interval = 5
87 | for frame in range(0, num_steps, step_interval):
88 | update(ax, env, uav_plots, target_plots, uav_search_patches, frame, frames, num_steps)
89 | plt.draw()
90 | plt.savefig(os.path.join(save_dir, f'frame_{frame:04d}.png'))
91 | plt.pause(0.001) # Pause to ensure the plot updates visibly if needed
92 |
93 | plt.close(fig)
94 |
95 | # Generate MP4
96 | video_path = os.path.join(config["save_dir"], "animated", f'animated_plot_{ep_num + 1}.mp4')
97 | writer = imageio.get_writer(video_path, fps=5, codec='libx264', format='FFMPEG', pixelformat='yuv420p')
98 |
99 | for frame in range(0, num_steps, step_interval):
100 | frame_path = os.path.join(save_dir, f'frame_{frame:04d}.png')
101 | if os.path.exists(frame_path):
102 | img_array = resize_image(frame_path)
103 | writer.append_data(img_array)
104 | writer.close()
105 |
106 | # Optionally remove PNG files
107 | for frame in range(0, num_steps, step_interval):
108 | frame_path = os.path.join(save_dir, f'frame_{frame:04d}.png')
109 | if os.path.exists(frame_path):
110 | os.remove(frame_path)
111 |
112 |
113 | def plot_reward_curve(config, return_list, name):
114 | plt.figure(figsize=(6, 6))
115 | plt.plot(return_list)
116 | plt.xlabel('Episodes')
117 | plt.ylabel('Total Return')
118 | plt.title(name)
119 | plt.grid(True)
120 | plt.savefig(os.path.join(config["save_dir"], name + ".png"))
121 | # plt.show()
122 |
123 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path
3 | from environment import Environment
4 | from models.actor_critic import ActorCritic
5 | from utils.args_util import get_config
6 | from train import train, evaluate, run
7 | from models.PMINet import PMINetwork
8 | from utils.data_util import save_csv
9 | from utils.draw_util import plot_reward_curve
10 |
11 |
12 | def print_config(vdict, name="config"):
13 | """
14 | :param vdict: dict, 待打印的字典
15 | :param name: str, 打印的字典名称
16 | :return: None
17 | """
18 | print("-----------------------------------------")
19 | print("|This is the summary of {}:".format(name))
20 | var = vdict
21 | for i in var:
22 | if var[i] is None:
23 | continue
24 | print("|{:11}\t: {}".format(i, var[i]))
25 | print("-----------------------------------------")
26 |
27 |
28 | def print_args(args, name="args"):
29 | """
30 | :param args:
31 | :param name: str, 打印的字典名称
32 | :return: None
33 | """
34 | print("-----------------------------------------")
35 | print("|This is the summary of {}:".format(name))
36 | for arg in vars(args):
37 | print("| {:<11} : {}".format(arg, getattr(args, arg)))
38 | print("-----------------------------------------")
39 |
40 |
41 | def add_args_to_config(config, args):
42 | for arg in vars(args):
43 | # print("| {:<11} : {}".format(arg, getattr(args, arg)))
44 | config[str(arg)] = getattr(args, arg)
45 |
46 |
47 | def main(args):
48 | # 获取方法所用的参数
49 | config = get_config(os.path.join("configs", args.method + ".yaml"))
50 | add_args_to_config(config, args)
51 | print_config(config)
52 | print_args(args)
53 |
54 | # 初始化environment, agent
55 | env = Environment(n_uav=config["environment"]["n_uav"],
56 | m_targets=config["environment"]["m_targets"],
57 | x_max=config["environment"]["x_max"],
58 | y_max=config["environment"]["y_max"],
59 | na=config["environment"]["na"])
60 | if args.method == "C-METHOD":
61 | agent = None
62 | else:
63 | agent = ActorCritic(state_dim=12,
64 | hidden_dim=config["actor_critic"]["hidden_dim"],
65 | action_dim=config["environment"]["na"],
66 | actor_lr=float(config["actor_critic"]["actor_lr"]),
67 | critic_lr=float(config["actor_critic"]["critic_lr"]),
68 | gamma=float(config["actor_critic"]["gamma"]),
69 | device=config["devices"][0]) # 只用第一个device
70 | agent.load(args.actor_path, args.critic_path)
71 |
72 | # 初始化 pmi
73 | if args.method == "MAAC" or args.method == "MAAC-G" or args.method == "C-METHOD":
74 | pmi = None
75 | if args.method == "MAAC":
76 | config["cooperative"] = 0 # 只考虑无人机自己的奖励
77 | elif args.method == "MAAC-R":
78 | pmi = PMINetwork(hidden_dim=config["pmi"]["hidden_dim"],
79 | b2_size=config["pmi"]["b2_size"])
80 | pmi.load(args.pmi_path)
81 | else:
82 | return
83 |
84 | if args.phase == "train":
85 | return_list = train(config=config,
86 | env=env,
87 | agent=agent,
88 | pmi=pmi,
89 | num_episodes=args.num_episodes,
90 | num_steps=args.num_steps,
91 | frequency=args.frequency)
92 | elif args.phase == "evaluate":
93 | return_list = evaluate(config=config,
94 | env=env,
95 | agent=agent,
96 | pmi=pmi,
97 | num_steps=args.num_steps)
98 | elif args.phase == "run":
99 | return_list = run(config=config,
100 | env=env,
101 | pmi=pmi,
102 | num_steps=args.num_steps)
103 | else:
104 | return
105 |
106 | save_csv(config, return_list)
107 |
108 | plot_reward_curve(config, return_list['return_list'], "overall_return")
109 | plot_reward_curve(config, return_list["target_tracking_return_list"],
110 | "target_tracking_return_list")
111 | plot_reward_curve(config, return_list["boundary_punishment_return_list"],
112 | "boundary_punishment_return_list")
113 | plot_reward_curve(config, return_list["duplicate_tracking_punishment_return_list"],
114 | "duplicate_tracking_punishment_return_list")
115 | plot_reward_curve(config, return_list["average_covered_targets_list"],
116 | "average_covered_targets_list")
117 | plot_reward_curve(config, return_list["max_covered_targets_list"],
118 | "max_covered_targets_list")
119 |
120 |
121 | if __name__ == "__main__":
122 | # 创建命令行参数解析器
123 | parser = argparse.ArgumentParser(description="")
124 |
125 | # 添加超参数
126 | parser.add_argument("--phase", type=str, default="train", choices=["train", "evaluate", "run"])
127 | parser.add_argument("-e", "--num_episodes", type=int, default=10000, help="训练轮数")
128 | parser.add_argument("-s", "--num_steps", type=int, default=200, help="每轮进行步数")
129 | parser.add_argument("-f", "--frequency", type=int, default=100, help="打印信息及保存的频率")
130 | parser.add_argument("-a", "--actor_path", type=str, default=None, help="actor网络权重的路径")
131 | parser.add_argument("-c", "--critic_path", type=str, default=None, help="critic网络权重的路径")
132 | parser.add_argument("-p", "--pmi_path", type=str, default=None, help="pmi网络权重的路径")
133 | parser.add_argument("-m", "--method", help="", default="MAAC-R", choices=["MAAC", "MAAC-R", "MAAC-G", "C-METHOD"])
134 | # 解析命令行参数
135 | main_args = parser.parse_args()
136 |
137 | # 调用主函数
138 | main(main_args)
139 |
--------------------------------------------------------------------------------
/src/models/actor_critic.py:
--------------------------------------------------------------------------------
1 | import os.path
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as f
6 | import numpy as np
7 |
8 |
9 | class ResidualBlock(nn.Module):
10 | def __init__(self, in_channels, out_channels, stride=1):
11 | super(ResidualBlock, self).__init__()
12 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
13 | self.bn1 = nn.BatchNorm1d(out_channels)
14 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm1d(out_channels)
16 | self.relu = nn.ReLU(inplace=True)
17 | self.down_sample = None
18 | if stride != 1 or in_channels != out_channels:
19 | self.down_sample = nn.Sequential(
20 | nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
21 | nn.BatchNorm1d(out_channels)
22 | )
23 |
24 | def forward(self, x):
25 | residual = x
26 | out = self.conv1(x)
27 | out = self.bn1(out)
28 | out = self.relu(out)
29 | out = self.conv2(out)
30 | out = self.bn2(out)
31 | if self.down_sample is not None:
32 | residual = self.down_sample(x)
33 | out += residual
34 | out = self.relu(out)
35 | return out
36 |
37 |
38 | class ResPolicyNet(nn.Module):
39 | def __init__(self, state_dim, hidden_dim, action_dim):
40 | super(ResPolicyNet, self).__init__()
41 | self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False)
42 | self.bn1 = nn.BatchNorm1d(hidden_dim)
43 | self.relu = nn.ReLU(inplace=True)
44 | self.residual_block1 = ResidualBlock(hidden_dim, hidden_dim)
45 | self.residual_block2 = ResidualBlock(hidden_dim, hidden_dim)
46 | self.fc = nn.Linear(hidden_dim, action_dim)
47 |
48 | def forward(self, x):
49 | x = x.unsqueeze(1)
50 | x = self.conv1(x)
51 | x = self.bn1(x)
52 | x = self.relu(x)
53 | x = self.residual_block1(x)
54 | x = self.residual_block2(x)
55 | x = f.avg_pool1d(x, 12) # 这里使用平均池化,你也可以根据需求使用其他池化方式
56 | x = x.view(x.size(0), -1)
57 | x = self.fc(x)
58 | return f.softmax(x, dim=1)
59 |
60 |
61 | class ResValueNet(nn.Module):
62 | def __init__(self, state_dim, hidden_dim):
63 | super(ResValueNet, self).__init__()
64 | self.conv1 = nn.Conv1d(1, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False)
65 | self.bn1 = nn.BatchNorm1d(hidden_dim)
66 | self.relu = nn.ReLU(inplace=True)
67 | self.residual_block1 = ResidualBlock(hidden_dim, hidden_dim)
68 | self.residual_block2 = ResidualBlock(hidden_dim, hidden_dim)
69 | self.avg_pool = nn.AdaptiveAvgPool1d(1)
70 | self.fc = nn.Linear(hidden_dim, 1)
71 |
72 | def forward(self, x):
73 | x = x.unsqueeze(1)
74 | x = self.conv1(x)
75 | x = self.bn1(x)
76 | x = self.relu(x)
77 | x = self.residual_block1(x)
78 | x = self.residual_block2(x)
79 | x = self.avg_pool(x)
80 | x = x.view(x.size(0), -1)
81 | x = self.fc(x)
82 | return x.squeeze(1)
83 |
84 |
85 | class FnnPolicyNet(nn.Module):
86 | def __init__(self, n_states, n_hiddens, n_actions):
87 | super(FnnPolicyNet, self).__init__()
88 | self.fc1 = nn.Linear(n_states, n_hiddens)
89 | self.fc2 = nn.Linear(n_hiddens, n_actions)
90 |
91 | # 前向传播
92 | def forward(self, x):
93 | x = self.fc1(x) # [b,n_states]-->[b,n_hiddens]
94 | x = f.relu(x)
95 | x = self.fc2(x) # [b,n_hiddens]-->[b,n_actions]
96 | # 每个状态对应的动作的概率
97 | x = f.softmax(x, dim=1) # [b,n_actions]-->[b,n_actions]
98 | return x
99 |
100 |
101 | class FnnValueNet(nn.Module):
102 | def __init__(self, n_states, n_hiddens):
103 | super(FnnValueNet, self).__init__()
104 | self.fc1 = nn.Linear(n_states, n_hiddens)
105 | self.fc2 = nn.Linear(n_hiddens, 1)
106 |
107 | # 前向传播
108 | def forward(self, x):
109 | x = self.fc1(x) # [b,n_states]-->[b,n_hiddens]
110 | x = f.relu(x)
111 | x = self.fc2(x) # [b,n_hiddens]-->[b,1]
112 | return x.squeeze(1)
113 |
114 |
115 | class ActorCritic:
116 | def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
117 | gamma, device):
118 | """
119 | :param state_dim: 特征空间的维数
120 | :param hidden_dim: 隐藏层的维数
121 | :param action_dim: 动作空间的维数
122 | :param actor_lr: actor网络的学习率
123 | :param critic_lr: critic网络的学习率
124 | :param gamma: 经验回放参数
125 | :param device: 用于训练的设备
126 | """
127 | # 策略网络
128 | self.actor = FnnPolicyNet(state_dim, hidden_dim, action_dim).to(device)
129 | self.critic = FnnValueNet(state_dim, hidden_dim).to(device) # 价值网络
130 | # 策略网络优化器
131 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
132 | lr=actor_lr)
133 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
134 | lr=critic_lr) # 价值网络优化器
135 | self.gamma = gamma
136 | self.device = device
137 |
138 | def take_action(self, states):
139 | """
140 | :param states: nparray, size(state_dim,) 代表无人机的状态
141 | :return:
142 | """
143 | states_np = np.array(states)[np.newaxis, :] # 直接使用np.array来转换
144 | states_tensor = torch.tensor(states_np, dtype=torch.float).to(self.device)
145 | probs = self.actor(states_tensor)
146 | action_dist = torch.distributions.Categorical(probs) # TODO ?
147 | action = action_dist.sample()
148 | return action, probs
149 |
150 | def update(self, transition_dict):
151 | """
152 | :param transition_dict: dict, 包含状态,动作, 单个无人机的奖励, 下一个状态的四元组
153 | :return: None
154 | """
155 | states = torch.tensor(np.array(transition_dict['states']),
156 | dtype=torch.float).to(self.device)
157 | actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(
158 | self.device)
159 | # actions = actions.long()
160 | rewards = torch.tensor(transition_dict['rewards'],
161 | dtype=torch.float).view(-1, 1).to(self.device).squeeze()
162 | next_states = torch.tensor(np.array(transition_dict['next_states']),
163 | dtype=torch.float).to(self.device)
164 |
165 | # 时序差分目标
166 | td_target = rewards + self.gamma * self.critic(next_states)
167 | td_delta = td_target - self.critic(states) # 时序差分误差
168 | log_probs = torch.log(self.actor(states).gather(1, actions))
169 |
170 | actor_loss = torch.mean(-log_probs * td_delta.detach())
171 | critic_loss = torch.mean(f.mse_loss(self.critic(states), td_target.detach()))
172 | self.actor_optimizer.zero_grad()
173 | self.critic_optimizer.zero_grad()
174 | actor_loss.backward()
175 | critic_loss.backward()
176 | self.actor_optimizer.step()
177 | self.critic_optimizer.step()
178 |
179 | return actor_loss, critic_loss, td_delta
180 |
181 | def save(self, save_dir, epoch_i):
182 | torch.save({
183 | 'model_state_dict': self.actor.state_dict(),
184 | 'optimizer_state_dict': self.actor_optimizer.state_dict()
185 | }, os.path.join(save_dir, "actor", 'actor_weights_' + str(epoch_i) + '.pth'))
186 | torch.save({
187 | 'model_state_dict': self.critic.state_dict(),
188 | 'optimizer_state_dict': self.critic_optimizer.state_dict()
189 | }, os.path.join(save_dir, "critic", 'critic_weights_' + str(epoch_i) + '.pth'))
190 |
191 | def load(self, actor_path, critic_path):
192 | if actor_path and os.path.exists(actor_path):
193 | checkpoint = torch.load(actor_path)
194 | self.actor.load_state_dict(checkpoint['model_state_dict'])
195 | self.actor_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
196 |
197 | if critic_path and os.path.exists(critic_path):
198 | checkpoint = torch.load(critic_path)
199 | self.critic.load_state_dict(checkpoint['model_state_dict'])
200 | self.critic_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
201 |
--------------------------------------------------------------------------------
/src/environment.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from math import e
3 | from utils.data_util import clip_and_normalize
4 | from agent.uav import UAV
5 | from agent.target import TARGET
6 | import numpy as np
7 | from math import pi
8 | import random
9 | from typing import List
10 |
11 |
12 | class Environment:
13 | def __init__(self, n_uav: int, m_targets: int, x_max: float, y_max: float, na: int):
14 | """
15 | :param n_uav: scalar
16 | :param m_targets: scalar
17 | :param x_max: scalar
18 | :param y_max: scalar
19 | :param na: scalar
20 | """
21 | # size of the environment
22 | self.x_max = x_max
23 | self.y_max = y_max
24 |
25 | # dim of action space and state space
26 | # communication(4 scalar, a), observation(4 scalar), boundary and state information(2 scalar, a)
27 | # self.state_dim = (4 + na) + 4 + (2 + na)
28 | self.state_dim = (4 + 1) + 4 + (2 + 1)
29 | self.action_dim = na
30 |
31 | # agents parameters in the environments
32 | self.n_uav = n_uav
33 | self.m_targets = m_targets
34 |
35 | # agents
36 | self.uav_list = []
37 | self.target_list = []
38 |
39 | # position of uav and target
40 | self.position = {'all_uav_xs': [], 'all_uav_ys': [], 'all_target_xs': [], 'all_target_ys': []}
41 |
42 | # coverage rate of target
43 | self.covered_target_num = []
44 |
45 | def __reset(self, t_v_max, t_h_max, u_v_max, u_h_max, na, dc, dp, dt, init_x, init_y):
46 | """
47 | reset the location for all uav_s at (init_x, init_y)
48 | reset the store position to empty
49 | :return: should be the initial states !!!!
50 | """
51 | if isinstance(init_x, List) and isinstance(init_y, List):
52 | self.uav_list = [UAV(init_x[i],
53 | init_y[i],
54 | random.uniform(-pi, pi),
55 | random.randint(0, self.action_dim - 1),
56 | u_v_max, u_h_max, na, dc, dp, dt) for i in range(self.n_uav)]
57 | elif not isinstance(init_x, List) and not isinstance(init_y, List):
58 | self.uav_list = [UAV(init_x,
59 | init_y,
60 | random.uniform(-pi, pi),
61 | random.randint(0, self.action_dim - 1),
62 | u_v_max, u_h_max, na, dc, dp, dt) for _ in range(self.n_uav)]
63 | elif isinstance(init_x, List):
64 | self.uav_list = [UAV(init_x[i],
65 | init_y,
66 | random.uniform(-pi, pi),
67 | random.randint(0, self.action_dim - 1),
68 | u_v_max, u_h_max, na, dc, dp, dt) for i in range(self.n_uav)]
69 | elif isinstance(init_y, List):
70 | self.uav_list = [UAV(init_x,
71 | init_y[i],
72 | random.uniform(-pi, pi),
73 | random.randint(0, self.action_dim - 1),
74 | u_v_max, u_h_max, na, dc, dp, dt) for i in range(self.n_uav)]
75 | else:
76 | print("wrong init position")
77 | # the initial position of the target is random, having randon headings
78 | self.target_list = [TARGET(random.uniform(0, self.x_max),
79 | random.uniform(0, self.y_max),
80 | random.uniform(-pi, pi),
81 | random.uniform(-pi / 6, pi / 6),
82 | t_v_max, t_h_max, dt)
83 | for _ in range(self.m_targets)]
84 | self.position = {'all_uav_xs': [], 'all_uav_ys': [], 'all_target_xs': [], 'all_target_ys': []}
85 | self.covered_target_num = []
86 |
87 | def reset(self, config):
88 | # self.__reset(t_v_max=config["target"]["v_max"],
89 | # t_h_max=pi / float(config["target"]["h_max"]),
90 | # u_v_max=config["uav"]["v_max"],
91 | # u_h_max=pi / float(config["uav"]["h_max"]),
92 | # na=config["environment"]["na"],
93 | # dc=config["uav"]["dc"],
94 | # dp=config["uav"]["dp"],
95 | # dt=config["uav"]["dt"],
96 | # init_x=config['environment']['x_max']/2, init_y=config['environment']['y_max']/2)
97 | self.__reset(t_v_max=config["target"]["v_max"],
98 | t_h_max=pi / float(config["target"]["h_max"]),
99 | u_v_max=config["uav"]["v_max"],
100 | u_h_max=pi / float(config["uav"]["h_max"]),
101 | na=config["environment"]["na"],
102 | dc=config["uav"]["dc"],
103 | dp=config["uav"]["dp"],
104 | dt=config["uav"]["dt"],
105 | init_x=[x * config['environment']['x_max'] / (config['environment']['n_uav'] + 1)
106 | for x in range(1, config['environment']['n_uav']+1)],
107 | init_y=config['environment']['y_max']/2)
108 |
109 | def get_states(self) -> (List['np.ndarray']):
110 | """
111 | get the state of the uav_s
112 | :return: list of np array, each element is a 1-dim array with size of 12
113 | """
114 | uav_states = []
115 | # collect the overall communication and target observation by each uav
116 | for uav in self.uav_list:
117 | uav_states.append(uav.get_local_state())
118 | return uav_states
119 |
120 | def step(self, config, pmi, actions):
121 | """
122 | state transfer functions
123 | :param config:
124 | :param pmi: PMI network
125 | :param actions: {0,1,...,Na - 1}
126 | :return: states, rewards
127 | """
128 | # update the position of targets
129 | for i, target in enumerate(self.target_list):
130 | target.update_position(self.x_max, self.y_max)
131 |
132 | # update the position of targets
133 | for i, uav in enumerate(self.uav_list):
134 | uav.update_position(actions[i])
135 |
136 | # observation and communication
137 | uav.observe_target(self.target_list)
138 | uav.observe_uav(self.uav_list)
139 |
140 | (rewards,
141 | target_tracking_reward,
142 | boundary_punishment,
143 | duplicate_tracking_punishment) = self.calculate_rewards(config=config, pmi=pmi)
144 | next_states = self.get_states()
145 |
146 | covered_targets = self.calculate_covered_target()
147 | self.covered_target_num.append(covered_targets)
148 |
149 | # trace the position matrix
150 | target_xs, target_ys = self.__get_all_target_position()
151 | self.position['all_target_xs'].append(target_xs)
152 | self.position['all_target_ys'].append(target_ys)
153 | uav_xs, uav_ys = self.__get_all_uav_position()
154 | self.position['all_uav_xs'].append(uav_xs)
155 | self.position['all_uav_ys'].append(uav_ys)
156 |
157 | reward = {
158 | 'rewards': rewards,
159 | 'target_tracking_reward': target_tracking_reward,
160 | 'boundary_punishment': boundary_punishment,
161 | 'duplicate_tracking_punishment': duplicate_tracking_punishment
162 | }
163 |
164 | return next_states, reward, covered_targets
165 |
166 | def __get_all_uav_position(self) -> (List[float], List[float]):
167 | """
168 | :return: all the position of the uav through this epoch
169 | """
170 | uav_xs = []
171 | uav_ys = []
172 | for uav in self.uav_list:
173 | uav_xs.append(uav.x)
174 | uav_ys.append(uav.y)
175 | return uav_xs, uav_ys
176 |
177 | def __get_all_target_position(self) -> (List[float], List[float]):
178 | """
179 | :return: all the position of the targets through this epoch
180 | """
181 | target_xs = []
182 | target_ys = []
183 | for target in self.target_list:
184 | target_xs.append(target.x)
185 | target_ys.append(target.y)
186 | return target_xs, target_ys
187 |
188 | def get_uav_and_target_position(self) -> (List[float], List[float], List[float], List[float]):
189 | """
190 | :return: both the uav and the target position matrix
191 | """
192 | return (self.position['all_uav_xs'], self.position['all_uav_ys'],
193 | self.position['all_target_xs'], self.position['all_target_ys'])
194 |
195 | def calculate_rewards(self, config, pmi) -> ([float], float, float, float):
196 | # raw reward first
197 | target_tracking_rewards = []
198 | boundary_punishments = []
199 | duplicate_tracking_punishments = []
200 | for uav in self.uav_list:
201 | # raw reward for each uav (not clipped)
202 | (target_tracking_reward,
203 | boundary_punishment,
204 | duplicate_tracking_punishment) = uav.calculate_raw_reward(self.uav_list, self.target_list, self.x_max, self.y_max)
205 |
206 | # clip op
207 | target_tracking_reward = clip_and_normalize(target_tracking_reward,
208 | 0, 2 * config['environment']['m_targets'], 0)
209 | duplicate_tracking_punishment = clip_and_normalize(duplicate_tracking_punishment,
210 | -e / 2 * config['environment']['n_uav'], 0, -1)
211 | boundary_punishment = clip_and_normalize(boundary_punishment, -1/2, 0, -1)
212 |
213 | # append
214 | target_tracking_rewards.append(target_tracking_reward)
215 | boundary_punishments.append(boundary_punishment)
216 | duplicate_tracking_punishments.append(duplicate_tracking_punishment)
217 |
218 | # weights
219 | uav.raw_reward = (config["uav"]["alpha"] * target_tracking_reward + config["uav"]["beta"] *
220 | boundary_punishment + config["uav"]["gamma"] * duplicate_tracking_punishment)
221 |
222 | rewards = []
223 | for uav in self.uav_list:
224 | reward = uav.calculate_cooperative_reward(self.uav_list, pmi, config['cooperative'])
225 | uav.reward = clip_and_normalize(reward, -1, 1)
226 | rewards.append(uav.reward)
227 | return rewards, target_tracking_rewards, boundary_punishments, duplicate_tracking_punishments
228 |
229 | def save_position(self, save_dir, epoch_i):
230 | u_xy = np.array([self.position["all_uav_xs"],
231 | self.position["all_uav_ys"]]).transpose() # n_uav * num_steps * 2
232 | t_xy = np.array([self.position["all_target_xs"],
233 | self.position["all_target_ys"]]).transpose() # m_target * num_steps * 2
234 |
235 | np.savetxt(os.path.join(save_dir, "u_xy", 'u_xy' + str(epoch_i) + '.csv'),
236 | u_xy.reshape(-1, 2), delimiter=',', header='x,y', comments='')
237 | np.savetxt(os.path.join(save_dir, "t_xy", 't_xy' + str(epoch_i) + '.csv'),
238 | t_xy.reshape(-1, 2), delimiter=',', header='x,y', comments='')
239 |
240 | def save_covered_num(self, save_dir, epoch_i):
241 | covered_target_num_array = np.array(self.covered_target_num).reshape(-1, 1)
242 |
243 | np.savetxt(os.path.join(save_dir, "covered_target_num", 'covered_target_num' + str(epoch_i) + '.csv'),
244 | covered_target_num_array, delimiter=',', header='covered_target_num', comments='')
245 |
246 | def calculate_covered_target(self):
247 | covered_target_num = 0
248 | for target in self.target_list:
249 | for uav in self.uav_list:
250 | if uav.distance(uav.x, uav.y, target.x, target.y) < uav.dp:
251 | covered_target_num += 1
252 | break
253 | return covered_target_num
254 |
255 |
256 |
--------------------------------------------------------------------------------
/src/agent/uav.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | from math import cos, sin, sqrt, exp, pi, e, atan2
4 | from typing import List, Tuple
5 | from models.PMINet import PMINetwork
6 | from agent.target import TARGET
7 | from scipy.special import softmax
8 | from utils.data_util import clip_and_normalize
9 |
10 |
11 | class UAV:
12 | def __init__(self, x0, y0, h0, a_idx, v_max, h_max, na, dc, dp, dt):
13 | """
14 | :param dt: float, 采样的时间间隔
15 | :param x0: float, 坐标
16 | :param y0: float, 坐标
17 | :param h0: float, 朝向
18 | :param v_max: float, 最大线速度
19 | :param h_max: float, 最大角速度
20 | :param na: int, 动作空间的维度
21 | :param dc: float, 与无人机交流的最大距离
22 | :param dp: float, 观测目标的最大距离
23 | """
24 | # the position, velocity and heading of this uav
25 | self.x = x0
26 | self.y = y0
27 | self.h = h0
28 | self.v_max = v_max
29 |
30 | # the max heading angular rate and the action of this uav
31 | self.h_max = h_max
32 | self.Na = na
33 |
34 | # action
35 | self.a = a_idx
36 |
37 | # the maximum communication distance and maximum perception distance
38 | self.dc = dc
39 | self.dp = dp
40 |
41 | # time interval
42 | self.dt = dt
43 |
44 | # set of local information
45 | # self.communication = []
46 | self.target_observation = []
47 | self.uav_communication = []
48 |
49 | # reward
50 | self.raw_reward = 0
51 | self.reward = 0
52 |
53 | def __distance(self, target) -> float:
54 | """
55 | calculate the distance from uav to target
56 | :param target: class UAV or class TARGET
57 | :return: scalar
58 | """
59 | return sqrt((self.x - target.x) ** 2 + (self.y - target.y) ** 2)
60 |
61 | @staticmethod
62 | def distance(x1, y1, x2, y2) -> float:
63 | """
64 | calculate the distance from uav to target
65 | :param x2:
66 | :param y1:
67 | :param x1:
68 | :param y2:
69 | :return: scalar
70 | """
71 | return sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
72 |
73 | def discrete_action(self, a_idx: int) -> float:
74 | """
75 | from the action space index to the real difference
76 | :param a_idx: {0,1,...,Na - 1}
77 | :return: action : scalar 即角度改变量
78 | """
79 | # from action space to the real world action
80 | na = a_idx + 1 # 从 1 开始索引
81 | return (2 * na - self.Na - 1) * self.h_max / (self.Na - 1)
82 |
83 | def update_position(self, action: 'int') -> (float, float, float):
84 | """
85 | receive the index from action space, then update the current position
86 | :param action: {0,1,...,Na - 1}
87 | :return:
88 | """
89 | self.a = action
90 | a = self.discrete_action(action) # 有可能把这行放到其他位置
91 |
92 | dx = self.dt * self.v_max * cos(self.h) # x 方向位移
93 | dy = self.dt * self.v_max * sin(self.h) # y 方向位移
94 | self.x += dx
95 | self.y += dy
96 | self.h += self.dt * a # 更新朝向角度
97 | self.h = (self.h + pi) % (2 * pi) - pi # 确保朝向角度在 [-pi, pi) 范围内
98 |
99 | return self.x, self.y, self.h # 返回agent的位置和朝向(heading/theta)
100 |
101 | def observe_target(self, targets_list: List['TARGET'], relative=True):
102 | """
103 | Observing target with a radius within dp
104 | :param relative: relative to uav itself
105 | :param targets_list: [class UAV]
106 | :return: None
107 | """
108 | self.target_observation = [] # Reset observed targets
109 | for target in targets_list:
110 | dist = self.__distance(target)
111 | if dist <= self.dp:
112 | # add (x, y, vx, vy) information
113 | if relative:
114 | self.target_observation.append(((target.x - self.x) / self.dp,
115 | (target.y - self.y) / self.dp,
116 | cos(target.h) * target.v_max / self.v_max - cos(self.h),
117 | sin(target.h) * target.v_max / self.v_max - sin(self.h)))
118 | else:
119 | self.target_observation.append((target.x / self.dp,
120 | target.y / self.dp,
121 | cos(target.h) * target.v_max / self.v_max,
122 | sin(target.h) * target.v_max / self.v_max))
123 |
124 | def observe_uav(self, uav_list: List['UAV'], relative=True): # communication
125 | """
126 | communicate with other uav_s with a radius within dp
127 | :param relative: relative to uav itself
128 | :param uav_list: [class UAV]
129 | :return:
130 | """
131 | self.uav_communication = [] # Reset observed targets
132 | for uav in uav_list:
133 | dist = self.__distance(uav)
134 | if dist <= self.dc and uav != self:
135 | # add (x, y, vx, vy, a) information
136 | if relative:
137 | self.uav_communication.append(((uav.x - self.x) / self.dc,
138 | (uav.y - self.y) / self.dc,
139 | cos(uav.h) - cos(self.h),
140 | sin(uav.h) - sin(self.h),
141 | (uav.a - self.a) / self.Na))
142 | else:
143 | self.uav_communication.append((uav.x / self.dc,
144 | uav.y / self.dc,
145 | cos(uav.h),
146 | sin(uav.h),
147 | uav.a / self.Na))
148 |
149 | def __get_all_local_state(self) -> (List[Tuple[float, float, float, float, float]],
150 | List[Tuple[float, float, float, float]], Tuple[float, float, float]):
151 | """
152 | :return: [(x, y, vx, by, na),...] for uav, [(x, y, vx, vy)] for targets, (x, y, na) for itself
153 | """
154 | return self.uav_communication, self.target_observation, (self.x / self.dc, self.y / self.dc, self.a / self.Na)
155 |
156 | def __get_local_state_by_weighted_mean(self) -> 'np.ndarray':
157 | """
158 | :return: return weighted state: ndarray: (12)
159 | """
160 | communication, observation, sb = self.__get_all_local_state()
161 |
162 | if communication:
163 | d_communication = [] # store the distance from each uav to itself
164 | for x, y, vx, vy, na in communication:
165 | d_communication.append(min(self.distance(x, y, self.x, self.y), 1))
166 |
167 | # regularization by the distance
168 | # communication = self.__transform_to_array2d(communication)
169 | communication = np.array(communication)
170 | communication_weighted = communication / np.array(d_communication)[:, np.newaxis]
171 | average_communication = np.mean(communication_weighted, axis=0)
172 | else:
173 | # average_communication = np.zeros(4 + self.Na) # empty communication
174 | average_communication = -np.ones(4 + 1) # empty communication # TODO -1合法吗
175 |
176 | if observation:
177 | d_observation = [] # store the distance from each target to itself
178 | for x, y, vx, vy in observation:
179 | d_observation.append(min(self.distance(x, y, self.x, self.y), 1))
180 |
181 | # regularization by the distance
182 | observation = np.array(observation)
183 | observation_weighted = observation / np.array(d_observation)[:, np.newaxis]
184 | average_observation = np.mean(observation_weighted, axis=0)
185 | else:
186 | average_observation = -np.ones(4) # empty observation # TODO -1合法吗
187 |
188 | sb = np.array(sb)
189 | result = np.hstack((average_communication, average_observation, sb))
190 | return result
191 |
192 | def get_local_state(self) -> 'np.ndarray':
193 | """
194 | :return: np.ndarray
195 | """
196 | # using weighted mean method:
197 | return self.__get_local_state_by_weighted_mean()
198 |
199 | def __calculate_multi_target_tracking_reward(self, uav_list) -> float:
200 | """
201 | calculate multi target tracking reward
202 | :return: scalar [1, 2)
203 | """
204 | track_reward = 0
205 | for other_uav in uav_list:
206 | if other_uav != self:
207 | distance = self.__distance(other_uav)
208 | if distance <= self.dp:
209 | reward = 1 + (self.dp - distance) / self.dp
210 | # track_reward += clip_and_normalize(reward, 1, 2, 0)
211 | track_reward += reward # 没有clip, 在调用时外部clip
212 | return track_reward
213 |
214 | def __calculate_duplicate_tracking_punishment(self, uav_list: List['UAV'], radio=2) -> float:
215 | """
216 | calculate duplicate tracking punishment
217 | :param uav_list: [class UAV]
218 | :param radio: radio用来控制惩罚的范围, 超出多远才算入惩罚
219 | :return: scalar (-e/2, -1/2]
220 | """
221 | total_punishment = 0
222 | for other_uav in uav_list:
223 | if other_uav != self:
224 | distance = self.__distance(other_uav)
225 | if distance <= radio * self.dp:
226 | punishment = -0.5 * exp((radio * self.dp - distance) / (radio * self.dp))
227 | # total_punishment += clip_and_normalize(punishment, -e/2, -1/2, -1)
228 | total_punishment += punishment # 没有clip, 在调用时外部clip
229 | return total_punishment
230 |
231 | def __calculate_boundary_punishment(self, x_max: float, y_max: float) -> float:
232 | """
233 | :param x_max: border of the map at x-axis, scalar
234 | :param y_max: border of the map at y-axis, scalar
235 | :return:
236 | """
237 | x_to_0 = self.x - 0
238 | x_to_max = x_max - self.x
239 | y_to_0 = self.y - 0
240 | y_to_max = y_max - self.y
241 | d_bdr = min(x_to_0, x_to_max, y_to_0, y_to_max)
242 | if 0 <= self.x <= x_max and 0 <= self.y <= y_max:
243 | if d_bdr < self.dp:
244 | boundary_punishment = -0.5 * (self.dp - d_bdr) / self.dp
245 | else:
246 | boundary_punishment = 0
247 | else:
248 | boundary_punishment = -1/2
249 | return boundary_punishment # 没有clip, 在调用时外部clip
250 | # return clip_and_normalize(boundary_punishment, -1/2, 0, -1)
251 |
252 | def calculate_raw_reward(self, uav_list: List['UAV'], target__list: List['TAEGET'], x_max, y_max):
253 | """
254 | calculate three parts of the reward/punishment for this uav
255 | :return: float, float, float
256 | """
257 | reward = self.__calculate_multi_target_tracking_reward(target__list)
258 | boundary_punishment = self.__calculate_boundary_punishment(x_max, y_max)
259 | punishment = self.__calculate_duplicate_tracking_punishment(uav_list)
260 | return reward, boundary_punishment, punishment
261 |
262 | def __calculate_cooperative_reward_by_pmi(self, uav_list: List['UAV'], pmi_net: "PMINetwork", a) -> float:
263 | """
264 | calculate cooperative reward by pmi network
265 | :param pmi_net: class PMINetwork
266 | :param uav_list: [class UAV]
267 | :param a: float, proportion of selfish and sharing
268 | :return:
269 | """
270 | if a == 0: # 提前判断,节省计算的复杂度
271 | return self.raw_reward
272 |
273 | neighbor_rewards = []
274 | neighbor_dependencies = []
275 | la = self.get_local_state()
276 |
277 | for other_uav in uav_list:
278 | if other_uav != self and self.__distance(other_uav) <= self.dp:
279 | neighbor_rewards.append(other_uav.raw_reward)
280 | other_uav_la = other_uav.get_local_state()
281 | _input = la * other_uav_la
282 | neighbor_dependencies.append(pmi_net.inference(_input.squeeze()))
283 |
284 | if len(neighbor_rewards):
285 | neighbor_rewards = np.array(neighbor_rewards)
286 | neighbor_dependencies = np.array(neighbor_dependencies).astype(np.float32)
287 | softmax_values = softmax(neighbor_dependencies)
288 | reward = (1 - a) * self.raw_reward + a * np.sum(neighbor_rewards * softmax_values).item()
289 | else:
290 | reward = (1 - a) * self.raw_reward
291 | return reward
292 |
293 | def __calculate_cooperative_reward_by_mean(self, uav_list: List['UAV'], a) -> float:
294 | """
295 | calculate cooperative reward by mean
296 | :param uav_list: [class UAV]
297 | :param a: float, proportion of selfish and sharing
298 | :return:
299 | """
300 | if a == 0: # 提前判断,节省计算的复杂度
301 | return self.raw_reward
302 |
303 | neighbor_rewards = []
304 | for other_uav in uav_list:
305 | if other_uav != self and self.__distance(other_uav) <= self.dp:
306 | neighbor_rewards.append(other_uav.raw_reward)
307 | # 没有加入PMI网络
308 | reward = (1 - a) * self.raw_reward + a * sum(neighbor_rewards) / len(neighbor_rewards) \
309 | if len(neighbor_rewards) else 0
310 | return reward
311 |
312 | def calculate_cooperative_reward(self, uav_list: List['UAV'], pmi_net=None, a=0.5) -> float:
313 | """
314 | :param uav_list:
315 | :param pmi_net:
316 | :param a: 0: selfish, 1: completely shared
317 | :return:
318 | """
319 | if pmi_net:
320 | return self.__calculate_cooperative_reward_by_pmi(uav_list, pmi_net, a)
321 | else:
322 | return self.__calculate_cooperative_reward_by_mean(uav_list, a)
323 |
324 | def get_action_by_direction(self, target_list, uav_list):
325 | def distance(x1, y1, x2, y2):
326 | return np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
327 |
328 | # 奖励和惩罚权重
329 | target_reward_weight = 1.0
330 | repetition_penalty_weight = 0.8
331 | self.epsilon = 0.25
332 | self.continue_tracing = 0.3
333 |
334 | best_score = float('-inf')
335 | best_angle = 0.0
336 |
337 | # 随机扰动:以epsilon的概率选择随机目标
338 | if random.random() < self.epsilon:
339 | return np.random.randint(0, self.Na)
340 | else:
341 | for target in target_list:
342 | target_x, target_y = target.x, target.y
343 |
344 | # 当前无人机到目标的距离
345 | dist_to_target = distance(self.x, self.y, target_x, target_y)
346 |
347 | # 重复追踪的惩罚,考虑其他无人机在重复追踪半径内是否在追踪同一目标
348 | repetition_penalty = 0.0
349 | for uav in uav_list:
350 | uav_x, uav_y = uav.x, uav.y
351 | if (uav_x, uav_y) != (self.x, self.y):
352 | dist_to_target_from_other_uav = distance(uav_x, uav_y, target_x, target_y)
353 | if dist_to_target_from_other_uav < self.dc:
354 | repetition_penalty += repetition_penalty_weight
355 |
356 | # 计算当前目标的得分
357 | score = target_reward_weight / dist_to_target - repetition_penalty
358 |
359 | # 根据得分选择最优目标
360 | if score > best_score:
361 | best_score = score
362 | best_angle = np.arctan2(target_y - self.y, target_x - self.x) - self.h
363 |
364 | # 以continue_tracing的概率保持上一个动作
365 | if random.random() < self.continue_tracing:
366 | best_angle = 0
367 |
368 | actual_action = self.find_closest_a_idx(best_angle)
369 | return actual_action
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import csv
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | from utils.draw_util import draw_animation
7 | from torch.utils.tensorboard import SummaryWriter
8 | import random
9 | import collections
10 |
11 |
12 | class ReturnValueOfTrain:
13 | def __init__(self):
14 | self.return_list = []
15 | self.target_tracking_return_list = []
16 | self.boundary_punishment_return_list = []
17 | self.duplicate_tracking_punishment_return_list = []
18 | self.average_covered_targets_list = []
19 | self.max_covered_targets_list = []
20 |
21 | def item(self):
22 | value_dict = {
23 | 'return_list': self.return_list,
24 | 'target_tracking_return_list': self.target_tracking_return_list,
25 | 'boundary_punishment_return_list': self.boundary_punishment_return_list,
26 | 'duplicate_tracking_punishment_return_list': self.duplicate_tracking_punishment_return_list,
27 | 'average_covered_targets_list': self.average_covered_targets_list,
28 | 'max_covered_targets_list': self.max_covered_targets_list
29 | }
30 | return value_dict
31 |
32 | def save_epoch(self, reward, tt_return, bp_return, dtp_return, average_targets, max_targets):
33 | self.return_list.append(reward)
34 | self.target_tracking_return_list.append(tt_return)
35 | self.boundary_punishment_return_list.append(bp_return)
36 | self.duplicate_tracking_punishment_return_list.append(dtp_return)
37 | self.average_covered_targets_list.append(average_targets)
38 | self.max_covered_targets_list.append(max_targets)
39 |
40 |
41 | class ReplayBuffer:
42 | def __init__(self, capacity):
43 | self.buffer = collections.deque(maxlen=capacity)
44 |
45 | def add(self, transition_dict):
46 | # 从transition_dict中提取各个列表
47 | states = transition_dict['states']
48 | actions = transition_dict['actions']
49 | rewards = transition_dict['rewards']
50 | next_states = transition_dict['next_states']
51 |
52 | # 将各个元素合并成元组,并添加到缓冲区中
53 | experiences = zip(states, actions, rewards, next_states)
54 | self.buffer.extend(experiences)
55 |
56 | def sample(self, batch_size):
57 | transitions = random.sample(self.buffer, min(batch_size, self.size()))
58 | states, actions, rewards, next_states = zip(*transitions)
59 |
60 | # 构造返回的字典
61 | sample_dict = {
62 | 'states': states,
63 | 'actions': actions,
64 | 'rewards': rewards,
65 | 'next_states': next_states
66 | }
67 | return sample_dict
68 |
69 | def size(self):
70 | return len(self.buffer)
71 |
72 |
73 | class PrioritizedReplayBuffer:
74 | def __init__(self, capacity, alpha=0.6):
75 | self.capacity = capacity
76 | self.alpha = alpha
77 | self.buffer = collections.deque(maxlen=capacity)
78 | self.priorities = np.zeros((capacity,), dtype=np.float32)
79 | self.pos = 0
80 |
81 | def add(self, transition_dict):
82 | states = transition_dict['states']
83 | actions = transition_dict['actions']
84 | rewards = transition_dict['rewards']
85 | next_states = transition_dict['next_states']
86 |
87 | experiences = zip(states, actions, rewards, next_states)
88 |
89 | for experience in experiences:
90 | max_priority = self.priorities.max() if self.buffer else 1.0
91 |
92 | if len(self.buffer) < self.capacity:
93 | self.buffer.append(experience)
94 | else:
95 | self.buffer[self.pos] = experience
96 |
97 | self.priorities[self.pos] = max_priority
98 | self.pos = (self.pos + 1) % self.capacity
99 |
100 | def sample(self, batch_size, beta=0.4):
101 | if len(self.buffer) == 0:
102 | return dict(states=[], actions=[], rewards=[], next_states=[]), None, None
103 |
104 | if len(self.buffer) == self.capacity:
105 | priorities = self.priorities
106 | else:
107 | priorities = self.priorities[:self.pos]
108 |
109 | probabilities = priorities ** self.alpha
110 | probabilities /= probabilities.sum()
111 |
112 | indices = np.random.choice(len(self.buffer), min(batch_size, len(self.buffer)), p=probabilities)
113 | samples = [self.buffer[idx] for idx in indices]
114 |
115 | total = len(self.buffer)
116 | weights = (total * probabilities[indices]) ** (-beta)
117 | weights /= weights.max()
118 | weights = np.array(weights, dtype=np.float32)
119 |
120 | batch = list(zip(*samples))
121 | states = np.array(batch[0])
122 | actions = np.array(batch[1])
123 | rewards = np.array(batch[2])
124 | next_states = np.array(batch[3])
125 |
126 | sample_dict = {
127 | 'states': states,
128 | 'actions': actions,
129 | 'rewards': rewards,
130 | 'next_states': next_states,
131 | }
132 | return sample_dict, indices, weights
133 |
134 | def update_priorities(self, batch_indices, batch_priorities):
135 | for idx, priority in zip(batch_indices, batch_priorities):
136 | self.priorities[idx] = priority
137 |
138 | def size(self):
139 | return len(self.buffer)
140 |
141 |
142 | def operate_epoch(config, env, agent, pmi, num_steps, cwriter_state=None, cwriter_prob=None):
143 | """
144 | :param config:
145 | :param env:
146 | :param agent:
147 | :param pmi:
148 | :param num_steps:
149 | :param cwriter_state: 用于记录一个epoch内的state信息, 调试bug时使用
150 | :param cwriter_prob: 用于记录一个epoch内的prob信息, 调试bug时使用
151 | :return:
152 | """
153 | transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': []}
154 | episode_return = 0
155 | episode_target_tracking_return = 0
156 | episode_boundary_punishment_return = 0
157 | episode_duplicate_tracking_punishment_return = 0
158 | covered_targets_list = []
159 |
160 | for i in range(num_steps):
161 | config['step'] = i + 1
162 | action_list = []
163 |
164 | # each uav makes choices first
165 | for uav in env.uav_list:
166 | state = uav.get_local_state()
167 | if cwriter_state:
168 | cwriter_state.writerow(state.tolist())
169 | action, probs = agent.take_action(state)
170 | if cwriter_prob:
171 | cwriter_prob.writerow(probs.tolist())
172 | transition_dict['states'].append(state)
173 | action_list.append(action.item())
174 |
175 | # use action_list to update the environment
176 | next_state_list, reward_list, covered_targets = env.step(config, pmi, action_list) # action: List[int]
177 | transition_dict['actions'].extend(action_list)
178 | transition_dict['next_states'].extend(next_state_list)
179 | transition_dict['rewards'].extend(reward_list['rewards'])
180 |
181 | episode_return += sum(reward_list['rewards'])
182 | episode_target_tracking_return += sum(reward_list['target_tracking_reward'])
183 | episode_boundary_punishment_return += sum(reward_list['boundary_punishment'])
184 | episode_duplicate_tracking_punishment_return += sum(reward_list['duplicate_tracking_punishment'])
185 | covered_targets_list.append(covered_targets)
186 |
187 | episode_return /= num_steps * env.n_uav
188 | episode_target_tracking_return /= num_steps * env.n_uav
189 | episode_boundary_punishment_return /= num_steps * env.n_uav
190 | episode_duplicate_tracking_punishment_return /= num_steps * env.n_uav
191 | average_covered_targets = np.mean(covered_targets_list)
192 | max_covered_targets = np.max(covered_targets_list)
193 |
194 | return (transition_dict, episode_return, episode_target_tracking_return,
195 | episode_boundary_punishment_return, episode_duplicate_tracking_punishment_return,
196 | average_covered_targets, max_covered_targets)
197 |
198 |
199 | def train(config, env, agent, pmi, num_episodes, num_steps, frequency):
200 | """
201 | :param config:
202 | :param pmi: pmi network
203 | :param frequency: 打印消息的频率
204 | :param num_steps: 每局进行的步数
205 | :param env:
206 | :param agent: # 因为所有的无人机共享权重训练, 所以共用一个agent
207 | :param num_episodes: 局数
208 | :return:
209 | """
210 | # initialize saving list
211 | save_dir = os.path.join(config["save_dir"], "logs")
212 | writer = SummaryWriter(log_dir=save_dir) # 可以指定log存储的目录
213 | return_value = ReturnValueOfTrain()
214 | # buffer = ReplayBuffer(config["actor_critic"]["buffer_size"])
215 | buffer = PrioritizedReplayBuffer(config["actor_critic"]["buffer_size"])
216 | if config["actor_critic"]["sample_size"] > 0:
217 | sample_size = config["actor_critic"]["sample_size"]
218 | else:
219 | sample_size = config["environment"]["n_uav"] * num_steps
220 |
221 | with open(os.path.join(save_dir, 'state.csv'), mode='w', newline='') as state_file, \
222 | open(os.path.join(save_dir, 'prob.csv'), mode='w', newline='') as prob_file:
223 | cwriter_state = csv.writer(state_file)
224 | cwriter_prob = csv.writer(prob_file)
225 |
226 | cwriter_state.writerow(['state']) # 写入state.csv的表头
227 | cwriter_prob.writerow(['prob']) # 写入prob.csv的表头
228 |
229 | with tqdm(total=num_episodes, desc='Episodes') as pbar:
230 | for i in range(num_episodes):
231 | # reset environment from config yaml file
232 | env.reset(config=config)
233 |
234 | # episode start
235 | # transition_dict, reward, tt_return, bp_return, \
236 | # dtp_return = operate_epoch(config, env, agent, pmi, num_steps, cwriter_state, cwriter_prob)
237 | transition_dict, reward, tt_return, bp_return, \
238 | dtp_return, average_targets, max_targets = operate_epoch(config, env, agent, pmi, num_steps)
239 | writer.add_scalar('reward', reward, i)
240 | writer.add_scalar('target_tracking_return', tt_return, i)
241 | writer.add_scalar('boundary_punishment', bp_return, i)
242 | writer.add_scalar('duplicate_tracking_punishment', dtp_return, i)
243 | writer.add_scalar('average_covered_targets', average_targets, i)
244 | writer.add_scalar('max_covered_targets', max_targets, i)
245 |
246 | # saving return lists
247 | return_value.save_epoch(reward, tt_return, bp_return, dtp_return, average_targets, max_targets)
248 |
249 | # sample from buffer
250 | buffer.add(transition_dict)
251 | # sample_dict = buffer.sample(sample_size)
252 | sample_dict, indices, _ = buffer.sample(sample_size)
253 |
254 | # update actor-critic network
255 | actor_loss, critic_loss, td_errors = agent.update(sample_dict)
256 | writer.add_scalar('actor_loss', actor_loss, i)
257 | writer.add_scalar('critic_loss', critic_loss, i)
258 |
259 | # update buffer
260 | buffer.update_priorities(indices, td_errors.abs().detach().cpu().numpy())
261 |
262 | # update pmi network
263 | if pmi:
264 | avg_pmi_loss = pmi.train_pmi(config, torch.tensor(np.array(sample_dict["states"])), env.n_uav)
265 | writer.add_scalar('avg_pmi_loss', avg_pmi_loss, i)
266 |
267 | # save & print
268 | if (i + 1) % frequency == 0:
269 | # print some information
270 | if pmi:
271 | pbar.set_postfix({'episode': '%d' % (i + 1),
272 | 'return': '%.3f' % np.mean(return_value.return_list[-frequency:]),
273 | 'actor loss': '%f' % actor_loss,
274 | 'critic loss': '%f' % critic_loss,
275 | 'avg pmi loss': '%f' % avg_pmi_loss})
276 | else:
277 | pbar.set_postfix({'episode': '%d' % (i + 1),
278 | 'return': '%.3f' % np.mean(return_value.return_list[-frequency:]),
279 | 'actor loss': '%f' % actor_loss,
280 | 'critic loss': '%f' % critic_loss})
281 |
282 | # save results and weights
283 | draw_animation(config=config, env=env, num_steps=num_steps, ep_num=i)
284 | agent.save(save_dir=config["save_dir"], epoch_i=i + 1)
285 | if pmi:
286 | pmi.save(save_dir=config["save_dir"], epoch_i=i + 1)
287 | env.save_position(save_dir=config["save_dir"], epoch_i=i + 1)
288 | env.save_covered_num(save_dir=config["save_dir"], epoch_i=i + 1)
289 |
290 | # episode end
291 | pbar.update(1)
292 |
293 | writer.close()
294 |
295 | return return_value.item()
296 |
297 |
298 | def evaluate(config, env, agent, pmi, num_steps):
299 | """
300 | :param config:
301 | :param pmi: pmi network
302 | :param num_steps: 每局进行的步数
303 | :param env:
304 | :param agent: # 因为所有的无人机共享权重训练, 所以共用一个agent
305 | :return:
306 | """
307 | # initialize saving list
308 | return_value = ReturnValueOfTrain()
309 |
310 | # reset environment from config yaml file
311 | env.reset(config=config)
312 |
313 | # episode start
314 | transition_dict, reward, tt_return, bp_return, dtp_return, average_targets, max_targets = operate_epoch(config, env, agent, pmi, num_steps)
315 |
316 | # saving return lists
317 | return_value.save_epoch(reward, tt_return, bp_return, dtp_return, average_targets, max_targets)
318 |
319 | # save results and weights
320 | draw_animation(config=config, env=env, num_steps=num_steps, ep_num=0)
321 | env.save_position(save_dir=config["save_dir"], epoch_i=0)
322 | env.save_covered_num(save_dir=config["save_dir"], epoch_i=0)
323 |
324 | return return_value.item()
325 |
326 | def run_epoch(config, pmi, env, num_steps):
327 | """
328 | :param config:
329 | :param env:
330 | :param num_steps:
331 | :return:
332 | """
333 | transition_dict = {'states': [], 'actions': [], 'next_states': [], 'rewards': []}
334 | episode_return = 0
335 | episode_target_tracking_return = 0
336 | episode_boundary_punishment_return = 0
337 | episode_duplicate_tracking_punishment_return = 0
338 | covered_targets_list = []
339 |
340 | for _ in range(num_steps):
341 | action_list = []
342 | # uav_tracking_status = [0] * len(env.uav_list)
343 |
344 | # # each uav makes choices first
345 | # for uav in env.uav_list:
346 | # action, target_index = uav.get_action_by_direction(env.target_list, env.uav_list, uav_tracking_status) # TODO
347 | # uav_tracking_status[target_index] = 1
348 | # action_list.append(action)
349 | for uav in env.uav_list:
350 | action = uav.get_action_by_direction(env.target_list, env.uav_list) # TODO
351 | action_list.append(action)
352 |
353 | next_state_list, reward_list, covered_targets = env.step(config, pmi, action_list) # TODO
354 |
355 | # use action_list to update the environment
356 | transition_dict['actions'].extend(action_list)
357 | transition_dict['rewards'].extend(reward_list['rewards'])
358 |
359 | episode_return += sum(reward_list['rewards'])
360 | episode_target_tracking_return += sum(reward_list['target_tracking_reward'])
361 | episode_boundary_punishment_return += sum(reward_list['boundary_punishment'])
362 | episode_duplicate_tracking_punishment_return += sum(reward_list['duplicate_tracking_punishment'])
363 | covered_targets_list.append(covered_targets)
364 |
365 | average_covered_targets = np.mean(covered_targets_list)
366 | max_covered_targets = np.max(covered_targets_list)
367 |
368 | return (transition_dict, episode_return, episode_target_tracking_return,
369 | episode_boundary_punishment_return, episode_duplicate_tracking_punishment_return,
370 | average_covered_targets, max_covered_targets)
371 |
372 | def run(config, env, pmi, num_steps):
373 | """
374 | :param config:
375 | :param num_steps: 每局进行的步数
376 | :param env:
377 | :return:
378 | """
379 | # initialize saving list
380 | return_value = ReturnValueOfTrain()
381 |
382 | # reset environment from config yaml file
383 | env.reset(config=config)
384 |
385 | # episode start
386 | transition_dict, reward, tt_return, bp_return, dtp_return, average_targets, max_targets = run_epoch(config, pmi, env, num_steps)
387 |
388 | # saving return lists
389 | return_value.save_epoch(reward, tt_return, bp_return, dtp_return, average_targets, max_targets)
390 |
391 | # save results and weights
392 | draw_animation(config=config, env=env, num_steps=num_steps, ep_num=0)
393 | env.save_position(save_dir=config["save_dir"], epoch_i=0)
394 | env.save_covered_num(save_dir=config["save_dir"], epoch_i=0)
395 |
396 | return return_value.item()
--------------------------------------------------------------------------------