├── save ├── features │ └── .gitkeep └── models │ └── .gitkeep ├── .gitignore ├── docs ├── dataset.md └── architecture.md ├── README.md ├── model.py ├── layers.py ├── data_process.py └── run.py /save/features/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /save/models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .vscode 3 | *.save 4 | *.pt 5 | -------------------------------------------------------------------------------- /docs/dataset.md: -------------------------------------------------------------------------------- 1 | # dataset 2 | 3 | Dataset: argoverse_forecasting 4 | 5 | - 每个 csv 文件是一个 sequence; 6 | - 每个 sequence 中有若干 track objects (track_id_list 标识)。 -------------------------------------------------------------------------------- /docs/architecture.md: -------------------------------------------------------------------------------- 1 | # Architecture 2 | 3 | - data_process.py: 4 | - 使用 `argoverse` 的 api 处理数据,将数据 padding 并转换为 features. 5 | - layers.py: 6 | - 实现 subgraphlayer 和 self-attention. 7 | - model.py: 8 | - 实现 subgraph, globalgraph, decoder. 9 | - run.py: 10 | - 模型训练与评测。 11 | - ./save: 12 | - 存放处理好的 features 和模型。 13 | - ./docs: 14 | - 一些文档。 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VectorNet 2 | 3 | [![HitCount](http://hits.dwyl.com/ForeverFancy/VectorNet.svg)](http://hits.dwyl.com/ForeverFancy/VectorNet) 4 | 5 | Pytorch implementation of paper "VectorNet: Encoding HD maps and Agent Dynamics from Vectorized Representation". 6 | 7 | 目录结构: 8 | 9 | ``` 10 | . 11 | ├── README.md 12 | ├── data_process.py 13 | ├── docs 14 | │   └── dataset.md 15 | ├── layers.py 16 | ├── model.py 17 | ├── run.py 18 | └── save 19 | ├── features 20 | └── models 21 | ``` 22 | 23 | 使用示例: 24 | 25 | ``` 26 | python3 run.py --root_dir ../forecasting_sample/data/ --epochs 50 --feature_path ./save/features/ --logging_steps 50 --train_batch_size=16 --enable_logging 27 | ``` 28 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from layers import SubGraphLayer 4 | from layers import SelfAttentionLayer 5 | import numpy as np 6 | 7 | 8 | class SubGraph(nn.Module): 9 | def __init__(self, in_features: int = 7, hidden_size: int = 64): 10 | ''' 11 | Subgraph model 12 | ''' 13 | super(SubGraph, self).__init__() 14 | self.hidden_size = hidden_size 15 | self.in_features = in_features 16 | self.sglayer1 = SubGraphLayer( 17 | in_features=self.in_features, hidden_size=self.hidden_size) 18 | self.sglayer2 = SubGraphLayer( 19 | in_features=self.hidden_size * 2, hidden_size=self.hidden_size) 20 | self.sglayer3 = SubGraphLayer( 21 | in_features=self.hidden_size * 2, hidden_size=self.hidden_size) 22 | 23 | def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 24 | ''' 25 | @input x of shape(batch_size, num_of_seqs, max_seq_size, in_features) 26 | 27 | @input mask of shape(batch_size, num_of_seqs, max_seq_size, in_features) 28 | 29 | @return out of shape(batch_size, num_of_seqs, hidden_size * 2): polyline level feature 30 | ''' 31 | # x now is shape(*, hidden_size) 32 | x = self.sglayer1(x, mask) 33 | x = self.sglayer2(x, mask) 34 | out = self.sglayer3(x, mask) 35 | out = self.aggregate(out) 36 | return out 37 | 38 | def aggregate(self, x: torch.Tensor) -> torch.Tensor: 39 | ''' 40 | @input x of shape(batch_size, num_of_seqs, max_seq_size, hidden_size * 2) 41 | 42 | @return x of shape(batch_size, num_of_seqs, hidden_size * 2) 43 | ''' 44 | y, _ = torch.max(x, dim=2) 45 | return y 46 | 47 | 48 | class GlobalGraph(nn.Module): 49 | def __init__(self, in_features: int = 128, out_features: int = 128): 50 | ''' 51 | Global graph model 52 | ''' 53 | super(GlobalGraph, self).__init__() 54 | self.in_features = in_features 55 | self.out_features = out_features 56 | self.attention_layer = SelfAttentionLayer(self.in_features, self.out_features) 57 | 58 | def forward(self, query: torch.Tensor, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: 59 | ''' 60 | Do self-attention 61 | 62 | @input query of shape(batch_size, 1, 2 * hidden_size): agent features 63 | 64 | @input x of shape (batch_size, num_of_nodes, 2 * hidden_size): all polyline node features 65 | 66 | @input attention_mask of shape (batch_size, num_of_nodes): mask for self attention 67 | 68 | @return out of shape(batch_size, num_of_nodes, 2 * hidden_size): self-attention output 69 | ''' 70 | out = self.attention_layer.forward(query, x, attention_mask) 71 | return out 72 | 73 | 74 | class TrajectoryDecoder(nn.Module): 75 | def __init__(self, in_features: int = 128, out_features: int = 2): 76 | ''' 77 | Decode future trajectories 78 | ''' 79 | super(TrajectoryDecoder, self).__init__() 80 | self.linear = nn.Linear(in_features=in_features, out_features=out_features, bias=True) 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | ''' 84 | @input x of shape(*, in_features) 85 | 86 | @return out of shape(*, out_features) 87 | ''' 88 | out = self.linear(x) 89 | return out 90 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SubGraphLayer(nn.Module): 6 | def __init__(self, in_features: int = 7, hidden_size: int = 64): 7 | ''' 8 | Layer of subgraph 9 | ''' 10 | super(SubGraphLayer, self).__init__() 11 | self.hidden_size = hidden_size 12 | self.in_features = in_features 13 | self.linear = nn.Linear( 14 | in_features=self.in_features, 15 | out_features=self.hidden_size, 16 | bias=True 17 | ) 18 | self.lm = nn.LayerNorm(normalized_shape=(self.hidden_size)) 19 | 20 | def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: 21 | ''' 22 | @input x of shape (batch_size, num_of_seqs, max_seq_size, in_features) 23 | 24 | @input mask of shape (batch_size, num_of_seqs, max_seq_size) 25 | 26 | @return out of shape (batch_size, num_of_seqs, max_seq_size, hidden_size * 2): output features 27 | ''' 28 | x = self.encode(x) 29 | ag = self.aggregate(x).repeat(1, 1, x.shape[2], 1) 30 | out = torch.cat([x, ag], dim=3) 31 | assert out.shape == (x.shape[0], x.shape[1], x.shape[2], self.hidden_size * 2) 32 | out = out * (mask.unsqueeze(dim=-1).repeat(1, 1, 1, self.hidden_size * 2)) 33 | assert out.shape == (x.shape[0], x.shape[1], x.shape[2], self.hidden_size * 2) 34 | 35 | return out 36 | 37 | def encode(self, x: torch.Tensor) -> torch.Tensor: 38 | ''' 39 | MLP + layer normalization + relu activation 40 | 41 | @input x of shape (batch_size, num_of_seqs, max_seq_size, in_features) 42 | 43 | @return x of shape (batch_size, num_of_seqs, max_seq_size, hidden_size) 44 | ''' 45 | x = self.linear(x) 46 | x = self.lm(x) 47 | return torch.relu(x) 48 | 49 | def aggregate(self, x: torch.Tensor) -> torch.Tensor: 50 | ''' 51 | Use maxpooling to aggregate 52 | 53 | @input x of shape (batch_size, num_of_seqs, max_seq_size, hidden_size) 54 | 55 | @return x of shape (batch_size, num_of_seqs, hidden_size) 56 | ''' 57 | y, _ = torch.max(x, dim=2) 58 | return y.unsqueeze(2) 59 | 60 | 61 | class SelfAttentionLayer(nn.Module): 62 | def __init__(self, in_features: int = 128, out_features: int = 128): 63 | ''' 64 | Self-attention layer 65 | ''' 66 | super(SelfAttentionLayer, self).__init__() 67 | self.Proj_Q = nn.Linear( 68 | in_features=in_features, 69 | out_features=out_features, 70 | bias=False 71 | ) 72 | self.Proj_K = nn.Linear( 73 | in_features=in_features, 74 | out_features=out_features, 75 | bias=False 76 | ) 77 | self.Proj_V = nn.Linear( 78 | in_features=in_features, 79 | out_features=out_features, 80 | bias=False 81 | ) 82 | self.softmax = nn.Softmax(dim=2) 83 | 84 | def forward(self, query: torch.Tensor, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor: 85 | ''' 86 | Do self-attention 87 | 88 | @input query of shape(batch_size, 1, 2 * hidden_size): agent features 89 | 90 | @input x of shape (batch_size, num_of_nodes, 2 * hidden_size): all polyline node features 91 | 92 | @input attention_mask of shape (batch_size, num_of_nodes): mask for self attention 93 | 94 | @return out of shape(batch_size, num_of_nodes, 2 * hidden_size): self-attention output 95 | ''' 96 | P_q = self.Proj_Q(query) 97 | P_k = self.Proj_K(x) 98 | P_v = self.Proj_V(x) 99 | out = torch.bmm(P_q, P_k.transpose(1, 2)) 100 | # mask for self attention 101 | if attention_mask is not None: 102 | out = out.masked_fill(attention_mask.unsqueeze(1).expand(-1, query.shape[1], -1) == 0, -1e9) 103 | out = torch.bmm(self.softmax(out), P_v) 104 | return out 105 | -------------------------------------------------------------------------------- /data_process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pickle 3 | from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader 4 | import pandas as pd 5 | import os 6 | 7 | 8 | OBJ_MAP = {"AGENT": 0, "AV": 1, "OTHERS": 2} 9 | 10 | 11 | def load_features(root_dir: str, feature_path: str = None): 12 | ''' 13 | load features from paths 14 | 15 | @input root_dir (str): path to raw data root directory 16 | 17 | @input feature_path (str): path to saved features directory 18 | 19 | @return features (np.ndarray) of shape (num_of_paths, maxnum_of_global_nodes, maxnum_of_sub_nodes, 7): feature of all subgraph, padding to maxnum_of_seqs 20 | 21 | @return subgraph_mask (np.ndarray) of shape (len(feature_list), maxnum_of_global_nodes, maxnum_of_sub_nodes): mask of subgraph nodes 22 | 23 | @return attention_mask (np.ndarray) of shape (len(feature_list), maxnum_of_global_nodes): attention mask of padding nodes 24 | 25 | @return groundtruth (np.ndarray): groundtruth for prediction 26 | 27 | @return groundtruth (np.ndarray): mask for real groundtruth (split padding) 28 | 29 | @return max_groundtruth_length (int): max ground truth length 30 | ''' 31 | if feature_path is None: 32 | load_raw_data(root_dir) 33 | paths = os.listdir("./save/features") 34 | paths = [os.path.join("./save/features", p) for p in paths if p.endswith(".save")] 35 | else: 36 | paths = os.listdir(feature_path) 37 | paths = [os.path.join(feature_path, p) 38 | for p in paths if p.endswith(".save")] 39 | if len(paths) == 0: 40 | load_raw_data(root_dir, feature_path) 41 | paths = os.listdir(feature_path) 42 | paths = [os.path.join(feature_path, p) for p in paths if p.endswith(".save")] 43 | groundtruth_list = [] 44 | padding_features_list = [] 45 | mask_list = [] 46 | agent_index = None 47 | for path in paths: 48 | with open(path, "rb") as f: 49 | traj_list = pickle.load(f) 50 | features = [] 51 | for i in range(len(traj_list)): 52 | vec, ground_truth = build_vector(traj_list[i], i) 53 | features.append(vec) 54 | if ground_truth is not None: 55 | groundtruth_list.append(ground_truth) 56 | agent_index = i 57 | 58 | padding_features, mask = padding_trajectory(features, agent_index) 59 | padding_features_list.append(padding_features) 60 | mask_list.append(mask) 61 | features, subgraph_mask, attention_mask = global_padding(padding_features_list, mask_list) 62 | 63 | groundtruth, groundtruth_mask, max_groundtruth_length = handle_ground_truth(groundtruth_list) 64 | return features, subgraph_mask, attention_mask, groundtruth, groundtruth_mask, max_groundtruth_length 65 | 66 | 67 | def build_vector(traj: np.ndarray, id: int): 68 | ''' 69 | build vectors based on the input trajectory 70 | 71 | @input traj (np.ndarray): traject of one object 72 | 73 | @input id (int): j in the paper, integer id of P_j, indicating v_i is in P_j 74 | 75 | @return vector (np.ndarray) of shape (len(traj) - 1, 7): vector build by trajectory, each row contains (x_start, y_start, x_end, y_end, obj_type, time_stamp, j) 76 | 77 | @return ground_truth (np.ndarray): return groundtruth trajectory if the input is an agent, otherwise return None 78 | ''' 79 | # print(len(tarj)) 80 | # print(tarj) 81 | ground_truth = None 82 | vector = np.zeros((len(traj) - 1, 7)) 83 | 84 | # start coordinates (x_start, y_start) 85 | vector[:, 0] = traj[:, 3][:-1] 86 | vector[:, 1] = traj[:, 4][:-1] 87 | 88 | # end coordinates (x_end, y_end) 89 | vector[:, 2] = traj[:, 3][1:] 90 | vector[:, 3] = traj[:, 4][1:] 91 | 92 | # obj_type, time_stamp, j 93 | vector[:, 4] = OBJ_MAP[traj[0, 2]] 94 | vector[:, 5] = traj[:, 0][:-1] 95 | vector[:, 6] = id 96 | 97 | if traj[0, 2] == "AGENT": 98 | ground_truth = vector[np.where(vector[:, 5] > 2), :].squeeze(axis=0) 99 | vector = vector[np.where(vector[:, 5] <= 2), :].squeeze(axis=0) 100 | 101 | return vector, ground_truth 102 | 103 | 104 | def padding_trajectory(features: list, agent_index: int, max_seq_length: int = 49): 105 | ''' 106 | Padding the input features to max sequence length (max number of sub nodes), and swap agent index to 0 107 | 108 | @input features: raw features 109 | 110 | @input agent_index (int): if not None, swap agent feature to index 0 111 | 112 | @input max_seq_length: padding to max sequence length (num of sub nodes), default is 49 (5 sec, 0.1 sec sampling) 113 | 114 | @return padding_features of shape (len(features), maxnum_of_sub_nodes, 7) 115 | 116 | @return mask of shape (len(features), maxnum_of_sub_nodes, 7) 117 | ''' 118 | if agent_index is not None: 119 | # Swap agent index to 0 120 | tmp = features[0] 121 | features[0] = features[agent_index] 122 | features[agent_index] = tmp 123 | 124 | seq_length = [x.shape[0] for x in features] 125 | max_seq_length = max( 126 | seq_length) if max_seq_length is not None else max_seq_length 127 | mask = np.zeros((len(features), max_seq_length)) 128 | padding_features = np.zeros((len(features), max_seq_length, 7)) 129 | for i, feature in enumerate(features): 130 | mask[i, : feature.shape[0]] = 1 131 | mask[i, feature.shape[0]:] = 0 132 | padding_features[i, :, :] = np.concatenate( 133 | (feature, np.zeros((max_seq_length - feature.shape[0], 7))), axis=0) 134 | return padding_features, mask 135 | 136 | 137 | def global_padding(feature_list: list, mask_list: list): 138 | ''' 139 | padding all trajectories to the same number of nodes 140 | 141 | @input feature_list (list): input feature list 142 | 143 | @input mask_list (list): subgraph mask list 144 | 145 | @return features (np.ndarray) of shape (len(feature_list), maxnum_of_global_nodes, maxnum_of_sub_nodes, 7) 146 | 147 | @return subgraph_mask (np.ndarray) of shape (len(feature_list), maxnum_of_global_nodes, maxnum_of_sub_nodes) 148 | 149 | @return attention_mask (np.ndarray) of shape (len(feature_list), maxnum_of_global_nodes) 150 | ''' 151 | assert len(feature_list) == len(mask_list) 152 | 153 | length = len(feature_list) 154 | num_of_seqs = [feature.shape[0] for feature in feature_list] 155 | maxnum_of_seqs = max(num_of_seqs) 156 | features = np.zeros((len(feature_list), maxnum_of_seqs, 157 | feature_list[0].shape[1], feature_list[0].shape[2])) 158 | attention_mask = np.zeros((len(feature_list), maxnum_of_seqs)) 159 | subgraph_mask = np.zeros((len(feature_list), maxnum_of_seqs, feature_list[0].shape[1])) 160 | 161 | for i, f in enumerate(feature_list): 162 | features[i, :, :, :] = np.concatenate( 163 | (f, np.zeros((maxnum_of_seqs - f.shape[0], f.shape[1], f.shape[2]))), axis=0) 164 | attention_mask[i, : f.shape[0]] = 1 165 | subgraph_mask[i, : f.shape[0], :] = mask_list[i] 166 | return features, subgraph_mask, attention_mask 167 | 168 | 169 | def handle_ground_truth(groundtruth_list: list, max_groundtruth_length: int = None): 170 | ''' 171 | @input groundtruth_list (list): input groundtruth 172 | 173 | @input max_groundtruth_length (int): maximum length of groundtruth, default is 30 (padding to the same length) 174 | 175 | @return groundtruth (np.ndarray) of shape (len(groundtruth_list), max_groundtruth_length * 4): each contains (x_start, y_start, x_end, y_end) 176 | 177 | @return groundtruth_mask (np.ndarray) of shape (len(groundtruth_list), max_groundtruth_length * 4): mask where is not padding 178 | 179 | @return max_groundtruth_length (int): max ground truth length 180 | ''' 181 | groundtruth_length = [gt.shape[0] for gt in groundtruth_list] 182 | max_groundtruth_length = max(groundtruth_length) if max_groundtruth_length is None else max_groundtruth_length 183 | groundtruth = np.zeros((len(groundtruth_list), max_groundtruth_length * 4)) 184 | groundtruth_mask = np.zeros((len(groundtruth_list), max_groundtruth_length * 4)) 185 | 186 | for i, gt in enumerate(groundtruth_list): 187 | groundtruth[i, : gt.shape[0] * 4] = gt[:, :4].reshape(-1, 1).squeeze(axis=1) 188 | groundtruth_mask[i, : gt.shape[0] * 4] = 1 189 | 190 | return groundtruth, groundtruth_mask, max_groundtruth_length 191 | 192 | 193 | def load_raw_data(root_dir: str, save_path: str = "./save/features"): 194 | ''' 195 | Save raw csv to np.ndarray 196 | 197 | @input root_dir (string): root directory contains .csv data 198 | 199 | @input save_path (string): save features to this path, default "./save/features" 200 | ''' 201 | afl = ArgoverseForecastingLoader(root_dir) 202 | files = os.listdir(root_dir) 203 | for f in files: 204 | if not f.endswith(".csv"): 205 | continue 206 | seq_path = os.path.join(root_dir, f) 207 | print("Processing " + seq_path) 208 | id_list = afl.get(seq_path).track_id_list 209 | agent_traj = afl.get(seq_path).agent_traj 210 | df = afl.get(seq_path).seq_df 211 | tarj_list = [] 212 | df['TIMESTAMP'] -= df['TIMESTAMP'].min() 213 | for id in id_list: 214 | subdf = df.loc[df['TRACK_ID'] == id] 215 | tarj_list.append(subdf.drop(columns='CITY_NAME').sort_values(by=['TIMESTAMP']).to_numpy()) 216 | with open(os.path.join(save_path, f[:-4]+".save",), "wb") as f: 217 | pickle.dump(tarj_list, f) 218 | 219 | 220 | if __name__ == "__main__": 221 | pass 222 | 223 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import argparse 4 | from model import * 5 | from torch.utils.data import RandomSampler, DataLoader, SequentialSampler 6 | from torch.utils.data.distributed import DistributedSampler 7 | from tqdm import tqdm 8 | from data_process import * 9 | from sklearn.model_selection import train_test_split 10 | from transformers import get_linear_schedule_with_warmup 11 | import math 12 | import os 13 | 14 | 15 | def train(args, train_dataset, test_dataset, device): 16 | subgraph = SubGraph() 17 | globalgraph = GlobalGraph() 18 | decoder = TrajectoryDecoder(out_features=args.max_groundtruth_length * 4) 19 | 20 | subgraph.to(device) 21 | subgraph.train() 22 | subgraph.zero_grad() 23 | globalgraph.to(device) 24 | globalgraph.train() 25 | globalgraph.zero_grad() 26 | decoder.to(device) 27 | decoder.train() 28 | decoder.zero_grad() 29 | 30 | subgraph_optimizer = torch.optim.AdamW( 31 | subgraph.parameters(), lr=args.subgraph_learning_rate, weight_decay=1e-6) 32 | globalgraph_optimizer = torch.optim.AdamW( 33 | globalgraph.parameters(), lr=args.globalgraph_learning_rate, weight_decay=1e-6) 34 | decoder_optimizer = torch.optim.AdamW( 35 | decoder.parameters(), lr=args.decoder_learning_rate, weight_decay=1e-6) 36 | 37 | train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) 38 | train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) 39 | 40 | t_total = len(train_dataloader) * args.epochs 41 | subgraph_scheduler = get_linear_schedule_with_warmup(subgraph_optimizer, num_warmup_steps=0, num_training_steps=t_total) 42 | globalgraph_scheduler = get_linear_schedule_with_warmup(globalgraph_optimizer, num_warmup_steps=0, num_training_steps=t_total) 43 | decoder_scheduler = get_linear_schedule_with_warmup(decoder_optimizer, num_warmup_steps=0, num_training_steps=t_total) 44 | 45 | if args.saving_path is not None: 46 | print("*** Loading model from {} ***".format(args.saving_path)) 47 | if os.path.isfile(os.path.join(args.saving_path, "subgraph.pt")): 48 | subgraph.load_state_dict(torch.load(os.path.join(args.saving_path, "subgraph.pt"))) 49 | if os.path.isfile(os.path.join(args.saving_path, "globalgraph.pt")): 50 | globalgraph.load_state_dict(torch.load(os.path.join(args.saving_path, "globalgraph.pt"))) 51 | if os.path.isfile(os.path.join(args.saving_path, "decoder.pt")): 52 | decoder.load_state_dict(torch.load(os.path.join(args.saving_path, "decoder.pt"))) 53 | if os.path.isfile(os.path.join(args.saving_path, "subgraph.pt")): 54 | subgraph_optimizer.load_state_dict(torch.load(os.path.join(args.saving_path, "subgraph_optimizer.pt"))) 55 | if os.path.isfile(os.path.join(args.saving_path, "globalgraph.pt")): 56 | globalgraph_optimizer.load_state_dict(torch.load(os.path.join(args.saving_path, "globalgraph_optimizer.pt"))) 57 | if os.path.isfile(os.path.join(args.saving_path, "decoder.pt")): 58 | decoder_optimizer.load_state_dict(torch.load(os.path.join(args.saving_path, "decoder_optimizer.pt"))) 59 | if os.path.isfile(os.path.join(args.saving_path, "subgraph.pt")): 60 | subgraph_scheduler.load_state_dict(torch.load(os.path.join(args.saving_path, "subgraph_scheduler.pt"))) 61 | if os.path.isfile(os.path.join(args.saving_path, "globalgraph.pt")): 62 | globalgraph_scheduler.load_state_dict(torch.load(os.path.join(args.saving_path, "globalgraph_scheduler.pt"))) 63 | if os.path.isfile(os.path.join(args.saving_path, "decoder.pt")): 64 | decoder_scheduler.load_state_dict(torch.load(os.path.join(args.saving_path, "decoder_scheduler.pt"))) 65 | 66 | mse_loss = nn.MSELoss(reduction="mean") 67 | total_loss, logging_loss = 0.0, 0.0 68 | global_steps = 1 69 | print("-" * 80) 70 | print("*** Begin training ***" ) 71 | 72 | for i in tqdm(range(args.epochs), desc='Epoch: '): 73 | subgraph.zero_grad() 74 | globalgraph.zero_grad() 75 | decoder.zero_grad() 76 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 77 | 78 | for step, batch in enumerate(epoch_iterator): 79 | subgraph.train() 80 | globalgraph.train() 81 | decoder.train() 82 | 83 | features, subgraph_mask, attention_mask, groundtruth, groundtruth_mask = batch 84 | 85 | features = features.to(device) 86 | subgraph_mask = subgraph_mask.to(device) 87 | attention_mask = attention_mask.to(device) 88 | groundtruth = groundtruth.to(device) 89 | groundtruth_mask = groundtruth_mask.to(device) 90 | 91 | out = subgraph.forward(features, subgraph_mask) 92 | out = globalgraph.forward(out[:, 0, :].unsqueeze(dim=1), out, attention_mask) 93 | 94 | pred = decoder.forward(out).squeeze(1) 95 | loss = mse_loss.forward(pred * groundtruth_mask, groundtruth) 96 | loss.backward() 97 | 98 | total_loss += loss.item() 99 | if args.local_rank in [-1, 0] and args.enable_logging and global_steps % args.logging_steps == 0: 100 | print("\n\nLoss:\t {}".format( 101 | (total_loss-logging_loss)/args.logging_steps)) 102 | logging_loss = total_loss 103 | 104 | if args.local_rank in [-1, 0] and args.evaluate_during_training and global_steps % args.logging_steps == 0: 105 | evaluate(args, (subgraph, globalgraph, decoder), test_dataset, batch_size=args.eval_batch_size, device=device) 106 | 107 | if args.local_rank in [-1, 0] and global_steps % args.saving_steps == 0: 108 | save_model((subgraph, globalgraph, decoder, subgraph_optimizer, globalgraph_optimizer, decoder_optimizer, subgraph_scheduler, globalgraph_scheduler, decoder_scheduler)) 109 | 110 | subgraph_optimizer.step() 111 | globalgraph_optimizer.step() 112 | decoder_optimizer.step() 113 | subgraph_scheduler.step() 114 | globalgraph_scheduler.step() 115 | decoder_scheduler.step() 116 | subgraph.zero_grad() 117 | globalgraph.zero_grad() 118 | decoder.zero_grad() 119 | 120 | global_steps += 1 121 | 122 | if test_dataset is not None: 123 | evaluate(args, (subgraph, globalgraph, decoder), test_dataset, device=device, batch_size=args.eval_batch_size) 124 | save_model((subgraph, globalgraph, decoder, subgraph_optimizer, globalgraph_optimizer, 125 | decoder_optimizer, subgraph_scheduler, globalgraph_scheduler, decoder_scheduler)) 126 | 127 | 128 | def evaluate(args, models, dataset: torch.utils.data.TensorDataset, device, batch_size=1): 129 | print("\n*** Evaluating ***\n") 130 | eval_sampler = SequentialSampler(dataset) 131 | eval_dataloader = DataLoader( 132 | dataset, sampler=eval_sampler, batch_size=batch_size) 133 | 134 | subgraph, globalgraph, decoder = models 135 | subgraph.eval() 136 | globalgraph.eval() 137 | decoder.eval() 138 | 139 | mse_loss = nn.MSELoss() 140 | total_loss = 0.0 141 | 142 | for batch in tqdm(eval_dataloader, desc="Evaluating"): 143 | with torch.no_grad(): 144 | features, subgraph_mask, attention_mask, groundtruth, groundtruth_mask = batch 145 | 146 | features = features.to(device) 147 | subgraph_mask = subgraph_mask.to(device) 148 | attention_mask = attention_mask.to(device) 149 | groundtruth = groundtruth.to(device) 150 | groundtruth_mask = groundtruth_mask.to(device) 151 | 152 | out = subgraph.forward(features, subgraph_mask) 153 | out = globalgraph.forward(out[:, 0, :].unsqueeze(dim=1), out, attention_mask) 154 | 155 | pred = decoder.forward(out).squeeze(1) 156 | loss = mse_loss.forward(pred * groundtruth_mask, groundtruth) 157 | total_loss += loss.item() 158 | 159 | print("Eval mse loss (per point): {}".format(math.sqrt(total_loss / (len(dataset) // batch_size * args.max_groundtruth_length)))) 160 | print("-" * 80) 161 | 162 | 163 | def save_model(models: tuple): 164 | subgraph, globalgraph, decoder, subgraph_optimizer, globalgraph_optimizer, decoder_optimizer, subgraph_scheduler, globalgraph_scheduler, decoder_scheduler = models 165 | torch.save(subgraph.state_dict(), "./save/models/subgraph.pt") 166 | torch.save(globalgraph.state_dict(), "./save/models/globalgraph.pt") 167 | torch.save(decoder.state_dict(), "./save/models/decoder.pt") 168 | torch.save(subgraph_optimizer.state_dict(), "./save/models/subgraph_optimizer.pt") 169 | torch.save(globalgraph_optimizer.state_dict(), "./save/models/globalgraph_optimizer.pt") 170 | torch.save(decoder_optimizer.state_dict(), "./save/models/decoder_optimizer.pt") 171 | torch.save(subgraph_scheduler.state_dict(), "./save/models/subgraph_scheduler.pt") 172 | torch.save(globalgraph_scheduler.state_dict(), "./save/models/globalgraph_scheduler.pt") 173 | torch.save(decoder_scheduler.state_dict(), "./save/models/decoder_scheduler.pt") 174 | 175 | 176 | def build_dataset(features: np.ndarray, subgraph_mask: np.ndarray, attention_mask: np.ndarray, groundtruth: np.ndarray, groundtruth_mask: np.ndarray): 177 | print("-" * 80) 178 | print("*** Building dataset ***") 179 | 180 | train_features, test_features, train_subgraph_mask, test_subgraph_mask, train_attention_mask, test_attention_mask, train_groundtruth, test_groundtruth, train_groundtruth_mask, test_groundtruth_mask = train_test_split(features, subgraph_mask, attention_mask, groundtruth, groundtruth_mask, train_size=0.8) 181 | 182 | train_features = torch.from_numpy(train_features).to(dtype=torch.float) 183 | train_subgraph_mask = torch.from_numpy( 184 | train_subgraph_mask).to(dtype=torch.float) 185 | train_attention_mask = torch.from_numpy( 186 | train_attention_mask).to(dtype=torch.float) 187 | train_groundtruth = torch.from_numpy( 188 | train_groundtruth).to(dtype=torch.float) 189 | train_groundtruth_mask = torch.from_numpy( 190 | train_groundtruth_mask).to(dtype=torch.float) 191 | 192 | test_features = torch.from_numpy(test_features).to(dtype=torch.float) 193 | test_subgraph_mask = torch.from_numpy( 194 | test_subgraph_mask).to(dtype=torch.float) 195 | test_attention_mask = torch.from_numpy( 196 | test_attention_mask).to(dtype=torch.float) 197 | test_groundtruth = torch.from_numpy(test_groundtruth).to(dtype=torch.float) 198 | test_groundtruth_mask = torch.from_numpy( 199 | test_groundtruth_mask).to(dtype=torch.float) 200 | train_dataset = torch.utils.data.TensorDataset( 201 | train_features, 202 | train_subgraph_mask, 203 | train_attention_mask, 204 | train_groundtruth, 205 | train_groundtruth_mask 206 | ) 207 | 208 | test_dataset = torch.utils.data.TensorDataset( 209 | test_features, 210 | test_subgraph_mask, 211 | test_attention_mask, 212 | test_groundtruth, 213 | test_groundtruth_mask 214 | ) 215 | print("*** Finish building dataset ***") 216 | return train_dataset, test_dataset 217 | 218 | 219 | def main(): 220 | parser = argparse.ArgumentParser( 221 | description="Run VectorNet training and evaluating") 222 | parser.add_argument( 223 | "--epochs", 224 | default=5, 225 | type=int, 226 | help="Number of training epochs" 227 | ) 228 | parser.add_argument( 229 | "--subgraph_learning_rate", 230 | default=1e-3, 231 | type=float, 232 | help="Learning rate for subgraph" 233 | ) 234 | parser.add_argument( 235 | "--globalgraph_learning_rate", 236 | default=1e-3, 237 | type=float, 238 | help="Learning rate for globalgraph" 239 | ) 240 | parser.add_argument( 241 | "--decoder_learning_rate", 242 | default=1e-3, 243 | type=float, 244 | help="Learning rate for decoder" 245 | ) 246 | parser.add_argument( 247 | "--root_dir", 248 | default=None, 249 | required=True, 250 | type=str, 251 | help="Path to data root directory" 252 | ) 253 | parser.add_argument( 254 | "--feature_path", 255 | default=None, 256 | type=str, 257 | help="Path to feature directory" 258 | ) 259 | parser.add_argument( 260 | "--saving_path", 261 | default=None, 262 | type=str, 263 | help="Path to save model" 264 | ) 265 | parser.add_argument( 266 | "--logging_steps", 267 | default=10, 268 | type=int, 269 | help="Number of logging steps" 270 | ) 271 | parser.add_argument( 272 | "--warmup_steps", 273 | default=0, 274 | type=int, 275 | help="Number of warmup steps" 276 | ) 277 | parser.add_argument( 278 | "--saving_steps", 279 | default=100, 280 | type=int, 281 | help="Number of saving steps" 282 | ) 283 | parser.add_argument( 284 | "--no_cuda", 285 | action="store_true", 286 | help="Whether not to use CUDA when available" 287 | ) 288 | parser.add_argument( 289 | "--max_groundtruth_length", 290 | default=30, 291 | help="Maximum length of groundtruth" 292 | ) 293 | parser.add_argument( 294 | "--train_batch_size", 295 | type=int, 296 | default=2, 297 | help="train batch size" 298 | ) 299 | parser.add_argument( 300 | "--eval_batch_size", 301 | type=int, 302 | default=1, 303 | help="eval batch size" 304 | ) 305 | parser.add_argument( 306 | "--enable_logging", 307 | action="store_true", 308 | help="whether enable logging" 309 | ) 310 | parser.add_argument( 311 | "--local_rank", 312 | default=-1, 313 | help="local rank for distributed training" 314 | ) 315 | parser.add_argument( 316 | "--evaluate_during_training", 317 | action="store_true", 318 | help="Run evaluation during training at each logging step" 319 | ) 320 | 321 | args = parser.parse_args() 322 | 323 | if args.local_rank == -1 or args.no_cuda: 324 | # Data parallel or CPU training 325 | device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 326 | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() 327 | else: 328 | torch.cuda.set_device(args.local_rank) 329 | device = torch.device("cuda", args.local_rank) 330 | torch.distributed.init_process_group(backend="nccl") 331 | args.n_gpu = 1 332 | 333 | print("*** Process rank: {}, device: {}, n_gpu: {}, distributed training: {} ***".format(args.local_rank, device, args.n_gpu, bool(args.local_rank != -1))) 334 | 335 | print("*** Loading features ***") 336 | features, subgraph_mask, attention_mask, groundtruth, groundtruth_mask, max_groundtruth_length = load_features(root_dir=args.root_dir, feature_path=args.feature_path) 337 | args.max_groundtruth_length = max_groundtruth_length 338 | print("*** Finish loading features ***") 339 | 340 | train_dataset, test_dataset = build_dataset( 341 | features, subgraph_mask, attention_mask, groundtruth, groundtruth_mask) 342 | 343 | train(args, train_dataset, test_dataset, device) 344 | 345 | 346 | if __name__ == "__main__": 347 | main() 348 | 349 | --------------------------------------------------------------------------------