├── ModelNet ├── __pycache__ │ ├── msma.cpython-37.pyc │ └── utils.cpython-37.pyc ├── msma.py └── utils.py ├── README.md ├── assets ├── arch.png ├── delay.png ├── noise.png ├── performance_compr.png ├── s700_mpr0.png ├── s700_mpr2.png ├── s700_mpr4.png ├── s700_mpr6.png └── s700_mpr8.png ├── carla_data └── Town03.osm ├── dataloader ├── __pycache__ │ └── carla_scene_process.cpython-37.pyc ├── carla_scene_mining.py ├── carla_scene_process.py ├── utils │ ├── __pycache__ │ │ ├── lane_sampling.cpython-37.pyc │ │ ├── lane_segment.cpython-37.pyc │ │ └── load_xml.cpython-37.pyc │ ├── lane_sampling.py │ ├── lane_segment.py │ └── load_xml.py └── visualization.py ├── losses ├── __pycache__ │ ├── get_anchors.cpython-37.pyc │ ├── loss.cpython-37.pyc │ ├── msma_loss.cpython-37.pyc │ ├── mtp_loss.cpython-37.pyc │ └── multipath_loss.cpython-37.pyc ├── get_anchors.py ├── hivt_loss.py ├── msma_loss.py ├── mtp_loss.py └── multipath_loss.py ├── metrics ├── __pycache__ │ ├── ade.cpython-37.pyc │ ├── fde.cpython-37.pyc │ ├── metric.cpython-37.pyc │ └── mr.cpython-37.pyc ├── ade.py ├── fde.py ├── metric.py └── mr.py ├── train.py └── utils ├── __pycache__ └── optim_schedule.cpython-37.pyc ├── optim_schedule.py └── viz.py /ModelNet/__pycache__/msma.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/ModelNet/__pycache__/msma.cpython-37.pyc -------------------------------------------------------------------------------- /ModelNet/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/ModelNet/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /ModelNet/msma.py: -------------------------------------------------------------------------------- 1 | #test overall model architecture 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | import sys 7 | 8 | from ModelNet.utils import MLP, bivariate_gaussian_activation 9 | # from utils import MLP 10 | from typing import Optional, Tuple, Union, Dict 11 | import math 12 | from torch_scatter import scatter_mean, scatter_add 13 | 14 | from dataloader.carla_scene_process import CarlaData 15 | from itertools import product 16 | from torch_geometric.utils import subgraph, add_self_loops 17 | 18 | class Base_Net(nn.Module): 19 | def __init__(self, 20 | ip_dim: int=2, 21 | historical_steps: int=30, 22 | embed_dim: int=16, 23 | temp_ff: int=64, 24 | spat_hidden_dim: int=64, 25 | spat_out_dim: int=64, 26 | edge_attr_dim: int=2, 27 | map_out_dim: int=64, 28 | lane_dim: int = 2, 29 | map_local_radius: float=30., 30 | decoder_hidden_dim: int=64, 31 | num_heads: int = 8, 32 | dropout: float = 0.1, 33 | num_temporal_layers: int = 4, 34 | use_variance: bool = False, 35 | device = 'cpu', 36 | commu_only = False, 37 | sensor_only = False, 38 | prediction_mode = None, 39 | ) -> None: 40 | super(Base_Net, self).__init__() 41 | self.ip_dim = ip_dim 42 | self.historical_steps = historical_steps 43 | self.embed_dim = embed_dim 44 | self.device = device 45 | self.local_radius = map_local_radius 46 | self.commu_only = commu_only 47 | self.sensor_only = sensor_only 48 | self.prediction_mode = prediction_mode 49 | 50 | if self.prediction_mode == "temp_only": 51 | decoder_in_dim = embed_dim 52 | elif self.prediction_mode == "temp_spat": 53 | decoder_in_dim = spat_out_dim 54 | else: 55 | decoder_in_dim = spat_out_dim+map_out_dim 56 | 57 | #input embedding 58 | self.ip_emb_cav = MLP(ip_dim, embed_dim) 59 | self.ip_emb_commu = MLP(ip_dim, embed_dim) 60 | self.ip_emb_sensor = MLP(ip_dim, embed_dim) 61 | self.ip_emb_fuse = MLP(ip_dim, embed_dim) 62 | #temporal encoders 63 | self.temp_encoder = TemporalEncoder(historical_steps=historical_steps, 64 | embed_dim=embed_dim, 65 | device=device, 66 | num_heads=num_heads, 67 | num_layers=num_temporal_layers, 68 | temp_ff=temp_ff, 69 | dropout=dropout) 70 | self.feature_fuse = FeatureFuse(embed_dim=embed_dim, 71 | num_heads=num_heads, 72 | dropout=dropout) 73 | self.spat_encoder = GAT(in_dim=embed_dim, 74 | hidden_dim=spat_hidden_dim, 75 | out_dim=spat_out_dim, 76 | edge_attr_dim=edge_attr_dim, 77 | device=device, 78 | num_heads=num_heads, 79 | dropout=dropout) 80 | self.map_encoder = MapEncoder(lane_dim=lane_dim, 81 | v_dim=spat_out_dim, 82 | out_dim=map_out_dim, 83 | edge_attr_dim=edge_attr_dim, 84 | num_heads=num_heads, 85 | device=device, 86 | dropout=dropout) 87 | self.decoder = PredictionDecoder(encoding_size=decoder_in_dim, 88 | hidden_size=decoder_hidden_dim, 89 | num_modes=5, 90 | op_len=50, 91 | use_variance=use_variance) 92 | 93 | def forward(self, data: CarlaData): 94 | 95 | #temporal encoding 96 | x_cav, x_commu, x_sensor = data.x_cav, data.x_commu, data.x_sensor #overlapping among different modes 97 | cav_mask, commu_mask, sensor_mask = data.cav_mask, data.commu_mask, data.sensor_mask 98 | rotate_imat = data.rotate_imat 99 | x_cav = torch.bmm(x_cav, rotate_imat[cav_mask]) 100 | x_commu = torch.bmm(x_commu, rotate_imat[commu_mask]) 101 | x_sensor = torch.bmm(x_sensor, rotate_imat[sensor_mask]) 102 | 103 | x_cav_, x_commu_, x_sensor_ = self.ip_emb_cav(x_cav), self.ip_emb_commu(x_commu), self.ip_emb_sensor(x_sensor) 104 | cav_out, commu_out, sensor_out = self.temp_encoder(x_cav_, x_commu_, x_sensor_) 105 | 106 | #convert back to original num_nodes given masks 107 | node_features_all = torch.zeros((data.num_nodes, self.embed_dim)).to(self.device) 108 | node_features_all[cav_mask] = cav_out 109 | node_features_all[commu_mask] = commu_out 110 | node_features_all[sensor_mask] = sensor_out 111 | #fuse sensor&commu encodings 112 | mask_fuse = (commu_mask & sensor_mask) 113 | commu_emd, sensor_emd = self.get_overlap_feature(data, commu_out, sensor_out, mask_fuse, self.embed_dim) 114 | # commu_relpos, sensor_relpos = self.get_overlap_feature(data, data.x_commu_ori, data.x_sensor_ori, mask_fuse, self.ip_dim) 115 | # relpos_emd = self.ip_emb_fuse(sensor_relpos-commu_relpos) 116 | 117 | if self.commu_only: 118 | node_features_all[commu_mask] = commu_out 119 | # data.y[commu_mask] = data.y_commu 120 | elif self.sensor_only: 121 | node_features_all[sensor_mask] = sensor_out 122 | elif sum(mask_fuse)>0: 123 | node_features_all[mask_fuse] = self.feature_fuse(commu_emd, sensor_emd) 124 | 125 | mask_all = (cav_mask | commu_mask | sensor_mask) 126 | 127 | if self.prediction_mode == "temp_only": 128 | predictions = self.decoder(node_features_all[mask_all]) #'traj':[nodes_of_interest, 5, 50, 2], 'log_probs':[nodes_of_interest, 5] 129 | return predictions, mask_all 130 | 131 | edge_index, _ = subgraph(subset=mask_all, edge_index=data.edge_index) 132 | edge_index, _ = add_self_loops(edge_index, num_nodes=data.num_nodes) 133 | edge_attr = data['positions'][edge_index[0], 49] - data['positions'][edge_index[1], 49] 134 | # edge_attr = torch.bmm(edge_attr.unsqueeze(-2), rotate_imat[edge_index[1]]).squeeze(-2) 135 | spat_out = self.spat_encoder(node_features_all.view(data.num_nodes,-1), edge_index, edge_attr) #[num_nodes, 64] 136 | 137 | if self.prediction_mode == "temp_spat": 138 | predictions = self.decoder(spat_out[mask_all]) #'traj':[nodes_of_interest, 5, 50, 2], 'log_probs':[nodes_of_interest, 5] 139 | return predictions, mask_all 140 | #AL encoding 141 | map_out = self.map_encoder(data, spat_out, mask_all) #[num_nodes, 64] 142 | final_emd = torch.cat((spat_out, map_out), dim=-1) #[num_nodes, 128] 143 | 144 | predictions = self.decoder(final_emd[mask_all]) #'traj':[nodes_of_interest, 5, 50, 2], 'log_probs':[nodes_of_interest, 5] 145 | return predictions, mask_all 146 | 147 | def get_overlap_feature(self, data, commu_f, sensor_f, mask_fuse, dim): 148 | commu_mask, sensor_mask = data.commu_mask, data.sensor_mask 149 | commu_feature = torch.zeros((data.num_nodes, dim)).to(self.device) 150 | sensor_feature = torch.zeros((data.num_nodes, dim)).to(self.device) 151 | commu_feature[commu_mask] = commu_f 152 | sensor_feature[sensor_mask] = sensor_f 153 | 154 | return commu_feature[mask_fuse], sensor_feature[mask_fuse] 155 | 156 | 157 | class TemporalEncoder(nn.Module): 158 | ''' 159 | for each agent, only one fused channel instead of three 160 | ''' 161 | def __init__(self, 162 | historical_steps: int, 163 | embed_dim: int, 164 | device, 165 | num_heads: int=8, 166 | num_layers: int=4, 167 | temp_ff: int=64, 168 | dropout: float=0.1) -> None: 169 | super(TemporalEncoder, self).__init__() 170 | self.embed_dim = embed_dim 171 | self.device = device 172 | self.historical_steps = historical_steps 173 | encoder_layer_cav = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=temp_ff, dropout=dropout, batch_first=True) 174 | self.transformer_encoder_cav = nn.TransformerEncoder(encoder_layer=encoder_layer_cav, num_layers=num_layers, 175 | norm=nn.LayerNorm(embed_dim)) 176 | encoder_layer_sensor = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=temp_ff, dropout=dropout, batch_first=True) 177 | self.transformer_encoder_sensor = nn.TransformerEncoder(encoder_layer=encoder_layer_sensor, num_layers=num_layers, 178 | norm=nn.LayerNorm(embed_dim)) 179 | encoder_layer_commu = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=temp_ff, dropout=dropout, batch_first=True) 180 | self.transformer_encoder_commu= nn.TransformerEncoder(encoder_layer=encoder_layer_commu, num_layers=num_layers, 181 | norm=nn.LayerNorm(embed_dim)) 182 | self.cls_token_cav = nn.Parameter(torch.Tensor(1, 1, embed_dim)) 183 | self.cls_token_commu = nn.Parameter(torch.Tensor(1, 1, embed_dim)) 184 | self.cls_token_sensor = nn.Parameter(torch.Tensor(1, 1, embed_dim)) 185 | 186 | self.pos_embed_cav = nn.Parameter(torch.Tensor(1, historical_steps + 1, embed_dim)) 187 | self.pos_embed_commu = nn.Parameter(torch.Tensor(1, historical_steps + 1, embed_dim)) 188 | self.pos_embed_sensor = nn.Parameter(torch.Tensor(1, historical_steps + 1, embed_dim)) 189 | 190 | nn.init.normal_(self.cls_token_cav, mean=0., std=.02) 191 | nn.init.normal_(self.cls_token_commu, mean=0., std=.02) 192 | nn.init.normal_(self.cls_token_sensor, mean=0., std=.02) 193 | nn.init.normal_(self.pos_embed_cav, mean=0., std=.02) 194 | nn.init.normal_(self.pos_embed_commu, mean=0., std=.02) 195 | nn.init.normal_(self.pos_embed_sensor, mean=0., std=.02) 196 | # self.apply(init_weights) 197 | self.dropout = nn.Dropout(dropout) 198 | self.layer_norm = nn.LayerNorm(embed_dim) 199 | self.linear = nn.Linear(embed_dim * 2, embed_dim) 200 | 201 | def forward(self, x_cav, x_commu, x_sensor): 202 | """ 203 | input [batch, seq, feature] 204 | """ 205 | num_sensor, seq_len = x_sensor.shape[0], x_sensor.shape[1] 206 | assert seq_len == self.historical_steps 207 | 208 | x_cav, x_commu, x_sensor = self._expand_cls_token(x_cav, x_commu, x_sensor) 209 | 210 | x_cav = x_cav + self.pos_embed_cav 211 | x_sensor = x_sensor + self.pos_embed_sensor 212 | x_commu = x_commu + self.pos_embed_commu 213 | 214 | # Apply dropout and layer normalization 215 | x_cav_t = self.layer_norm(self.dropout(x_cav)) 216 | x_sensor_t = self.layer_norm(self.dropout(x_sensor)) 217 | x_commu_t = self.layer_norm(self.dropout(x_commu)) 218 | 219 | # Apply the transformers 220 | x_cav_temp = self.transformer_encoder_cav(x_cav_t) 221 | x_commu_temp = self.transformer_encoder_commu(x_commu_t) 222 | x_sensor_temp = self.transformer_encoder_sensor(x_sensor_t) 223 | 224 | return x_cav_temp[:,-1,:], x_commu_temp[:,-1,:], x_sensor_temp[:,-1,:] #encoding at last timestep 225 | 226 | def _expand_cls_token(self, x_cav, x_commu, x_sensor): 227 | expand_cls_token_cav= self.cls_token_cav.expand(x_cav.shape[0], -1, -1) 228 | expand_cls_token_commu= self.cls_token_commu.expand(x_commu.shape[0], -1, -1) 229 | expand_cls_token_sensor= self.cls_token_sensor.expand(x_sensor.shape[0], -1, -1) 230 | 231 | x_cav = torch.cat((x_cav, expand_cls_token_cav), dim=1) 232 | x_commu = torch.cat((x_commu, expand_cls_token_commu), dim=1) 233 | x_sensor = torch.cat((x_sensor, expand_cls_token_sensor), dim=1) 234 | 235 | return x_cav, x_commu, x_sensor 236 | 237 | class FeatureFuse(nn.Module): 238 | """ 239 | cross attention module 240 | """ 241 | def __init__(self, 242 | embed_dim, 243 | num_heads, 244 | dropout=0.1): 245 | super(FeatureFuse, self).__init__() 246 | self.embed_dim = embed_dim 247 | self.num_heads = num_heads 248 | self.lin_q = nn.Linear(embed_dim, embed_dim) 249 | self.lin_k = nn.Linear(embed_dim, embed_dim) 250 | self.lin_v = nn.Linear(embed_dim, embed_dim) 251 | self.lin_self = nn.Linear(embed_dim, embed_dim) 252 | self.lin_ih = nn.Linear(embed_dim, embed_dim) 253 | self.lin_hh = nn.Linear(embed_dim, embed_dim) 254 | self.attn_drop = nn.Dropout(dropout) 255 | self.softmax = nn.Softmax(dim=1) 256 | 257 | def forward(self, commu_enc, sensor_enc): 258 | query = self.lin_q(sensor_enc).view(-1, self.num_heads, self.embed_dim // self.num_heads) 259 | key = self.lin_k(commu_enc).view(-1, self.num_heads, self.embed_dim // self.num_heads) 260 | value = self.lin_v(commu_enc).view(-1, self.num_heads, self.embed_dim // self.num_heads) 261 | scale = (self.embed_dim // self.num_heads) ** 0.5 262 | alpha = (query * key).sum(dim=-1) / scale 263 | alpha = self.softmax(alpha) 264 | alpha = self.attn_drop(alpha) 265 | commu_att = (value * alpha.unsqueeze(-1)).reshape(-1, self.embed_dim) 266 | w = torch.sigmoid(self.lin_ih(sensor_enc) + self.lin_hh(commu_att)) 267 | fused_enc = w * self.lin_self(sensor_enc) + (1-w) * commu_att 268 | return fused_enc 269 | 270 | class GAT(nn.Module): 271 | def __init__(self, in_dim, hidden_dim, out_dim, edge_attr_dim, device, num_heads=8, dropout=0.1): 272 | super(GAT, self).__init__() 273 | 274 | self.device = device 275 | self.attention_layers = nn.ModuleList( 276 | [GATlayer(in_dim, hidden_dim, edge_attr_dim) for _ in range(num_heads)] 277 | ) 278 | self.out_att = GATlayer(hidden_dim*num_heads, out_dim, edge_attr_dim) 279 | self.dropout = nn.Dropout(dropout) 280 | 281 | def forward(self, X, edge_index, edge_attr): 282 | x = X 283 | 284 | # Concatenate multi-head attentions 285 | x = torch.cat([att(x, edge_index, edge_attr) for att in self.attention_layers], dim=1) 286 | x = F.elu(x) 287 | x = self.dropout(x) 288 | x = self.out_att(x, edge_index, edge_attr) # Final attention aggregation 289 | return F.log_softmax(x, dim=1) 290 | 291 | class GATlayer(nn.Module): 292 | def __init__(self, 293 | embed_dim: int, 294 | out_dim: int, 295 | edge_attr_dim: int, 296 | dropout: float=0.1) -> None: 297 | super(GATlayer, self).__init__() 298 | 299 | self.W = nn.Linear(embed_dim, out_dim, bias=False) 300 | self.a = nn.Linear(2*out_dim + edge_attr_dim, 1, bias=False) 301 | self.edge_attr_dim = edge_attr_dim 302 | self.dropout = nn.Dropout(dropout) 303 | self.out_transform = nn.Linear(out_dim, out_dim, bias=False) 304 | 305 | def forward(self, 306 | X: torch.Tensor, 307 | edge_index: torch.Tensor, 308 | edge_attr: torch.Tensor): 309 | #transform node features 310 | h = self.W(X) 311 | N = h.size(0) 312 | attn_input = self._prepare_attention_input(h, edge_index, edge_attr) 313 | score_per_edge = F.leaky_relu(self.a(attn_input)).squeeze(1) # Calculate attention coefficients 314 | 315 | #apply dropout to attention weights 316 | score_per_edge = self.dropout(score_per_edge) 317 | # softmax 318 | # Calculate the numerator. Make logits <= 0 so that e^logit <= 1 (this will improve the numerical stability) 319 | score_per_edge = score_per_edge - score_per_edge.max() 320 | exp_score_per_edge = score_per_edge.exp() 321 | 322 | neigborhood_aware_denominator = scatter_add(exp_score_per_edge, edge_index[0], dim=0, dim_size=N) 323 | neigborhood_aware_denominator = neigborhood_aware_denominator.index_select(0, edge_index[0]) 324 | attentions_per_edge = exp_score_per_edge / (neigborhood_aware_denominator + 1e-16) 325 | 326 | # Apply attention weights to source node features and perform message passing 327 | out_src = h.index_select(0,edge_index[1]) * attentions_per_edge.unsqueeze(dim=1) 328 | h_prime = scatter_add(out_src, edge_index[0], dim=0, dim_size=N) 329 | 330 | # Apply activation function 331 | out = F.elu(h_prime) 332 | return out 333 | 334 | def _prepare_attention_input(self, h, edge_index, edge_attr): 335 | ''' 336 | h has shape [N, out_dim] 337 | ''' 338 | src, tgt = edge_index 339 | attn_input = torch.cat([h.index_select(0,src), h.index_select(0,tgt), edge_attr], dim=1) 340 | 341 | return attn_input 342 | 343 | class MapEncoder(nn.Module): 344 | def __init__(self, 345 | lane_dim: int, 346 | v_dim: int, 347 | out_dim: int, 348 | edge_attr_dim: int, 349 | num_heads: int, 350 | device: str, 351 | local_radius: float=30., 352 | dropout: float=0.1) -> None: 353 | super(MapEncoder, self).__init__() 354 | self.local_radius = local_radius 355 | self.device = device 356 | self.attention_layers = nn.ModuleList( 357 | [MapEncoderLayer(out_dim, v_dim, edge_attr_dim) for _ in range(num_heads)] 358 | ) 359 | self.lane_emb = MLP(lane_dim, v_dim) #out_dim = v_enc.size(1) 360 | self.edge_attr_dim = edge_attr_dim 361 | self.dropout = nn.Dropout(dropout) 362 | self.out_transform = nn.Linear(out_dim*num_heads, out_dim, bias=False) 363 | 364 | def forward(self, data: CarlaData, v_enc: torch.Tensor, v_mask: torch.Tensor): 365 | 366 | lane = data.lane_vectors 367 | 368 | lane_actor_mask = torch.cat((v_mask, (torch.ones(lane.size(0))==1).to(self.device)), dim=0) 369 | data.lane_actor_index[0] += data.num_nodes #lane_actor_index[0]:lane index, lane_actor_index[1]:actor index 370 | lane_actor_index, lane_actor_attr = subgraph(subset=lane_actor_mask, 371 | edge_index=data.lane_actor_index, edge_attr=data.lane_actor_attr) 372 | lane = torch.bmm(lane[lane_actor_index[0]-data.num_nodes].unsqueeze(-2), data.rotate_imat[lane_actor_index[1]]).squeeze(-2) 373 | 374 | lane_enc = self.lane_emb(lane) 375 | lane_actor_enc = torch.cat((v_enc, lane_enc), dim=0) #shape:[num_veh+num_lane, v_dim] 376 | # Concat multi-head attentions 377 | out = torch.cat([att(lane_actor_enc, data.num_nodes, lane.size(0), lane_actor_index, lane_actor_attr) for att in self.attention_layers], dim=1) 378 | out = F.elu(out) 379 | out = self.dropout(out) 380 | out = self.out_transform(out) 381 | 382 | return out 383 | 384 | class MapEncoderLayer(nn.Module): 385 | def __init__(self, 386 | v_dim: int, 387 | out_dim: int, 388 | edge_attr_dim: int, 389 | dropout: float=0.1) -> None: 390 | super(MapEncoderLayer, self).__init__() 391 | 392 | self.W = nn.Linear(v_dim, out_dim, bias=False) 393 | self.a = nn.Linear(2*out_dim + edge_attr_dim, 1, bias=False) 394 | self.dropout = nn.Dropout(dropout) 395 | 396 | def forward(self, 397 | lane_actor_enc: torch.Tensor, 398 | num_veh: int, 399 | num_lane: int, 400 | lane_actor_index: torch.Tensor, 401 | lane_actor_attr: torch.Tensor): 402 | #transform node features 403 | h = self.W(lane_actor_enc) 404 | N = h.size(0) 405 | assert N == num_veh+num_lane 406 | 407 | attn_input = self._prepare_attention_input(h, num_veh,lane_actor_index, lane_actor_attr) 408 | score_per_edge = F.leaky_relu(self.a(attn_input)).squeeze(1) # Calculate attention coefficients 409 | 410 | #apply dropout to attention weights 411 | score_per_edge = self.dropout(score_per_edge) 412 | # softmax 413 | # Calculate the numerator. Make logits <= 0 so that e^logit <= 1 (this will improve the numerical stability) 414 | score_per_edge = score_per_edge - score_per_edge.max() 415 | exp_score_per_edge = score_per_edge.exp() 416 | 417 | neigborhood_aware_denominator = scatter_add(exp_score_per_edge, lane_actor_index[1], dim=0, dim_size=num_veh) 418 | neigborhood_aware_denominator = neigborhood_aware_denominator.index_select(0, lane_actor_index[1]) 419 | attentions_per_edge = exp_score_per_edge / (neigborhood_aware_denominator + 1e-16) 420 | 421 | out_src = h[num_veh:] * attentions_per_edge.unsqueeze(dim=1) #shape[num_lane] 422 | out = scatter_add(out_src, lane_actor_index[1], dim=0, dim_size=num_veh) 423 | assert out.shape[0] == num_veh 424 | 425 | # Apply activation function 426 | out = F.elu(out) 427 | return out 428 | 429 | def _prepare_attention_input(self, h, num_v, edge_index, edge_attr): 430 | ''' 431 | h has shape [N, out_dim] 432 | ''' 433 | src, tgt = edge_index 434 | attn_input = torch.cat([h[num_v:], h[:num_v].index_select(0,tgt), edge_attr], dim=1) 435 | 436 | return attn_input 437 | 438 | class PredictionDecoder(nn.Module): 439 | 440 | def __init__(self, 441 | encoding_size: int, 442 | hidden_size: int=64, 443 | num_modes: int=5, 444 | op_len: int=50, 445 | use_variance: bool=False) -> None: 446 | super(PredictionDecoder, self).__init__() 447 | 448 | self.op_dim = 5 if use_variance else 2 449 | self.op_len = op_len 450 | self.num_modes = num_modes 451 | self.use_variance = use_variance 452 | self.hidden = nn.Linear(encoding_size, hidden_size) 453 | self.traj_op = nn.Sequential( 454 | nn.Linear(hidden_size, hidden_size), 455 | nn.LayerNorm(hidden_size), 456 | nn.ReLU(inplace=True), 457 | nn.Linear(hidden_size, hidden_size), 458 | nn.LayerNorm(hidden_size), 459 | nn.ReLU(inplace=True), 460 | nn.Linear(hidden_size, self.op_len * self.op_dim * self.num_modes)) 461 | self.prob_op = nn.Sequential( 462 | nn.Linear(hidden_size, hidden_size), 463 | nn.LayerNorm(hidden_size), 464 | nn.ReLU(inplace=True), 465 | nn.Linear(hidden_size, hidden_size), 466 | nn.LayerNorm(hidden_size), 467 | nn.ReLU(inplace=True), 468 | nn.Linear(hidden_size, self.num_modes)) 469 | 470 | self.leaky_relu = nn.LeakyReLU(0.01) 471 | self.log_softmax = nn.LogSoftmax(dim=1) 472 | 473 | 474 | def forward(self, agg_encoding: torch.Tensor) -> Dict: 475 | """ 476 | Forward pass for prediction decoder 477 | :param agg_encoding: aggregated context encoding 478 | :return predictions: dictionary with 'traj': K predicted trajectories and 479 | 'probs': K corresponding probabilities 480 | """ 481 | 482 | h = self.leaky_relu(self.hidden(agg_encoding)) 483 | num_vehs = h.shape[0] #n_v 484 | traj = self.traj_op(h) #[n_v, 1250] 485 | probs = self.log_softmax(self.prob_op(h)) #[n_v, 5] 486 | traj = traj.reshape(num_vehs, self.num_modes, self.op_len, self.op_dim) 487 | probs = probs.squeeze(dim=-1) 488 | traj = bivariate_gaussian_activation(traj) if self.use_variance else traj 489 | 490 | predictions = {'traj':traj, 'log_probs':probs} 491 | 492 | return predictions 493 | -------------------------------------------------------------------------------- /ModelNet/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MLP(nn.Module): 5 | def __init__(self, 6 | in_dim: int, 7 | out_dim: int) -> None: 8 | super(MLP, self).__init__() 9 | self.embed = nn.Sequential( 10 | nn.Linear(in_dim, out_dim), 11 | nn.LayerNorm(out_dim), 12 | nn.ReLU(inplace=True), 13 | nn.Linear(out_dim, out_dim), 14 | nn.LayerNorm(out_dim), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(out_dim, out_dim), 17 | nn.LayerNorm(out_dim) 18 | ) 19 | self.apply(init_weights) 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | return self.embed(x) 23 | 24 | def init_weights(m: nn.Module) -> None: 25 | if isinstance(m, nn.Linear): 26 | nn.init.xavier_uniform_(m.weight) 27 | if m.bias is not None: 28 | nn.init.zeros_(m.bias) 29 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | fan_in = m.in_channels / m.groups 31 | fan_out = m.out_channels / m.groups 32 | bound = (6.0 / (fan_in + fan_out)) ** 0.5 33 | nn.init.uniform_(m.weight, -bound, bound) 34 | if m.bias is not None: 35 | nn.init.zeros_(m.bias) 36 | elif isinstance(m, nn.Embedding): 37 | nn.init.normal_(m.weight, mean=0.0, std=0.02) 38 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 39 | nn.init.ones_(m.weight) 40 | nn.init.zeros_(m.bias) 41 | elif isinstance(m, nn.LayerNorm): 42 | nn.init.ones_(m.weight) 43 | nn.init.zeros_(m.bias) 44 | elif isinstance(m, nn.MultiheadAttention): 45 | if m.in_proj_weight is not None: 46 | fan_in = m.embed_dim 47 | fan_out = m.embed_dim 48 | bound = (6.0 / (fan_in + fan_out)) ** 0.5 49 | nn.init.uniform_(m.in_proj_weight, -bound, bound) 50 | else: 51 | nn.init.xavier_uniform_(m.q_proj_weight) 52 | nn.init.xavier_uniform_(m.k_proj_weight) 53 | nn.init.xavier_uniform_(m.v_proj_weight) 54 | if m.in_proj_bias is not None: 55 | nn.init.zeros_(m.in_proj_bias) 56 | nn.init.xavier_uniform_(m.out_proj.weight) 57 | if m.out_proj.bias is not None: 58 | nn.init.zeros_(m.out_proj.bias) 59 | if m.bias_k is not None: 60 | nn.init.normal_(m.bias_k, mean=0.0, std=0.02) 61 | if m.bias_v is not None: 62 | nn.init.normal_(m.bias_v, mean=0.0, std=0.02) 63 | elif isinstance(m, nn.LSTM): 64 | for name, param in m.named_parameters(): 65 | if 'weight_ih' in name: 66 | for ih in param.chunk(4, 0): 67 | nn.init.xavier_uniform_(ih) 68 | elif 'weight_hh' in name: 69 | for hh in param.chunk(4, 0): 70 | nn.init.orthogonal_(hh) 71 | elif 'weight_hr' in name: 72 | nn.init.xavier_uniform_(param) 73 | elif 'bias_ih' in name: 74 | nn.init.zeros_(param) 75 | elif 'bias_hh' in name: 76 | nn.init.zeros_(param) 77 | nn.init.ones_(param.chunk(4, 0)[1]) 78 | elif isinstance(m, nn.GRU): 79 | for name, param in m.named_parameters(): 80 | if 'weight_ih' in name: 81 | for ih in param.chunk(3, 0): 82 | nn.init.xavier_uniform_(ih) 83 | elif 'weight_hh' in name: 84 | for hh in param.chunk(3, 0): 85 | nn.init.orthogonal_(hh) 86 | elif 'bias_ih' in name: 87 | nn.init.zeros_(param) 88 | elif 'bias_hh' in name: 89 | nn.init.zeros_(param) 90 | 91 | def bivariate_gaussian_activation(ip: torch.Tensor) -> torch.Tensor: 92 | """ 93 | Activation function to output parameters of bivariate Gaussian distribution 94 | """ 95 | mu_x = ip[..., 0:1] 96 | mu_y = ip[..., 1:2] 97 | sig_x = ip[..., 2:3] 98 | sig_y = ip[..., 3:4] 99 | rho = ip[..., 4:5] 100 | sig_x = torch.exp(sig_x) 101 | sig_y = torch.exp(sig_y) 102 | rho = torch.tanh(rho) 103 | out = torch.cat([mu_x, mu_y, sig_x, sig_y, rho], dim = -1) 104 | 105 | return out -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MSMA 2 | 3 | we focus on traffic scenarios where a connected and autonomous vehicle (CAV) serves as the central agent, utilizing both sensors and communication technologies to perceive its surrounding traffics consisting of autonomous vehicles, connected vehicles, and human-driven vehicles. 4 | 5 | ## Overview 6 | ![](assets/arch.png) 7 | 8 | ## Gettting Started 9 | 10 | 1\. Clone this repository: 11 | ``` 12 | git clone https://github.com/xichennn/MSMA.git 13 | cd MSMA 14 | ``` 15 | 16 | 2\. Create a conda environment and install the dependencies: 17 | ``` 18 | conda create -n MSMA python=3.8 19 | conda activate MSMA 20 | conda install pytorch==1.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge 21 | 22 | # install other dependencies 23 | pip install pytorch-lightning 24 | pip install torch-scatter torch-geometric -f https://pytorch-geometric.com/whl/torch-2.1.0+cu121.html 25 | ``` 26 | 3\. Download the [CARLA simulation data](https://drive.google.com/file/d/1bxIS4O1ZF3AvKqnsRTYzy5xg7bVwvL-w/view?usp=drive_link) and move it to the carla_data dir. 27 | 28 | ## Training 29 | In train.py, There are 3 hyperparameters that control the data processing: 30 | - mpr: determines the mpr of the connected vehicles in the dataset 31 | - delay_frame: determines the latency ranging from 1 to 15 frames (0.1~1.5s) 32 | - noise_var: determines the Gaussian noise variance ranging from 0 to 0.5 \ 33 | 34 | and there are two in the model arguments that control the data fusion: 35 | - commu_only: when set to true, only data from connected vehicles are utilized 36 | - sensor_only: when set to true, only data from AV sensors are utilized \ 37 | when both commu_only and sensor_only are set to False, data from both sources will be integrated 38 | 39 | ## Results 40 | 41 | ### Quantitative Results 42 |

43 | 44 | 45 |

46 | 47 | | Metrics | MPR=0 | MPR=0.2 | MPR=0.4 | MPR=0.6 |MPR=0.8 | 48 | | :--- | :---: | :---: | :---: |:---: |:---: | 49 | | ADE | 0.62 | 0.61 | 0.59 | 0.59 | 0.56 | 50 | | FDE | 1.48 | 1.47 | 1.40 | 1.37 | 1.33 | 51 | | MR | 0.23 | 0.22 | 0.22 | 0.21 | 0.20 | 52 | ### Qualitative Results 53 | 54 | | MPR=0 | MPR=0.4 |MPR=0.8 | 55 | | -------------------------- | -------------------------- |-------------------------- | 56 | | ![MPR=0](assets/s700_mpr0.png) | ![MPR=0.4](assets/s700_mpr4.png) | ![MPR=0.8](assets/s700_mpr8.png) | 57 | 58 | ## Citation 59 | 60 | If you found this repository useful, please cite as: 61 | 62 | ``` 63 | @article{chen2024msma, 64 | title={MSMA: Multi-agent Trajectory Prediction in Connected and Autonomous Vehicle Environment with Multi-source Data Integration}, 65 | author={Chen, Xi and Bhadani, Rahul and Sun, Zhanbo and Head, Larry}, 66 | journal={arXiv preprint arXiv:2407.21310}, 67 | year={2024} 68 | } 69 | ``` 70 | 71 | ## License 72 | 73 | This repository is licensed under [Apache 2.0](LICENSE). 74 | -------------------------------------------------------------------------------- /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/arch.png -------------------------------------------------------------------------------- /assets/delay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/delay.png -------------------------------------------------------------------------------- /assets/noise.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/noise.png -------------------------------------------------------------------------------- /assets/performance_compr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/performance_compr.png -------------------------------------------------------------------------------- /assets/s700_mpr0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr0.png -------------------------------------------------------------------------------- /assets/s700_mpr2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr2.png -------------------------------------------------------------------------------- /assets/s700_mpr4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr4.png -------------------------------------------------------------------------------- /assets/s700_mpr6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr6.png -------------------------------------------------------------------------------- /assets/s700_mpr8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr8.png -------------------------------------------------------------------------------- /dataloader/__pycache__/carla_scene_process.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/dataloader/__pycache__/carla_scene_process.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/carla_scene_mining.py: -------------------------------------------------------------------------------- 1 | """mine CAV scenarios from logged carla data""" 2 | import pandas as pd 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | # import torch 6 | import random 7 | # import math 8 | import os 9 | import copy 10 | 11 | def get_obj_type_at_mpr(data_df, vids, cav_id, mpr=0.2): 12 | others = list(set(vids) - set([cav_id])) 13 | #keep random seed here to ensure the same seed 14 | random.seed(30) 15 | cv_ids = random.sample(list(others), int(mpr*len(others))) 16 | vid_df = data_df["vid"].values 17 | obj_type_mpr = [] 18 | for v in vid_df: 19 | if v == cav_id: 20 | obj_type_mpr.append("cav") 21 | elif v in cv_ids: 22 | obj_type_mpr.append("cv") 23 | else: 24 | obj_type_mpr.append("ncv") 25 | return obj_type_mpr 26 | 27 | # read the data 28 | data_raw = pd.read_csv("../carla_data/Location.csv", header=None) 29 | header = ["frame","time","vid","type_id","position_x","position_y","position_z","rotation_x","rotation_y","rotation_z","vel_x","vel_y","angular_z"] 30 | map = {idx:header[idx] for idx in range(13)} 31 | data_raw = data_raw.rename(columns = map) 32 | # make pos_y consistent with map 33 | data_raw["position_y"] = -data_raw["position_y"] 34 | # %% 35 | vids = list(data_raw["vid"].unique()) 36 | ts = np.sort(np.unique(data_raw['frame'].values)) 37 | random.seed(30) 38 | cv_range = 50 39 | av_range = 30 40 | 41 | data_df = data_raw.copy(deep=True) 42 | # segment the scenes into 10s 43 | min_ts = ts[0] 44 | max_ts = ts[-1] 45 | # 5s overlapping among scenes 10s = 100 steps/frames 46 | # remove the scenes where cav is not moving 47 | for cav_id in vids: 48 | for frame in range(min_ts+50,max_ts-50,50): 49 | 50 | vehicles_at_frame = data_df[data_df["frame"] == frame] 51 | cav_entry = data_df[(data_df["frame"]==frame) & (data_df["vid"]==cav_id)] 52 | cav_entry_previous = data_df[(data_df["frame"]==frame-1) & (data_df["vid"]==cav_id)] 53 | if (cav_entry.position_x.values == cav_entry_previous.position_x.values) and \ 54 | (cav_entry.position_y.values == cav_entry_previous.position_y.values): 55 | continue 56 | 57 | dist = ((vehicles_at_frame["position_x"].values - cav_entry["position_x"].values)**2 58 | + (vehicles_at_frame["position_y"].values - cav_entry["position_y"].values)**2)**0.5 59 | cv_idx = np.where((dist0))[0] 60 | cv_neighbors = vehicles_at_frame["vid"].values[cv_idx] 61 | #remove unmoved surrounding vehicles 62 | vid_ngbr_unmove = [] 63 | for i in range(len(cv_neighbors)): 64 | vid_ngbr = cv_neighbors[i] 65 | ngbr_entry = data_df[(data_df["frame"]==frame) & (data_df["vid"]==vid_ngbr)] 66 | ngbr_entry_previous = data_df[(data_df["frame"]==frame-1) & (data_df["vid"]==vid_ngbr)] 67 | if (ngbr_entry.position_x.values == ngbr_entry_previous.position_x.values) and \ 68 | (ngbr_entry.position_y.values == ngbr_entry_previous.position_y.values): 69 | vid_ngbr_unmove.append(vid_ngbr) 70 | cv_ngbrs_move=list(set(cv_neighbors)-set(vid_ngbr_unmove)) 71 | 72 | av_idx = np.where((dist0))[0] 73 | av_neighbors = vehicles_at_frame["vid"].values[av_idx] 74 | av_ngbrs_move = list(set(av_neighbors)-set(vid_ngbr_unmove)) 75 | 76 | scene_frames = list(range(frame-50,frame+50)) 77 | scene_vids = [cav_id]+cv_ngbrs_move 78 | scene_data = copy.deepcopy(data_df[data_df["vid"].isin(scene_vids) & data_df["frame"].isin(scene_frames)]) 79 | 80 | #mprs 81 | obj_type_mpr_02 = get_obj_type_at_mpr(scene_data, scene_vids, cav_id, mpr=0.2) 82 | obj_type_mpr_04 = get_obj_type_at_mpr(scene_data, scene_vids, cav_id, mpr=0.4) 83 | obj_type_mpr_06 = get_obj_type_at_mpr(scene_data, scene_vids, cav_id, mpr=0.6) 84 | obj_type_mpr_08 = get_obj_type_at_mpr(scene_data, scene_vids, cav_id, mpr=0.8) 85 | scene_data["obj_type_mpr_02"] = obj_type_mpr_02 86 | scene_data["obj_type_mpr_04"] = obj_type_mpr_04 87 | scene_data["obj_type_mpr_06"] = obj_type_mpr_06 88 | scene_data["obj_type_mpr_08"] = obj_type_mpr_08 89 | 90 | scene_data["in_av_range"] = scene_data["vid"].isin([cav_id]+av_ngbrs_move).values 91 | scene_data.to_csv("scene_mining/scene_{}_{}".format(frame, cav_id),index=False) 92 | -------------------------------------------------------------------------------- /dataloader/carla_scene_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | process the '.csv' files, save as '.pt' files 3 | """ 4 | import os 5 | import sys 6 | import numpy as np 7 | import pandas as pd 8 | import copy 9 | from os.path import join as pjoin 10 | 11 | from dataloader.utils import lane_segment, load_xml 12 | from dataloader.utils.lane_sampling import Spline2D, visualize_centerline 13 | import matplotlib.pyplot as plt 14 | 15 | from typing import List, Optional, Tuple 16 | 17 | import torch 18 | torch.manual_seed(30) 19 | import torch.nn as nn 20 | from torch_geometric.data import Data, HeteroData 21 | from torch_geometric.data import Dataset 22 | from typing import Callable, Dict, List, Optional, Tuple, Union 23 | from itertools import permutations, product 24 | from tqdm import tqdm 25 | 26 | class scene_processed_dataset(Dataset): 27 | def __init__(self, 28 | root:str, 29 | split:str, 30 | radius:float = 75, 31 | local_radius:float = 30, 32 | transform: Optional[Callable] = None, 33 | mpr:float = 0., 34 | obs_len:float=50, 35 | fut_len:float=50, 36 | cv_range:float=50, 37 | av_range:float=30, 38 | noise_var:float=0.1, 39 | delay_frame:float=1, 40 | normalized=True, 41 | source_dir:str = None, 42 | save_dir:str = None) ->None: 43 | 44 | self._split = split 45 | self._radius = radius 46 | self._local_radius = local_radius 47 | self.obs_len = obs_len 48 | self.fut_len = fut_len 49 | self.cv_range = cv_range 50 | self.av_range = av_range 51 | self.mpr = mpr 52 | self.noise_var = noise_var 53 | self.delay_frame = delay_frame 54 | self.normalized = normalized 55 | self.source_dir = source_dir 56 | self.save_dir = save_dir 57 | 58 | self.root = root 59 | self._raw_file_names = os.listdir(self.raw_dir) 60 | self._processed_file_names = [os.path.splitext(f)[0] + '.pt' for f in self.raw_file_names] 61 | self._processed_paths = [os.path.join(self.processed_dir, f) for f in self._processed_file_names] 62 | super(scene_processed_dataset, self).__init__(root) 63 | 64 | @property 65 | def raw_dir(self) -> str: 66 | return os.path.join(self.root, self.source_dir, self._split) 67 | 68 | @property 69 | def processed_dir(self) -> str: 70 | return os.path.join(self.root, self.save_dir, self._split) 71 | 72 | @property 73 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 74 | return self._raw_file_names 75 | 76 | @property 77 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 78 | return self._processed_file_names 79 | 80 | @property 81 | def processed_paths(self) -> List[str]: 82 | return self._processed_paths 83 | 84 | def process(self) -> None: 85 | self.get_map_polygon_bbox() 86 | for raw_path in tqdm(self.raw_paths): 87 | kwargs = self.get_scene_feats(raw_path, self._radius, self._local_radius, self._split) 88 | data = CarlaData(**kwargs) 89 | torch.save(data, os.path.join(self.processed_dir, str(kwargs['seq_id']) + '.pt')) 90 | 91 | def len(self) -> int: 92 | return len(self._raw_file_names) 93 | 94 | def get(self, idx) -> Data: 95 | return torch.load(self.processed_paths[idx]) 96 | 97 | def get_map_polygon_bbox(self): 98 | rel_path = "Town03.osm" 99 | roads = load_xml.load_lane_segments_from_xml(pjoin(self.root, rel_path)) 100 | polygon_bboxes, lane_starts, lane_ends = load_xml.build_polygon_bboxes(roads) 101 | self.roads = roads 102 | self.polygon_bboxes = polygon_bboxes 103 | self.lane_starts = lane_starts 104 | self.lane_ends = lane_ends 105 | 106 | def get_scene_feats(self, raw_path, radius, local_radius, split="train"): 107 | 108 | df = pd.read_csv(raw_path) 109 | # filter out actors that are unseen during the historical time steps 110 | timestamps = list(np.sort(df['frame'].unique())) 111 | historical_timestamps = timestamps[: 50] 112 | historical_df = df[df['frame'].isin(historical_timestamps)] 113 | actor_ids = list(historical_df['vid'].unique()) 114 | 115 | # # filter out unmoved actors 116 | # actor_ids = self.remove_unmoved_ids(df, actor_ids) 117 | 118 | df = df[df['vid'].isin(actor_ids)] 119 | num_nodes = len(actor_ids) 120 | 121 | objs = df.groupby(['vid', 'obj_type_mpr_02', 'obj_type_mpr_04', 'obj_type_mpr_06', 'obj_type_mpr_08', 'in_av_range']).groups 122 | keys = list(objs.keys()) 123 | 124 | vids = [x[0] for x in keys] 125 | actor_indices = [vids.index(x) for x in actor_ids] 126 | obj_type_02 = [keys[i][1] for i in actor_indices] 127 | obj_type_04 = [keys[i][2] for i in actor_indices] 128 | obj_type_06 = [keys[i][3] for i in actor_indices] 129 | obj_type_08 = [keys[i][4] for i in actor_indices] 130 | in_av_range = [keys[i][5] for i in actor_indices] 131 | 132 | cav_idx = np.where(np.asarray(obj_type_02)=="cav")[0] #np array 133 | cav_df = df[df['obj_type_mpr_02'] == 'cav'].iloc 134 | 135 | # make the scene centered at CAV 136 | origin = torch.tensor([cav_df[49]['position_x'], cav_df[49]['position_y']], dtype=torch.float) 137 | cav_heading_vector = origin - torch.tensor([cav_df[48]['position_x'], cav_df[48]['position_y']], dtype=torch.float) 138 | theta = torch.atan2(cav_heading_vector[1], cav_heading_vector[0]) 139 | rotate_mat = torch.tensor([[torch.cos(theta), -torch.sin(theta)], 140 | [torch.sin(theta), torch.cos(theta)]]) 141 | 142 | # initialization 143 | x = torch.zeros(num_nodes, 100, 2, dtype=torch.float) 144 | edge_index = torch.LongTensor(list(permutations(range(num_nodes), 2))).t().contiguous() 145 | padding_mask = torch.ones(num_nodes, 100, dtype=torch.bool) 146 | bos_mask = torch.zeros(num_nodes, 50, dtype=torch.bool) 147 | rotate_angles = torch.zeros(num_nodes, dtype=torch.float) 148 | 149 | for actor_id, actor_df in df.groupby('vid'): 150 | node_idx = actor_ids.index(actor_id) 151 | node_steps = [timestamps.index(timestamp) for timestamp in actor_df['frame']] 152 | padding_mask[node_idx, node_steps] = False 153 | if padding_mask[node_idx, 49]: # make no predictions for actors that are unseen at the current time step 154 | padding_mask[node_idx, 50:] = True 155 | xy = torch.from_numpy(np.stack([actor_df['position_x'].values, actor_df['position_y'].values], axis=-1)).float() 156 | x[node_idx, node_steps] = torch.matmul(rotate_mat, (xy - origin.reshape(-1, 2)).T).T 157 | node_historical_steps = list(filter(lambda node_step: node_step < 50, node_steps)) 158 | if len(node_historical_steps) > 1: # calculate the heading of the actor (approximately) 159 | heading_vector = x[node_idx, node_historical_steps[-1]] - x[node_idx, node_historical_steps[-2]] 160 | rotate_angles[node_idx] = torch.atan2(heading_vector[1], heading_vector[0]) 161 | else: # make no predictions for the actor if the number of valid time steps is less than 2 162 | padding_mask[node_idx, 50:] = True 163 | 164 | # bos_mask is True if time step t is valid and time step t-1 is invalid 165 | bos_mask[:, 0] = ~padding_mask[:, 0] 166 | bos_mask[:, 1: 50] = padding_mask[:, : 49] & ~padding_mask[:, 1: 50] 167 | 168 | #positions are transformed absolute x, y coordinates 169 | positions = x.clone() 170 | 171 | #reformat encode strs and bools, CAV:1, CV:2, NCV:3 172 | obj_type_mapping = {"cav":1, "cv":2, "ncv":3} 173 | obj_type_02_ = torch.tensor([obj_type_mapping[x] for x in obj_type_02]) 174 | obj_type_04_ = torch.tensor([obj_type_mapping[x] for x in obj_type_04]) 175 | obj_type_06_ = torch.tensor([obj_type_mapping[x] for x in obj_type_06]) 176 | obj_type_08_ = torch.tensor([obj_type_mapping[x] for x in obj_type_08]) 177 | in_av_range_ = torch.tensor([1 if in_av_range[i]==True else 0 for i in range(len(in_av_range))]) 178 | 179 | #get masks for different data sources 180 | types = [obj_type_02_, obj_type_04_, obj_type_06_, obj_type_08_] 181 | mprs = [0.2, 0.4, 0.6, 0.8] 182 | cav_mask, commu_mask, sensor_mask = self.get_masks(self.mpr, mprs, types, in_av_range_) 183 | positions_hist = positions[:,:50,:].clone() 184 | x_cav = positions_hist[cav_mask][:,20:50,:] 185 | x_commu = positions_hist[commu_mask] 186 | x_sensor = positions_hist[sensor_mask] 187 | 188 | #inject errors to different data sources 189 | x_sensor_noise, padding_mask_noise = self.get_noisy_x(x_sensor, padding_mask[sensor_mask], self.noise_var) 190 | x_commu_delay, padding_mask_delay = self.get_delayed_x(x_commu, padding_mask[commu_mask], self.delay_frame) 191 | 192 | #get vectorized x 193 | x_cav_vec = self.get_vectorized_x(x_cav, padding_mask[cav_mask][:,20:50]) 194 | x_commu_delay_vec = self.get_vectorized_x(x_commu_delay, padding_mask_delay) 195 | x_sensor_noise_vec = self.get_vectorized_x(x_sensor_noise, padding_mask_noise) 196 | 197 | 198 | y = torch.where((padding_mask[:, 49].unsqueeze(-1) | padding_mask[:, 50:]).unsqueeze(-1), 199 | torch.zeros(num_nodes, 50, 2), 200 | x[:, 50:] - x[:, 49].unsqueeze(-2)) 201 | 202 | y_commu = torch.where((padding_mask[:, 49].unsqueeze(-1) | padding_mask[:, 50:]).unsqueeze(-1), 203 | torch.zeros(num_nodes, 50, 2), 204 | x[:, 50:] - x[:, 49-self.delay_frame].unsqueeze(-2))[commu_mask] 205 | 206 | lane_pos, lane_vectors, lane_idcs,lane_actor_index, lane_actor_attr = \ 207 | self.get_lane_feats(origin, rotate_mat, num_nodes, positions, radius, local_radius) 208 | 209 | #get rotate-invariant matrix 210 | rotate_imat = torch.empty(num_nodes, 2, 2) 211 | sin_vals = torch.sin(rotate_angles) 212 | cos_vals = torch.cos(rotate_angles) 213 | rotate_imat[:, 0, 0] = cos_vals 214 | rotate_imat[:, 0, 1] = -sin_vals 215 | rotate_imat[:, 1, 0] = sin_vals 216 | rotate_imat[:, 1, 1] = cos_vals 217 | 218 | seq_id = os.path.splitext(os.path.basename(raw_path))[0] 219 | 220 | return { 221 | 'x_cav': x_cav_vec, # [1, 30, 2] 222 | 'x_commu': x_commu_delay_vec, # [N1, 30, 2] 223 | 'x_sensor': x_sensor_noise_vec, # [N2, 30, 2] 224 | 'cav_mask': cav_mask, # [N] 225 | 'commu_mask': commu_mask, # [N] 226 | 'sensor_mask': sensor_mask, # [N] 227 | 'positions': positions, # [N, 100, 2] 228 | 'edge_index': edge_index, # [2, N x (N - 1)] 229 | 'y': y, # [N, 50, 2] 230 | 'y_commu': y_commu, #[M, 50, 2] 231 | 'x_commu_ori': x_commu_delay[:,-1,:], #abs starting pos of delayed traj 232 | 'x_sensor_ori': x_sensor_noise[:,-1,:], #abs starting pos of nosiy traj 233 | 'seq_id': seq_id, #str, file_name 234 | 'num_nodes': num_nodes, 235 | 'padding_mask': padding_mask, # [N, 100] 236 | 'bos_mask': bos_mask, # [N, 50] 237 | 'rotate_angles': rotate_angles, # [N] 238 | 'rotate_imat': rotate_imat, #[N, 2, 2] 239 | 'lane_vectors': lane_vectors, # [L, 2] 240 | 'lane_pos': lane_pos, #[L, 2] 241 | 'lane_idcs': lane_idcs, #[L] 242 | 'lane_actor_index': lane_actor_index, 243 | 'lane_actor_attr': lane_actor_attr, 244 | 'mpr': self.mpr, 245 | 'origin': origin.unsqueeze(0), 246 | 'theta': theta, 247 | 'rotate_mat': rotate_mat 248 | } 249 | 250 | def get_lane_feats(self, origin, rotate_mat, num_nodes, positions, radius=75, local_radius=30): 251 | 252 | road_ids = load_xml.get_road_ids_in_xy_bbox(self.polygon_bboxes, self.lane_starts, self.lane_ends, self.roads, origin[0], origin[1], radius) 253 | road_ids = copy.deepcopy(road_ids) 254 | 255 | lanes=dict() 256 | for road_id in road_ids: 257 | road = self.roads[road_id] 258 | ctr_line = torch.from_numpy(np.stack(((self.roads[road_id].l_bound[:,0]+self.roads[road_id].r_bound[:,0])/2, 259 | (self.roads[road_id].l_bound[:,1]+self.roads[road_id].r_bound[:,1])/2),axis=-1)) 260 | ctr_line = torch.matmul(rotate_mat, (ctr_line - origin.reshape(-1, 2)).T.float()).T 261 | 262 | x, y = ctr_line[:,0], ctr_line[:,1] 263 | # if x.max() < x_min or x.min() > x_max or y.max() < y_min or y.min() > y_max: 264 | # continue 265 | # else: 266 | """getting polygons requires original centerline""" 267 | polygon, _, _ = load_xml.build_polygon_bboxes({road_id: self.roads[road_id]}) 268 | polygon_x = torch.from_numpy(np.array([polygon[:,0],polygon[:,0],polygon[:,2],polygon[:,2],polygon[:,0]])) 269 | polygon_y = torch.from_numpy(np.array([polygon[:,1],polygon[:,3],polygon[:,3],polygon[:,1],polygon[:,1]])) 270 | polygon_reshape = torch.cat([polygon_x,polygon_y],dim=-1) #shape(5,2) 271 | 272 | road.centerline = ctr_line 273 | road.polygon = torch.matmul(rotate_mat, (polygon_reshape.float() - origin.reshape(-1, 2)).T).T 274 | lanes[road_id] = road 275 | 276 | lane_ids = list(lanes.keys()) 277 | lane_pos, lane_vectors = [], [] 278 | for lane_id in lane_ids: 279 | lane = lanes[lane_id] 280 | ctrln = lane.centerline 281 | 282 | # lane_ctrs.append(torch.from_numpy(np.asarray((ctrln[:-1]+ctrln[1:])/2.0, np.float32)))#lane center point 283 | # lane_vectors.append(torch.from_numpy(np.asarray(ctrln[1:]-ctrln[:-1], np.float32))) #length between waypoints 284 | lane_pos.append(ctrln[:-1]) #lane center point 285 | lane_vectors.append(ctrln[1:]-ctrln[:-1])#length between waypoints 286 | 287 | lane_idcs = [] 288 | count = 0 289 | for i, position in enumerate(lane_pos): 290 | lane_idcs.append(i*torch.ones(len(position))) 291 | count += len(position) 292 | 293 | lane_idcs = torch.cat(lane_idcs, dim=0) 294 | lane_pos = torch.cat(lane_pos, dim=0) 295 | lane_vectors = torch.cat(lane_vectors, dim=0) 296 | 297 | lane_actor_index = torch.LongTensor(list(product(torch.arange(lane_vectors.size(0)), \ 298 | torch.arange(num_nodes)))).t().contiguous() 299 | lane_actor_attr = \ 300 | lane_pos[lane_actor_index[0]] - positions[:,49,:][lane_actor_index[1]] 301 | mask = torch.norm(lane_actor_attr, p=2, dim=-1) < local_radius 302 | lane_actor_index = lane_actor_index[:, mask] 303 | lane_actor_attr = lane_actor_attr[mask] 304 | 305 | 306 | return lane_pos, lane_vectors, lane_idcs, lane_actor_index, lane_actor_attr 307 | 308 | def get_vectorized_x(self, x0, padding_mask): 309 | ''' 310 | x: torch.Tensor: [n, 30, 2] 311 | padding_mask: torch.Tensor:[n, 30] 312 | ''' 313 | x = x0.clone() 314 | x[:, 1: 30] = torch.where((padding_mask[:, : 29] | padding_mask[:, 1: 30]).unsqueeze(-1), 315 | torch.zeros(x.shape[0], 29, 2), 316 | x[:, 1: 30] - x[:, : 29]) 317 | x[:, 0] = torch.zeros(x.shape[0], 2) 318 | 319 | return x 320 | 321 | def get_masks(self, mpr, mprs, types, in_av_range): 322 | #ncv in av range 323 | #and all cv 324 | if mpr == 0: 325 | cav_mask = types[0]==1 326 | commu_mask = torch.zeros(cav_mask.shape)==True 327 | sensor_mask = (types[0]!=1) & (in_av_range==1) 328 | else: 329 | type_idx = mprs.index(mpr) 330 | cav_mask = types[type_idx]==1 331 | commu_mask = types[type_idx]==2 332 | sensor_mask = (types[type_idx]!=1) & (in_av_range==1) 333 | 334 | return cav_mask, commu_mask, sensor_mask 335 | 336 | def get_noisy_x(self, x, padding_mask, var=0.1): 337 | """ 338 | get noisy feats for sensor data 339 | x: torch.Tensor of shape(n, 50, 2) 340 | 341 | return 342 | noise_x: torch.Tensor of shape(n, 30, 2) 343 | """ 344 | noise = torch.normal(0, var, x.shape) 345 | 346 | return (x+noise)[:,20:,:], padding_mask[:,20:50] 347 | 348 | def get_delayed_x(self, x, padding_mask, lag=1): 349 | """ 350 | get delayed feats of communication data 351 | x: torch tensor of shape(n, 50, 2) 352 | lag: number of frames in [0:20] 353 | 354 | return 355 | delayed_x: torch.Tensor of shape(n, 30, 2) 356 | """ 357 | if lag<0 or lag>20: 358 | raise Exception("lag must be in the range(0,20)") 359 | 360 | delayed_x = x[:,20-lag:-lag,:] 361 | 362 | return delayed_x, padding_mask[:, 20-lag:50-lag] 363 | 364 | class CarlaData(Data): 365 | 366 | def __init__(self, 367 | x_cav: Optional[torch.Tensor] = None, 368 | x_commu: Optional[torch.Tensor] = None, 369 | x_sensor: Optional[torch.Tensor] = None, 370 | cav_mask: Optional[torch.Tensor] = None, 371 | commu_mask: Optional[torch.Tensor] = None, 372 | sensor_mask: Optional[torch.Tensor] = None, 373 | positions: Optional[torch.Tensor] = None, 374 | edge_index: Optional[torch.Tensor] = None, 375 | edge_attrs: Optional[List[torch.Tensor]] = None, 376 | lane_actor_index: Optional[torch.Tensor] = None, 377 | lane_actor_attr: Optional[torch.Tensor] = None, 378 | y: Optional[torch.Tensor] = None, 379 | y_commu: Optional[torch.Tensor] = None, 380 | x_commu_ori: Optional[torch.Tensor] = None, 381 | x_sensor_ori: Optional[torch.Tensor] = None, 382 | seq_id: Optional[str] = None, 383 | num_nodes: Optional[int] = None, 384 | padding_mask: Optional[torch.Tensor] = None, 385 | bos_mask: Optional[torch.Tensor] = None, 386 | rotate_angles: Optional[torch.Tensor] = None, 387 | rotate_imat: Optional[torch.Tensor] = None, 388 | lane_vectors: Optional[torch.Tensor] = None, 389 | lane_pos: Optional[torch.Tensor] = None, 390 | lane_idcs: Optional[torch.Tensor] = None, 391 | mpr: Optional[torch.Tensor] = None, 392 | origin: Optional[torch.Tensor] = None, 393 | theta: Optional[torch.Tensor] = None, 394 | rotate_mat: Optional[torch.Tensor] = None, 395 | # obj_type_02: Optional[torch.Tensor] = None, 396 | # obj_type_04: Optional[torch.Tensor] = None, 397 | # obj_type_06: Optional[torch.Tensor] = None, 398 | # obj_type_08: Optional[torch.Tensor] = None, 399 | # in_av_range: Optional[torch.Tensor] = None, 400 | **kwargs) -> None: 401 | if x_cav is None: 402 | super(CarlaData, self).__init__() 403 | return 404 | super(CarlaData, self).__init__(x_cav=x_cav, x_commu=x_commu, x_sensor=x_sensor, mpr=mpr, 405 | cav_mask=cav_mask, commu_mask=commu_mask, sensor_mask=sensor_mask, 406 | positions=positions, edge_index=edge_index, edge_attrs=edge_attrs, 407 | lane_actor_index=lane_actor_index, lane_actor_attr=lane_actor_attr, 408 | y=y, y_commu=y_commu, x_commu_ori=x_commu_ori, x_sensor_ori=x_sensor_ori, 409 | seq_id=seq_id, num_nodes=num_nodes, padding_mask=padding_mask, 410 | bos_mask=bos_mask, rotate_angles=rotate_angles, rotate_imat=rotate_imat, 411 | lane_vectors=lane_vectors, lane_pos=lane_pos, lane_idcs=lane_idcs, 412 | theta=theta, rotate_mat=rotate_mat, 413 | **kwargs) 414 | if edge_attrs is not None: 415 | for t in range(self.x.size(1)): 416 | self[f'edge_attr_{t}'] = edge_attrs[t] 417 | 418 | def __inc__(self, key, value, *args, **kwargs): 419 | if key == 'lane_actor_index': 420 | return torch.tensor([[self['lane_vectors'].size(0)], [self.num_nodes]]) 421 | else: 422 | return super().__inc__(key, value) 423 | 424 | -------------------------------------------------------------------------------- /dataloader/utils/__pycache__/lane_sampling.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/dataloader/utils/__pycache__/lane_sampling.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/utils/__pycache__/lane_segment.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/dataloader/utils/__pycache__/lane_segment.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/utils/__pycache__/load_xml.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/dataloader/utils/__pycache__/load_xml.cpython-37.pyc -------------------------------------------------------------------------------- /dataloader/utils/lane_sampling.py: -------------------------------------------------------------------------------- 1 | #implement equal sampling of map vector 2 | import numpy as np 3 | import math 4 | import matplotlib.pyplot as plt 5 | 6 | class Spline: 7 | """ 8 | Cubic Spline class 9 | """ 10 | 11 | def __init__(self, x, y): 12 | self.b, self.c, self.d, self.w = [], [], [], [] 13 | 14 | self.x = np.array(x) 15 | self.y = np.array(y) 16 | 17 | self.eps = np.finfo(float).eps 18 | 19 | self.nx = len(x) # dimension of x 20 | h = np.diff(x) 21 | 22 | # calc coefficient c 23 | self.a = np.array([iy for iy in y]) 24 | 25 | # calc coefficient c 26 | A = self.__calc_A(h) 27 | B = self.__calc_B(h) 28 | self.c = np.linalg.solve(A, B) 29 | # print(self.c1) 30 | 31 | # calc spline coefficient b and d 32 | for i in range(self.nx - 1): 33 | self.d.append((self.c[i + 1] - self.c[i]) / (3.0 * h[i] + self.eps)) 34 | tb = (self.a[i + 1] - self.a[i]) / (h[i] + self.eps) - h[i] * \ 35 | (self.c[i + 1] + 2.0 * self.c[i]) / 3.0 36 | self.b.append(tb) 37 | self.b = np.array(self.b) 38 | self.d = np.array(self.d) 39 | 40 | def calc(self, t): 41 | """ 42 | Calc position 43 | if t is outside of the input x, return None 44 | """ 45 | t = np.asarray(t) 46 | mask = np.logical_and(t < self.x[0], t > self.x[-1]) 47 | t[mask] = self.x[0] 48 | 49 | i = self.__search_index(t) 50 | dx = t - self.x[i.astype(int)] 51 | result = self.a[i] + self.b[i] * dx + \ 52 | self.c[i] * dx ** 2.0 + self.d[i] * dx ** 3.0 53 | 54 | result = np.asarray(result) 55 | result[mask] = None 56 | return result 57 | 58 | def calcd(self, t): 59 | """ 60 | Calc first derivative 61 | if t is outside of the input x, return None 62 | """ 63 | t = np.asarray(t) 64 | mask = np.logical_and(t < self.x[0], t > self.x[-1]) 65 | t[mask] = 0 66 | 67 | i = self.__search_index(t) 68 | dx = t - self.x[i] 69 | result = self.b[i] + 2.0 * self.c[i] * dx + 3.0 * self.d[i] * dx ** 2.0 70 | 71 | result = np.asarray(result) 72 | result[mask] = None 73 | return result 74 | 75 | def calcdd(self, t): 76 | """ 77 | Calc second derivative 78 | """ 79 | t = np.asarray(t) 80 | mask = np.logical_and(t < self.x[0], t > self.x[-1]) 81 | t[mask] = 0 82 | 83 | i = self.__search_index(t) 84 | dx = t - self.x[i] 85 | result = 2.0 * self.c[i] + 6.0 * self.d[i] * dx 86 | 87 | result = np.asarray(result) 88 | result[mask] = None 89 | return result 90 | 91 | def __search_index(self, x): 92 | """ 93 | search data segment index 94 | """ 95 | indices = np.asarray(np.searchsorted(self.x, x, "left") - 1) 96 | indices[indices <= 0] = 0 97 | return indices 98 | 99 | def __calc_A(self, h): 100 | """ 101 | calc matrix A for spline coefficient c 102 | """ 103 | A = np.zeros((self.nx, self.nx)) 104 | A[0, 0] = 1.0 105 | for i in range(self.nx - 1): 106 | if i != (self.nx - 2): 107 | A[i + 1, i + 1] = 2.0 * (h[i] + h[i + 1]) 108 | A[i + 1, i] = h[i] 109 | A[i, i + 1] = h[i] 110 | 111 | A[0, 1] = 0.0 112 | A[self.nx - 1, self.nx - 2] = 0.0 113 | A[self.nx - 1, self.nx - 1] = 1.0 114 | # print(A) 115 | return A 116 | 117 | def __calc_B(self, h): 118 | """ 119 | calc matrix B for spline coefficient c 120 | """ 121 | B = np.zeros(self.nx) 122 | for i in range(self.nx - 2): 123 | B[i + 1] = 3.0 * (self.a[i + 2] - self.a[i + 1]) / (h[i + 1] + self.eps) \ 124 | - 3.0 * (self.a[i + 1] - self.a[i]) / (h[i] + self.eps) 125 | return B 126 | class Spline2D: 127 | """ 128 | 2D Cubic Spline class 129 | """ 130 | 131 | def __init__(self, x, y, resolution=0.1): 132 | self.s = self.__calc_s(x, y) 133 | self.sx = Spline(self.s, x) 134 | self.sy = Spline(self.s, y) 135 | 136 | self.s_fine = np.arange(0, self.s[-1], resolution) 137 | xy = np.array([self.calc_global_position_online(s_i) for s_i in self.s_fine]) 138 | 139 | self.x_fine = xy[:, 0] 140 | self.y_fine = xy[:, 1] 141 | 142 | def __calc_s(self, x, y): 143 | dx = np.diff(x) 144 | dy = np.diff(y) 145 | self.ds = np.hypot(dx, dy) 146 | s = [0] 147 | s.extend(np.cumsum(self.ds)) 148 | return s 149 | 150 | def calc_global_position_online(self, s): 151 | """ 152 | calc global position of points on the line, s: float 153 | return: x: float; y: float; the global coordinate of given s on the spline 154 | """ 155 | x = self.sx.calc(s) 156 | y = self.sy.calc(s) 157 | 158 | return x, y 159 | 160 | def calc_global_position_offline(self, s, d): 161 | """ 162 | calc global position of points in the frenet coordinate w.r.t. the line. 163 | s: float, longitudinal; d: float, lateral; 164 | return: x, float; y, float; 165 | """ 166 | s_x = self.sx.calc(s) 167 | s_y = self.sy.calc(s) 168 | 169 | theta = math.atan2(self.sy.calcd(s), self.sx.calcd(s)) 170 | x = s_x - math.sin(theta) * d 171 | y = s_y + math.cos(theta) * d 172 | return x, y 173 | 174 | def calc_frenet_position(self, x, y): 175 | """ 176 | cal the frenet position of given global coordinate (x, y) 177 | return s: the longitudinal; d: the lateral 178 | """ 179 | # find nearst x, y 180 | diff = np.hypot(self.x_fine - x, self.y_fine - y) 181 | idx = np.argmin(diff) 182 | [x_s, y_s] = self.x_fine[idx], self.y_fine[idx] 183 | s = self.s_fine[idx] 184 | 185 | # compute theta 186 | theta = math.atan2(self.sy.calcd(s), self.sx.calcd(s)) 187 | d_x, d_y = x - x_s, y - y_s 188 | cross_rd_nd = math.cos(theta) * d_y - math.sin(theta) * d_x 189 | d = math.copysign(np.hypot(d_x, d_y), cross_rd_nd) 190 | return s, d 191 | 192 | def calc_curvature(self, s): 193 | """ 194 | calc curvature 195 | """ 196 | dx = self.sx.calcd(s) 197 | ddx = self.sx.calcdd(s) 198 | dy = self.sy.calcd(s) 199 | ddy = self.sy.calcdd(s) 200 | k = (ddy * dx - ddx * dy) / ((dx ** 2 + dy ** 2)**(3 / 2)) 201 | return k 202 | 203 | def calc_yaw(self, s): 204 | """ 205 | calc yaw 206 | """ 207 | dx = self.sx.calcd(s) 208 | dy = self.sy.calcd(s) 209 | yaw = np.arctan2(dy, dx) 210 | return yaw 211 | 212 | def visualize_centerline(centerline) -> None: 213 | """Visualize the computed centerline. 214 | Args: 215 | centerline: Sequence of coordinates forming the centerline 216 | """ 217 | line_coords = list(zip(*centerline)) 218 | lineX = line_coords[0] 219 | lineY = line_coords[1] 220 | plt.plot(lineX, lineY, "--", color="grey", alpha=1, linewidth=1, zorder=0) 221 | plt.text(lineX[0], lineY[0], "s") 222 | plt.text(lineX[-1], lineY[-1], "e") 223 | plt.axis("equal") -------------------------------------------------------------------------------- /dataloader/utils/lane_segment.py: -------------------------------------------------------------------------------- 1 | # 2 | from typing import List, Optional 3 | 4 | import numpy as np 5 | 6 | 7 | class LaneSegment: 8 | def __init__( 9 | self, 10 | id: int, 11 | l_neighbor_id: Optional[int], 12 | r_neighbor_id: Optional[int], 13 | centerline: np.ndarray, 14 | ) -> None: 15 | """ 16 | Initialize the lane segment. 17 | 18 | Args: 19 | id: Unique lane ID that serves as identifier for this "Way" 20 | l_neighbor_id: Unique ID for left neighbor 21 | r_neighbor_id: Unique ID for right neighbor 22 | centerline: The coordinates of the lane segment's center line. 23 | """ 24 | self.id = id 25 | self.l_neighbor_id = l_neighbor_id 26 | self.r_neighbor_id = r_neighbor_id 27 | self.centerline = centerline 28 | 29 | class Road: 30 | def __init__( 31 | self, 32 | id: int, 33 | l_bound: np.ndarray, 34 | r_bound: np.ndarray, 35 | ) -> None: 36 | """Initialize the lane segment. 37 | 38 | Args: 39 | id: Unique lane ID that serves as identifier for this "Way". 40 | l_bound: The coordinates of the lane segment's left bound. 41 | r_bound: The coordinates of the lane segment's right bound. 42 | """ 43 | self.id = id 44 | self.l_bound = l_bound 45 | self.r_bound = r_bound 46 | 47 | -------------------------------------------------------------------------------- /dataloader/utils/load_xml.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # 3 | 4 | """ 5 | Utility to load the Argoverse vector map from disk, where it is stored in an XML format. 6 | 7 | We release our Argoverse vector map in a modified OpenStreetMap (OSM) form. We also provide 8 | the map data loader. OpenStreetMap (OSM) provides XML data and relies upon "Nodes" and "Ways" as 9 | its fundamental element. 10 | 11 | A "Node" is a point of interest, or a constituent point of a line feature such as a road. 12 | In OpenStreetMap, a `Node` has tags, which might be 13 | -natural: If it's a natural feature, indicates the type (hill summit, etc) 14 | -man_made: If it's a man made feature, indicates the type (water tower, mast etc) 15 | -amenity: If it's an amenity (e.g. a pub, restaurant, recycling 16 | centre etc) indicates the type 17 | 18 | In OSM, a "Way" is most often a road centerline, composed of an ordered list of "Nodes". 19 | An OSM way often represents a line or polygon feature, e.g. a road, a stream, a wood, a lake. 20 | Ways consist of two or more nodes. Tags for a Way might be: 21 | -highway: the class of road (motorway, primary,secondary etc) 22 | -maxspeed: maximum speed in km/h 23 | -ref: the road reference number 24 | -oneway: is it a one way road? (boolean) 25 | 26 | However, in Argoverse, a "Way" corresponds to a LANE segment centerline. An Argoverse Way has the 27 | following 9 attributes: 28 | - id: integer, unique lane ID that serves as identifier for this "Way" 29 | - has_traffic_control: boolean 30 | - turn_direction: string, 'RIGHT', 'LEFT', or 'NONE' 31 | - is_intersection: boolean 32 | - l_neighbor_id: integer, unique ID for left neighbor 33 | - r_neighbor_id: integer, unique ID for right neighbor 34 | - predecessors: list of integers or None 35 | - successors: list of integers or None 36 | - centerline_node_ids: list 37 | 38 | In Argoverse, a `LaneSegment` object is derived from a combination of a single `Way` and two or more 39 | `Node` objects. 40 | """ 41 | 42 | import logging 43 | import os 44 | import xml.etree.ElementTree as ET 45 | from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union, cast 46 | 47 | import numpy as np 48 | import matplotlib.pyplot as plt 49 | 50 | from dataloader.utils.lane_segment import LaneSegment, Road 51 | 52 | logger = logging.getLogger(__name__) 53 | 54 | 55 | _PathLike = Union[str, "os.PathLike[str]"] 56 | 57 | 58 | class Node: 59 | """ 60 | e.g. a point of interest, or a constituent point of a 61 | line feature such as a road 62 | """ 63 | 64 | def __init__(self, id: int, x: float, y: float, height: Optional[float] = None): 65 | """ 66 | Args: 67 | id: representing unique node ID 68 | x: x-coordinate in city reference system 69 | y: y-coordinate in city reference system 70 | 71 | Returns: 72 | None 73 | """ 74 | self.id = id 75 | self.x = x 76 | self.y = y 77 | self.height = height 78 | 79 | 80 | def str_to_bool(s: str) -> bool: 81 | """ 82 | Args: 83 | s: string representation of boolean, either 'True' or 'False' 84 | 85 | Returns: 86 | boolean 87 | """ 88 | if s == "True": 89 | return True 90 | assert s == "False" 91 | return False 92 | 93 | 94 | def convert_dictionary_to_lane_segment_obj(lane_id: int, lane_dictionary: Mapping[str, Any]) -> LaneSegment: 95 | """ 96 | Not all lanes have predecessors and successors. 97 | 98 | Args: 99 | lane_id: representing unique lane ID 100 | lane_dictionary: dictionary with LaneSegment attributes, not yet in object instance form 101 | 102 | Returns: 103 | ls: LaneSegment object 104 | """ 105 | 106 | l_neighbor_id = None 107 | r_neighbor_id = None 108 | ls = LaneSegment( 109 | lane_id, 110 | l_neighbor_id, 111 | r_neighbor_id, 112 | lane_dictionary["centerline"], 113 | ) 114 | return ls 115 | 116 | 117 | def append_additional_key_value_pair(lane_obj: MutableMapping[str, Any], way_field: List[Tuple[str, str]]) -> None: 118 | """ 119 | Key name was either 'predecessor' or 'successor', for which we can have multiple. 120 | Thus we append them to a list. They should be integers, as lane IDs. 121 | 122 | Args: 123 | lane_obj: lane object 124 | way_field: key and value pair to append 125 | 126 | Returns: 127 | None 128 | """ 129 | assert len(way_field) == 2 130 | k = way_field[0][1] 131 | v = int(way_field[1][1]) 132 | lane_obj.setdefault(k, []).append(v) 133 | 134 | 135 | def append_unique_key_value_pair(lane_obj: MutableMapping[str, Any], way_field: List[Tuple[str, str]]) -> None: 136 | """ 137 | For the following types of Way "tags", the key, value pair is defined only once within 138 | the object: 139 | - has_traffic_control, turn_direction, is_intersection, l_neighbor_id, r_neighbor_id 140 | 141 | Args: 142 | lane_obj: lane object 143 | way_field: key and value pair to append 144 | 145 | Returns: 146 | None 147 | """ 148 | assert len(way_field) == 2 149 | k = way_field[0][1] 150 | v = way_field[1][1] 151 | lane_obj[k] = v 152 | 153 | 154 | def extract_node_waypt(way_field: List[Tuple[str, str]]) -> int: 155 | """ 156 | Given a list with a reference node such as [('ref', '0')], extract out the lane ID. 157 | 158 | Args: 159 | way_field: key and node id pair to extract 160 | 161 | Returns: 162 | node_id: unique ID for a node waypoint 163 | """ 164 | key = way_field[0][0] 165 | node_id = way_field[0][1] 166 | assert key == "ref" 167 | return int(node_id) 168 | 169 | 170 | def get_lane_identifier(child: ET.Element) -> int: 171 | """ 172 | Fetch lane ID from XML ET.Element. 173 | 174 | Args: 175 | child: ET.Element with information about Way 176 | 177 | Returns: 178 | unique lane ID 179 | """ 180 | return int(child.attrib["id"]) 181 | 182 | 183 | def convert_node_id_list_to_xy(node_id_list: List[int], all_graph_nodes: Mapping[int, Node]) -> np.ndarray: 184 | """ 185 | convert node id list to centerline xy coordinate 186 | 187 | Args: 188 | node_id_list: list of node_id's 189 | all_graph_nodes: dictionary mapping node_ids to Node 190 | 191 | Returns: 192 | centerline 193 | """ 194 | num_nodes = len(node_id_list) 195 | 196 | if all_graph_nodes[node_id_list[0]].height is not None: 197 | centerline = np.zeros((num_nodes, 3)) 198 | else: 199 | centerline = np.zeros((num_nodes, 2)) 200 | for i, node_id in enumerate(node_id_list): 201 | if all_graph_nodes[node_id].height is not None: 202 | centerline[i] = np.array( 203 | [ 204 | all_graph_nodes[node_id].x, 205 | all_graph_nodes[node_id].y, 206 | all_graph_nodes[node_id].height, 207 | ] 208 | ) 209 | else: 210 | centerline[i] = np.array([all_graph_nodes[node_id].x, all_graph_nodes[node_id].y]) 211 | 212 | return centerline 213 | 214 | 215 | def extract_node_from_ET_element(child: ET.Element) -> Node: 216 | """ 217 | Given a line of XML, build a node object. The "node_fields" dictionary will hold "id", "x", "y". 218 | The XML will resemble: 219 | 220 | 221 | 222 | Args: 223 | child: xml.etree.ElementTree element 224 | 225 | Returns: 226 | Node object 227 | """ 228 | node_fields = child.attrib 229 | node_id = int(node_fields["id"]) 230 | for element in child: 231 | way_field = cast(List[Tuple[str, str]], list(element.items())) 232 | key = way_field[0][1] 233 | if key == "local_x": 234 | x = float(way_field[1][1]) 235 | elif key == "local_y": 236 | y = float(way_field[1][1]) 237 | 238 | return Node(id=node_id, x=x, y=y) 239 | 240 | 241 | def extract_lane_segment_from_ET_element( 242 | child: ET.Element, all_graph_nodes: Mapping[int, Node] 243 | ) -> Tuple[LaneSegment, int]: 244 | """ 245 | We build a lane segment from an XML element. A lane segment is equivalent 246 | to a "Way" in our XML file. Each Lane Segment has a polyline representing its centerline. 247 | The relevant XML data might resemble:: 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | ... 257 | 258 | 259 | ... 260 | 261 | 262 | 263 | Args: 264 | child: xml.etree.ElementTree element 265 | all_graph_nodes 266 | 267 | Returns: 268 | lane_segment: LaneSegment object 269 | lane_id 270 | """ 271 | lane_obj: Dict[str, Any] = {} 272 | lane_id = get_lane_identifier(child) 273 | node_id_list: List[int] = [] 274 | for element in child: 275 | # The cast on the next line is the result of a typeshed bug. This really is a List and not a ItemsView. 276 | way_field = cast(List[Tuple[str, str]], list(element.items())) 277 | field_name = way_field[0][0] 278 | if field_name == "k": 279 | key = way_field[0][1] 280 | if key in {"predecessor", "successor"}: 281 | append_additional_key_value_pair(lane_obj, way_field) 282 | else: 283 | append_unique_key_value_pair(lane_obj, way_field) 284 | else: 285 | node_id_list.append(extract_node_waypt(way_field)) 286 | 287 | lane_obj["centerline"] = convert_node_id_list_to_xy(node_id_list, all_graph_nodes) 288 | lane_segment = convert_dictionary_to_lane_segment_obj(lane_id, lane_obj) 289 | return lane_segment, lane_id 290 | 291 | def construct_road_from_ET_element( 292 | child: ET.Element, lane_objs: Mapping[int, LaneSegment] 293 | ): 294 | road_id = int(child.attrib["id"]) 295 | for element in child: 296 | if element.tag == "member": 297 | relation_field = cast(List[Tuple[str, str]], list(element.items())) 298 | if relation_field[2][1] == "right": 299 | r_bound_idx = int(relation_field[1][1]) 300 | elif relation_field[2][1] == "left": 301 | l_bound_idx = int(relation_field[1][1]) 302 | l_bound = lane_objs[l_bound_idx].centerline 303 | r_bound = lane_objs[r_bound_idx].centerline 304 | road = Road( 305 | road_id, 306 | l_bound, 307 | r_bound 308 | ) 309 | return road, road_id 310 | 311 | 312 | def load_lane_segments_from_xml(map_fpath: _PathLike) -> Mapping[int, LaneSegment]: 313 | """ 314 | Load lane segment object from xml file 315 | 316 | Args: 317 | map_fpath: path to xml file 318 | 319 | Returns: 320 | lane_objs: List of LaneSegment objects 321 | """ 322 | tree = ET.parse(os.fspath(map_fpath)) 323 | root = tree.getroot() 324 | 325 | logger.info(f"Loaded root: {root.tag}") 326 | 327 | all_graph_nodes = {} 328 | lane_objs = {} 329 | roads = {} 330 | # all children are either Nodes or Ways or relations 331 | for child in root: 332 | if child.tag == "node": 333 | node_obj = extract_node_from_ET_element(child) 334 | all_graph_nodes[node_obj.id] = node_obj 335 | elif child.tag == "way": 336 | lane_obj, lane_id = extract_lane_segment_from_ET_element(child, all_graph_nodes) 337 | lane_objs[lane_id] = lane_obj 338 | elif child.tag == "relation": 339 | road, road_id = construct_road_from_ET_element(child, lane_objs) 340 | roads[road_id] = road 341 | else: 342 | logger.error("Unknown XML item encountered.") 343 | raise ValueError("Unknown XML item encountered.") 344 | return roads 345 | 346 | def build_polygon_bboxes(roads): 347 | """ 348 | roads: dict, key: road id; value field: l_bound, r_bound 349 | polygon_bboxes: An array of shape (K,), each array element is a NumPy array of shape (4,) representing 350 | the bounding box for a polygon or point cloud. 351 | each road_id corresponds to a polygon_bbox 352 | lane_start: An array of shape (,4), indicating (x_l, y_l, x_r, y_r) 353 | lane_end: An array of shape (,4), indicating (x_l, y_l, x_r, y_r) 354 | """ 355 | polygon_bboxes = [] 356 | lane_starts = [] 357 | lane_ends = [] 358 | for road_id in roads.keys(): 359 | x = np.concatenate((roads[road_id].l_bound[:,0], roads[road_id].r_bound[:,0])) 360 | xmin = np.min(x) 361 | xmax = np.max(x) 362 | y = np.concatenate((roads[road_id].l_bound[:,1], roads[road_id].r_bound[:,1])) 363 | ymin = np.min(y) 364 | ymax = np.max(y) 365 | polygon_bbox = np.array([xmin, ymin, xmax, ymax]) 366 | polygon_bboxes.append(polygon_bbox) 367 | 368 | lane_start = np.array([roads[road_id].l_bound[0,0], roads[road_id].l_bound[0,1], 369 | roads[road_id].r_bound[0,0], roads[road_id].r_bound[0,1]]) 370 | lane_end = np.array([roads[road_id].l_bound[-1,0], roads[road_id].l_bound[-1,1], 371 | roads[road_id].r_bound[-1,0], roads[road_id].r_bound[-1,1]]) 372 | lane_starts.append(lane_start) 373 | lane_ends.append(lane_end) 374 | 375 | return np.array(polygon_bboxes), np.array(lane_starts), np.array(lane_ends) 376 | 377 | def find_all_polygon_bboxes_overlapping_query_bbox(polygon_bboxes: np.ndarray, 378 | query_bbox: np.ndarray, 379 | lane_starts: np.ndarray, 380 | lane_ends: np.ndarray) -> np.ndarray: 381 | """Find all the overlapping polygon bounding boxes. 382 | Each bounding box has the following structure: 383 | bbox = np.array([x_min,y_min,x_max,y_max]) 384 | In 3D space, if the coordinates are equal (polygon bboxes touch), then these are considered overlapping. 385 | We have a guarantee that the cropped image will have any sort of overlap with the zero'th object bounding box 386 | inside of the image e.g. along the x-dimension, either the left or right side of the bounding box lies between the 387 | edges of the query bounding box, or the bounding box completely engulfs the query bounding box. 388 | Args: 389 | polygon_bboxes: An array of shape (K, 4), each array element is a NumPy array of shape (4,) representing 390 | the bounding box for a polygon or point cloud. 391 | query_bbox: An array of shape (4,) representing a 2d axis-aligned bounding box, with order 392 | [min_x,min_y,max_x,max_y]. 393 | lane_starts: An array of shape (, 4), representing the start point of lane left bound and right bound 394 | lane_ends: An array of shape (, 4), representing the end point of lane left bound and right bound 395 | Returns: 396 | An integer array of shape (K,) representing indices where overlap occurs. 397 | """ 398 | query_min_x = query_bbox[0] 399 | query_min_y = query_bbox[1] 400 | 401 | query_max_x = query_bbox[2] 402 | query_max_y = query_bbox[3] 403 | 404 | bboxes_x1 = polygon_bboxes[:, 0] 405 | bboxes_x2 = polygon_bboxes[:, 2] 406 | 407 | bboxes_y1 = polygon_bboxes[:, 1] 408 | bboxes_y2 = polygon_bboxes[:, 3] 409 | 410 | # check if falls within range 411 | overlaps_left = (query_min_x <= bboxes_x2) & (bboxes_x2 <= query_max_x) 412 | overlaps_right = (query_min_x <= bboxes_x1) & (bboxes_x1 <= query_max_x) 413 | 414 | x_check1 = bboxes_x1 <= query_min_x 415 | x_check2 = query_min_x <= query_max_x 416 | x_check3 = query_max_x <= bboxes_x2 417 | x_subsumed = x_check1 & x_check2 & x_check3 418 | 419 | x_in_range = overlaps_left | overlaps_right | x_subsumed 420 | 421 | overlaps_below = (query_min_y <= bboxes_y2) & (bboxes_y2 <= query_max_y) 422 | overlaps_above = (query_min_y <= bboxes_y1) & (bboxes_y1 <= query_max_y) 423 | 424 | y_check1 = bboxes_y1 <= query_min_y 425 | y_check2 = query_min_y <= query_max_y 426 | y_check3 = query_max_y <= bboxes_y2 427 | y_subsumed = y_check1 & y_check2 & y_check3 428 | y_in_range = overlaps_below | overlaps_above | y_subsumed 429 | 430 | # at least one lane endpoint in range 431 | # xy_check1 = (query_min_x <= lane_starts[:,0]) & (lane_starts[:,0] <= query_max_x) & \ 432 | # (query_min_y <= lane_starts[:,1]) & (lane_starts[:,1] <= query_max_y) 433 | # xy_check2 = (query_min_x <= lane_starts[:,2]) & (lane_starts[:,2] <= query_max_x) & \ 434 | # (query_min_y <= lane_starts[:,3]) & (lane_starts[:,3] <= query_max_y) 435 | # xy_check3 = (query_min_x <= lane_ends[:,0]) & (lane_ends[:,0] <= query_max_x) & \ 436 | # (query_min_y <= lane_ends[:,1]) & (lane_ends[:,1] <= query_max_y) 437 | # xy_check4 = (query_min_x <= lane_ends[:,2]) & (lane_ends[:,2] <= query_max_x) & \ 438 | # (query_min_y <= lane_ends[:,3]) & (lane_ends[:,3] <= query_max_y) 439 | # xy_in_range = xy_check1 | xy_check2 | xy_check3 | xy_check4 440 | 441 | # overlap_indxs = np.where(x_in_range & y_in_range & xy_in_range)[0] 442 | 443 | overlap_indxs = np.where(x_in_range & y_in_range)[0] 444 | return overlap_indxs 445 | 446 | def get_road_ids_in_xy_bbox( 447 | polygon_bboxes, 448 | lane_starts, 449 | lane_ends, 450 | roads, 451 | query_x: float, 452 | query_y: float, 453 | query_search_range_manhattan: float = 50.0, 454 | ): 455 | """ 456 | Prune away all lane segments based on Manhattan distance. We vectorize this instead 457 | of using a for-loop. Get all lane IDs within a bounding box in the xy plane. 458 | This is a approximation of a bubble search for point-to-polygon distance. 459 | The bounding boxes of small point clouds (lane centerline waypoints) are precomputed in the map. 460 | We then can perform an efficient search based on manhattan distance search radius from a 461 | given 2D query point. 462 | We pre-assign lane segment IDs to indices inside a big lookup array, with precomputed 463 | hallucinated lane polygon extents. 464 | Args: 465 | query_x: representing x coordinate of xy query location 466 | query_y: representing y coordinate of xy query location 467 | city_name: either 'MIA' for Miami or 'PIT' for Pittsburgh 468 | query_search_range_manhattan: search radius along axes 469 | Returns: 470 | lane_ids: lane segment IDs that live within a bubble 471 | """ 472 | query_min_x = query_x - query_search_range_manhattan 473 | query_max_x = query_x + query_search_range_manhattan 474 | query_min_y = query_y - query_search_range_manhattan 475 | query_max_y = query_y + query_search_range_manhattan 476 | 477 | overlap_indxs = find_all_polygon_bboxes_overlapping_query_bbox( 478 | polygon_bboxes, 479 | np.array([query_min_x, query_min_y, query_max_x, query_max_y],), 480 | lane_starts, 481 | lane_ends 482 | ) 483 | 484 | if len(overlap_indxs) == 0: 485 | return [] 486 | 487 | neighborhood_road_ids = [] 488 | for overlap_idx in overlap_indxs: 489 | lane_segment_id = list(roads.keys())[overlap_idx] 490 | neighborhood_road_ids.append(lane_segment_id) 491 | 492 | return neighborhood_road_ids 493 | 494 | if __name__ == "__main__": 495 | roads = load_lane_segments_from_xml("Town03.osm") 496 | polygon_bboxes = build_polygon_bboxes(roads) 497 | query_x = 5.772 498 | query_y = 119.542 499 | cv_range = 50 500 | neighborhood_road_ids = get_road_ids_in_xy_bbox(polygon_bboxes, query_x, query_y, cv_range) 501 | 502 | 503 | # # %% 504 | # plt.figure(dpi=200) 505 | # fig, (ax1,ax2) = plt.subplots(1,2) 506 | # fig.set_figheight(2) 507 | # fig.set_figwidth(4) 508 | # for i in roads.keys(): 509 | 510 | # road_id = i 511 | # ax1.plot(roads[road_id].l_bound[:,0], roads[road_id].l_bound[:,1], color='k')#, marker='o', markerfacecolor='blue', markersize=5) 512 | # ax1.plot(roads[road_id].r_bound[:,0], roads[road_id].r_bound[:,1], color='k')#, marker='o', markerfacecolor='red', markersize=5) 513 | # ax1.plot((roads[road_id].l_bound[:,0]+roads[road_id].r_bound[:,0])/2, (roads[road_id].l_bound[:,1]+roads[road_id].r_bound[:,1])/2, color="0.7",linestyle='dashed') 514 | # ax2.plot(roads[road_id].l_bound[:,0], roads[road_id].l_bound[:,1], color='k')#, marker='o', markerfacecolor='blue', markersize=5) 515 | # ax2.plot(roads[road_id].r_bound[:,0], roads[road_id].r_bound[:,1], color='k')#, marker='o', markerfacecolor='red', markersize=5) 516 | # ax2.plot((roads[road_id].l_bound[:,0]+roads[road_id].r_bound[:,0])/2, (roads[road_id].l_bound[:,1]+roads[road_id].r_bound[:,1])/2, color="0.7",linestyle='dashed') 517 | 518 | # ax1.set_xlim([-60,60]) 519 | # ax1.set_ylim([-60,60]) 520 | # ax2.set_xlim([60,120]) 521 | # ax2.set_ylim([80,180]) 522 | # ax1.axis("off") 523 | # ax2.axis("off") 524 | # # plt.show() 525 | # plt.savefig("town03_lane_segment.jpg") 526 | # # %% 527 | # # plot one lane segment 528 | # for i in roads.keys(): 529 | # road_id = i 530 | # if min(roads[road_id].l_bound[:,0])>60 and max(roads[road_id].l_bound[:,1])>-20 and max(roads[road_id].r_bound[:,0])<120 and max(roads[road_id].r_bound[:,1])<70: 531 | 532 | # plt.plot(roads[road_id].l_bound[:,0], roads[road_id].l_bound[:,1], color='0.7')#, marker='o', markerfacecolor='blue', markersize=5) 533 | # plt.plot(roads[road_id].r_bound[:,0], roads[road_id].r_bound[:,1], color='0.7')#, marker='o', markerfacecolor='red', markersize=5) 534 | # # plt. 535 | # # plt.xlim((60,120)) 536 | # # plt.ylim((80,180)) 537 | # # plt.axis("off") 538 | # plt.show() 539 | 540 | # # %% 541 | # for i in roads.keys(): 542 | # road_id = i 543 | # plt.plot(roads[road_id].l_bound[:,0], roads[road_id].l_bound[:,1], color='0.7')#, marker='o', markerfacecolor='blue', markersize=5) 544 | # plt.plot(roads[road_id].r_bound[:,0], roads[road_id].r_bound[:,1], color='0.7')#, marker='o', markerfacecolor='red', markersize=5) 545 | # plt.show() 546 | # # %% 547 | -------------------------------------------------------------------------------- /dataloader/visualization.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from carla_scene_process import CarlaData 3 | import torch 4 | import numpy as np 5 | 6 | def visualize_centerline(centerline) -> None: 7 | """Visualize the computed centerline. 8 | Args: 9 | centerline: Sequence of coordinates forming the centerline 10 | """ 11 | line_coords = list(zip(*centerline)) 12 | lineX = line_coords[0] 13 | lineY = line_coords[1] 14 | plt.plot(lineX, lineY, "--", color="grey", alpha=1, linewidth=1, zorder=0) 15 | # plt.text(lineX[0], lineY[0], "s") 16 | # plt.text(lineX[-1], lineY[-1], "e") 17 | plt.axis("equal") 18 | 19 | def get_rotate_invariant_trajs(data: CarlaData): 20 | 21 | rotate_mat = torch.empty(data.num_nodes, 2, 2) 22 | sin_vals = torch.sin(data['rotate_angles']) 23 | cos_vals = torch.cos(data['rotate_angles']) 24 | rotate_mat[:, 0, 0] = cos_vals 25 | rotate_mat[:, 0, 1] = -sin_vals 26 | rotate_mat[:, 1, 0] = sin_vals 27 | rotate_mat[:, 1, 1] = cos_vals 28 | 29 | xrot = torch.bmm(data.positions[:,20:50,:], rotate_mat) 30 | yrot = torch.bmm(data.y, rotate_mat) 31 | # for i in range(xrot.shape[0]): 32 | # plt.plot(xrot[i,:,0], xrot[i,:,1]) 33 | # plt.plot(data.x_sensor[i,:,0], data.x_sensor[i,:,1],'--') 34 | # for i in range(yrot.shape[0]): 35 | # plt.plot(yrot[i,:,0], yrot[i,:,1]) 36 | # plt.plot(data.y[i,:,0], data.y[i,:,1],'--') 37 | 38 | return xrot, yrot, rotate_mat 39 | def viz_devectorize(xrot_vec): 40 | """ 41 | xrot_vec: rotated vector [N,30,2] 42 | """ 43 | x_devec = torch.cumsum(xrot_vec, dim=1) 44 | # translate back to original location 45 | x_devec_ori = x_devec - x_devec[:,-1,:] 46 | for i in range(x_devec_ori.shape[0]): 47 | plt.plot(x_devec_ori[i,:,0], x_devec_ori[i,:,1]) 48 | 49 | def local_invariant_scenes(data: CarlaData): 50 | xrot, yrot, rotate_mat = get_rotate_invariant_trajs(data) 51 | lane_str, lane_vectors = data.lane_pos, data.lane_vectors 52 | lane_idcs = data.lane_idcs 53 | # # visualize the centerlines 54 | # lane_pos = data.lane_pos 55 | # lane_vectors = data.lane_vectors 56 | # lane_idcs = data.lane_idcs 57 | # for i in torch.unique(lane_idcs): 58 | # lane_str = lane_pos[lane_idcs == i] 59 | # lane_vector = lane_vectors[lane_idcs == i] 60 | # lane_end = lane_str + lane_vector 61 | # lane = torch.vstack([lane_str, lane_end[-1,:].reshape(-1, 2)]) 62 | # visualize_centerline(lane) 63 | 64 | #rotate locally 65 | edge_index = data.lane_actor_index 66 | 67 | lane_rotate_mat = rotate_mat[edge_index[1]] 68 | lane_vectors_rot = torch.bmm(lane_vectors[edge_index[0]].unsqueeze(-2), lane_rotate_mat).squeeze(-2) #[#, 2] 69 | lane_pos_rot = torch.bmm(lane_str[edge_index[0]].unsqueeze(-2), lane_rotate_mat).squeeze(-2) #[#, 2] 70 | 71 | #viz local map and traj 72 | for i in range(data.num_nodes): 73 | #traj viz 74 | plt.plot(xrot[i,:,0], xrot[i,:,1]) 75 | plt.text(xrot[i,-1,0], xrot[i,-1,1], "q") 76 | #map viz 77 | lane_idx_i = (edge_index[1] == i).nonzero().squeeze() 78 | for j in lane_idx_i: 79 | # lane_str_i = lane_pos_rot[edge_index[1] == i] 80 | lane_str_i = lane_pos_rot[j].unsqueeze(0) #[1,2] 81 | # lane_vector_i = lane_vectors_rot[edge_index[1] == i] 82 | lane_vector_i = lane_vectors_rot[j].unsqueeze(0) 83 | lane_end_i = lane_str_i + lane_vector_i 84 | lane_i = torch.vstack([lane_str_i, lane_end_i]) 85 | visualize_centerline(lane_i) 86 | 87 | 88 | #for each agent, get self-centered maps 89 | for i in range(xrot.shape[0]): 90 | lane_vector_i = lane_vectors_rot[edge_index[1]==i] 91 | lane_pos_i = lane_pos_rot[edge_index[1]==i] 92 | lane_end_i = lane_vector_i + lane_pos_i 93 | lane_i = torch.vstack([lane_pos_i, lane_end_i[-1,:].reshape(-1, 2)]) #[L, 2] 94 | 95 | visualize_centerline(lane_i) 96 | 97 | # visualize the centerlines 98 | lane_pos = data.lane_pos 99 | lane_vectors = data.lane_vectors 100 | lane_idcs = data.lane_idcs 101 | for i in torch.unique(lane_idcs): 102 | lane_str = lane_pos[lane_idcs == i] 103 | lane_vector = lane_vectors[lane_idcs == i] 104 | lane_end = lane_str + lane_vector 105 | lane = torch.vstack([lane_str, lane_end[-1,:].reshape(-1, 2)]) 106 | visualize_centerline(lane) 107 | 108 | for i in range(data.x.shape[0]): 109 | lane_vector_i = lane_vectors[edge_index[0]][edge_index[1]==i] 110 | lane_pos_i = lane_str[edge_index[0]][edge_index[1]==i] 111 | lane_end_i = lane_vector_i + lane_pos_i 112 | lane_i = torch.vstack([lane_pos_i, lane_end_i[-1,:].reshape(-1, 2)]) #[L, 2] 113 | 114 | visualize_centerline(lane_i) 115 | 116 | def viz_lane_rot(): 117 | pass 118 | 119 | def tensor_viz(node_features_all, cav_mask, commu_mask, sensor_mask): 120 | 121 | axes = [8, 16, 3] 122 | filled = np.ones(axes, dtype=np.bool) 123 | colors = np.empty(axes + [4], dtype=np.float32) 124 | alpha = 0.5 125 | colors[:] = [1, 1, 1, alpha] 126 | colors[cav_mask,:,0] = [1, 0, 0, alpha] 127 | colors[commu_mask,:,1] = [0, 1, 0, alpha] 128 | colors[sensor_mask,:,2] = [0, 0, 1, alpha] 129 | 130 | fig = plt.figure() 131 | ax = fig.add_subplot(projection='3d') 132 | ax.voxels(filled, facecolors=colors, edgecolors='grey',shade=True) 133 | plt.show() 134 | plt.axis('off') 135 | 136 | 137 | -------------------------------------------------------------------------------- /losses/__pycache__/get_anchors.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/get_anchors.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/loss.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/msma_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/msma_loss.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/mtp_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/mtp_loss.cpython-37.pyc -------------------------------------------------------------------------------- /losses/__pycache__/multipath_loss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/multipath_loss.cpython-37.pyc -------------------------------------------------------------------------------- /losses/get_anchors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sklearn.cluster import KMeans 3 | # import psutil 4 | # import ray 5 | # from scipy.spatial.distance import cdist 6 | 7 | #Initialize device: 8 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 9 | 10 | # # Initialize ray: 11 | # num_cpus = psutil.cpu_count(logical=False) 12 | # ray.init(num_cpus=num_cpus, log_to_driver=False) 13 | 14 | def k_means_anchors(k, train_loader): 15 | """ 16 | Extract anchors for multipath/covernet using k-means on train set trajectories 17 | gt_y: [num_v, op_len, 2] 18 | train_loader: CarlaData 19 | """ 20 | 21 | trajectories = [] 22 | rotate_imat= [] 23 | for i, data in enumerate(train_loader): 24 | trajectories.append(data.y) 25 | rotate_imat.append(data.rotate_imat) 26 | 27 | traj_all = torch.cat(trajectories, dim=0) 28 | rotate_imat_all = torch.cat(rotate_imat, dim=0) 29 | traj_all_rot = torch.matmul(traj_all, rotate_imat_all) 30 | 31 | clustering = KMeans(n_clusters=k).fit(traj_all_rot.reshape((traj_all_rot.shape[0], -1))) 32 | op_len, op_dim = traj_all_rot.shape[1], traj_all_rot.shape[2] 33 | anchors = torch.zeros((k, op_len, op_dim)).to(device) 34 | for i in range(k): 35 | anchors[i] = torch.mean(traj_all_rot[clustering.labels_==i], axis=0) 36 | # for i in range(traj_all_rot.shape[0]): 37 | # plt.plot(traj_all_rot[i, :, 0], traj_all_rot[i, :, 1]) 38 | # for i in range(anchors.shape[0]): 39 | # plt.plot(anchors[i, :, 0], anchors[i, :, 1]) 40 | 41 | return anchors 42 | 43 | 44 | def bivariate_gaussian_activation(ip: torch.Tensor) -> torch.Tensor: 45 | """ 46 | Activation function to output parameters of bivariate Gaussian distribution 47 | """ 48 | mu_x = ip[..., 0:1] 49 | mu_y = ip[..., 1:2] 50 | sig_x = ip[..., 2:3] 51 | sig_y = ip[..., 3:4] 52 | rho = ip[..., 4:5] 53 | sig_x = torch.exp(sig_x) 54 | sig_y = torch.exp(sig_y) 55 | rho = torch.tanh(rho) 56 | out = torch.cat([mu_x, mu_y, sig_x, sig_y, rho], dim = -1) 57 | 58 | return out -------------------------------------------------------------------------------- /losses/hivt_loss.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/ZikangZhou/HiVT/blob/main/ 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class LaplaceNLLLoss(nn.Module): 7 | 8 | def __init__(self, 9 | eps: float = 1e-6, 10 | reduction: str = 'mean') -> None: 11 | super(LaplaceNLLLoss, self).__init__() 12 | self.eps = eps 13 | self.reduction = reduction 14 | 15 | def forward(self, 16 | y_hat: torch.Tensor, 17 | y_gt: torch.Tensor, 18 | pi: torch.Tensor) -> torch.Tensor: 19 | loc, scale = pred.chunk(2, dim=-1) 20 | scale = scale.clone() 21 | with torch.no_grad(): 22 | scale.clamp_(min=self.eps) 23 | nll = torch.log(2 * scale) + torch.abs(target - loc) / scale 24 | if self.reduction == 'mean': 25 | return nll.mean() 26 | elif self.reduction == 'sum': 27 | return nll.sum() 28 | elif self.reduction == 'none': 29 | return nll 30 | else: 31 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction)) 32 | 33 | class SoftTargetCrossEntropyLoss(nn.Module): 34 | 35 | def __init__(self, reduction: str = 'mean') -> None: 36 | super(SoftTargetCrossEntropyLoss, self).__init__() 37 | self.reduction = reduction 38 | 39 | def forward(self, 40 | pred: torch.Tensor, 41 | target: torch.Tensor) -> torch.Tensor: 42 | cross_entropy = torch.sum(-target * F.log_softmax(pred, dim=-1), dim=-1) 43 | if self.reduction == 'mean': 44 | return cross_entropy.mean() 45 | elif self.reduction == 'sum': 46 | return cross_entropy.sum() 47 | elif self.reduction == 'none': 48 | return cross_entropy 49 | else: 50 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction)) -------------------------------------------------------------------------------- /losses/msma_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from metrics.metric import min_ade, traj_nll 4 | 5 | class NLLloss(nn.Module): 6 | """ 7 | MTP loss modified to include variances. Uses MSE for mode selection. Can also be used with 8 | Multipath outputs, with residuals added to anchors. 9 | """ 10 | def __init__(self, alpha=0.2, use_variance=True, device='cpu'): 11 | """ 12 | Initialize MSMA loss 13 | :param args: Dictionary with the following (optional) keys 14 | use_variance: bool, whether or not to use variances for computing regression component of loss, 15 | default: False 16 | alpha: float, relative weight assigned to classification component, compared to regression component 17 | of loss, default: 1 18 | """ 19 | super(NLLloss, self).__init__() 20 | self.use_variance = use_variance 21 | self.alpha = alpha 22 | self.device = device 23 | 24 | def forward(self, y_pred, y_true, log_probs): 25 | """ 26 | params: 27 | :y_pred: [num_nodes, num_modes, op_len, 2] 28 | :y_true: [num_nodes, op_len, 2] 29 | :log_probs: probability for each mode [N_B, N_M] 30 | where N_B is batch_size, N_M is num_modes, op_len is target_len 31 | """ 32 | 33 | 34 | num_nodes = y_true.shape[0] 35 | l2_norm = (torch.norm(y_pred - y_true.unsqueeze(1), p=2, dim=-1)).sum(dim=-1) 36 | best_mode = l2_norm.argmin(dim=1) 37 | pred_best = y_pred[torch.arange(num_nodes), best_mode, :, :] 38 | 39 | 40 | loss_cls = (-log_probs[torch.arange(num_nodes).to(self.device), best_mode].squeeze()).mean() #[N_B] 41 | 42 | loss_reg = (torch.norm(pred_best-y_true, p=2, dim=-1)).mean() 43 | 44 | 45 | loss = loss_reg + self.alpha * loss_cls 46 | 47 | return loss -------------------------------------------------------------------------------- /losses/mtp_loss.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/nachiket92/PGP/blob/main/metrics/mtp_loss.py 2 | import torch 3 | import torch.nn as nn 4 | from metrics.metric import min_ade, traj_nll 5 | 6 | class NLLloss(nn.Module): 7 | """ 8 | MTP loss modified to include variances. Uses MSE for mode selection. Can also be used with 9 | Multipath outputs, with residuals added to anchors. 10 | """ 11 | def __init__(self, alpha=0.2, use_variance=True): 12 | """ 13 | Initialize MTP loss 14 | :param args: Dictionary with the following (optional) keys 15 | use_variance: bool, whether or not to use variances for computing regression component of loss, 16 | default: False 17 | alpha: float, relative weight assigned to classification component, compared to regression component 18 | of loss, default: 1 19 | """ 20 | super(NLLloss, self).__init__() 21 | self.use_variance = use_variance 22 | self.alpha = alpha 23 | 24 | def forward(self, y_pred, y_gt, log_probs): 25 | """ 26 | params: 27 | :y_pred: [num_vehs, num_modes, op_len, op_dim] 28 | :y_gt: [num_vehs, op_len, 2] 29 | :log_probs: probability for each mode [num_vehs, num_modes] 30 | :alpha: float, relative weight assigned to classification component, compared to regression component 31 | of loss, default: 1 32 | """ 33 | alpha = self.alpha 34 | use_variance = self.use_variance 35 | # Obtain mode with minimum ADE with respect to ground truth: 36 | op_len = y_pred.shape[2] 37 | pred_params = 5 if use_variance else 2 38 | 39 | errs, inds = min_ade(y_pred, y_gt) 40 | inds_rep = inds.repeat(op_len, pred_params, 1, 1).permute(3, 2, 0, 1) 41 | 42 | # Calculate MSE or NLL loss for trajectories corresponding to selected outputs: 43 | traj_best = y_pred.gather(1, inds_rep).squeeze(dim=1) 44 | # # devectorize traj_best 45 | # for i in range(1,50): 46 | # traj_best[:,i,:] += traj_best[:,i-1,:] 47 | 48 | if use_variance: 49 | l_reg = traj_nll(traj_best, y_gt) 50 | else: 51 | l_reg = errs 52 | 53 | # Compute classification loss 54 | l_class = - torch.squeeze(log_probs.gather(1, inds.unsqueeze(1))) 55 | 56 | loss = l_reg + alpha * l_class 57 | loss = torch.mean(loss) 58 | 59 | return loss 60 | 61 | 62 | 63 | -------------------------------------------------------------------------------- /losses/multipath_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from metrics.metric import min_ade, traj_nll 4 | 5 | class NLLloss(nn.Module): 6 | """ 7 | MTP loss modified to include variances. Uses MSE for mode selection. Can also be used with 8 | Multipath outputs, with residuals added to anchors. 9 | """ 10 | def __init__(self, alpha=0.2, use_variance=True): 11 | """ 12 | Initialize MTP loss 13 | :param args: Dictionary with the following (optional) keys 14 | use_variance: bool, whether or not to use variances for computing regression component of loss, 15 | default: False 16 | alpha: float, relative weight assigned to classification component, compared to regression component 17 | of loss, default: 1 18 | """ 19 | super(NLLloss, self).__init__() 20 | self.use_variance = use_variance 21 | self.alpha = alpha 22 | 23 | def forward(self, y_pred, y_true, log_probs, anchors): 24 | """ 25 | params: 26 | :y_pred: [num_nodes, num_modes, op_len, 2] 27 | :y_true: [num_nodes, op_len, 2] 28 | :log_probs: probability for each mode [N_B, N_M] 29 | :anchors: [num_modes, op_len, 2] 30 | where N_B is batch_size, N_M is num_modes, N_T is target_len 31 | """ 32 | 33 | 34 | num_nodes = y_true.shape[0] 35 | trajectories = y_pred 36 | anchor_probs = log_probs 37 | 38 | #find the nearest anchor mode to y_true 39 | #[1, num_modes, op_len, 2] - [num_nodes, 1, op_len, 2] = [num_nodes, num_modes, op_len, 2] 40 | distance_to_anchors = torch.sum(torch.linalg.vector_norm(anchors.unsqueeze(0) - y_true.unsqueeze(1), 41 | dim=-1),dim=-1) #[num_nodes, num_modes] 42 | 43 | nearest_mode = distance_to_anchors.argmin(dim=-1) #[num_nodes] 44 | nearest_mode_indices = torch.stack([torch.arange(num_nodes,dtype=torch.int64),nearest_mode],dim=-1) 45 | 46 | loss_cls = -log_probs[torch.arange(num_nodes),nearest_mode].squeeze() #[N_B] 47 | 48 | trajectories_xy = y_pred + anchors.unsqueeze(0) 49 | # l2_norm = (torch.norm(trajectories_xy[:, :, :, :2] - y_true.unsqueeze(1), p=2, dim=-1)).sum(dim=-1) # [num_nodes, num_modes] 50 | 51 | nearest_trajs = trajectories_xy[torch.arange(num_nodes),nearest_mode,:,:].squeeze() 52 | residual_trajs = y_true - nearest_trajs 53 | 54 | loss_reg = torch.mean(torch.square(residual_trajs[:,:,0])+torch.square(residual_trajs[:,:,1]), dim=-1) 55 | dx = residual_trajs[:,:,0] 56 | dy = residual_trajs[:,:,1] 57 | 58 | loss = loss_reg + self.alpha * loss_cls 59 | loss = torch.mean(loss) 60 | 61 | return loss 62 | 63 | 64 | -------------------------------------------------------------------------------- /metrics/__pycache__/ade.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/metrics/__pycache__/ade.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/fde.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/metrics/__pycache__/fde.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/metric.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/metrics/__pycache__/metric.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/mr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/metrics/__pycache__/mr.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/ade.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | 7 | class ADE(Metric): 8 | 9 | def __init__(self, 10 | compute_on_step: bool = True, 11 | dist_sync_on_step: bool = False, 12 | process_group: Optional[Any] = None, 13 | dist_sync_fn: Callable = None) -> None: 14 | super(ADE, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, 15 | process_group=process_group, dist_sync_fn=dist_sync_fn) 16 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 17 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 18 | 19 | def update(self, 20 | pred: torch.Tensor, 21 | target: torch.Tensor) -> None: 22 | self.sum += torch.norm(pred - target, p=2, dim=-1).mean(dim=-1).sum() 23 | self.count += pred.size(0) 24 | 25 | def compute(self) -> torch.Tensor: 26 | return self.sum / self.count -------------------------------------------------------------------------------- /metrics/fde.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | 7 | class FDE(Metric): 8 | 9 | def __init__(self, 10 | compute_on_step: bool = True, 11 | dist_sync_on_step: bool = False, 12 | process_group: Optional[Any] = None, 13 | dist_sync_fn: Callable = None) -> None: 14 | super(FDE, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, 15 | process_group=process_group, dist_sync_fn=dist_sync_fn) 16 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 17 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 18 | 19 | def update(self, 20 | pred: torch.Tensor, 21 | target: torch.Tensor) -> None: 22 | self.sum += torch.norm(pred[:, -1] - target[:, -1], p=2, dim=-1).sum() 23 | self.count += pred.size(0) 24 | 25 | def compute(self) -> torch.Tensor: 26 | return self.sum / self.count -------------------------------------------------------------------------------- /metrics/metric.py: -------------------------------------------------------------------------------- 1 | #source: https://github.com/nachiket92/PGP/blob/main/metrics/utils.py 2 | import torch 3 | from typing import Tuple 4 | 5 | def ade(traj: torch.Tensor, traj_gt: torch.Tensor): 6 | ls = torch.norm(traj - traj_gt, p=2, dim=-1).mean(dim=-1).mean() 7 | 8 | return ls 9 | 10 | def fde(traj: torch.Tensor, traj_gt: torch.Tensor): 11 | ls = torch.norm(traj[:, -1] - traj_gt[:, -1], p=2, dim=-1).mean() 12 | 13 | return ls 14 | 15 | def mr(traj: torch.Tensor, traj_gt: torch.Tensor, miss_threshold: torch.Tensor): 16 | ls = (torch.norm(traj[:, -1] - traj_gt[:, -1], p=2, dim=-1) > miss_threshold).sum() 17 | 18 | return ls/traj.shape[0] 19 | 20 | 21 | def min_ade(traj: torch.Tensor, traj_gt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 22 | """ 23 | Computes average displacement error for the best trajectory in a set, with respect to ground truth 24 | :param traj: predictions, shape [num_vehs, num_modes, op_len, 2] 25 | :param traj_gt: ground truth trajectory, shape [num_vehs, op_len, 2] 26 | :return errs, inds: errors and indices for modes with min error, shape [num_vehs] 27 | """ 28 | num_modes = traj.shape[1] 29 | op_len = traj.shape[2] 30 | 31 | traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1) 32 | # masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1) 33 | 34 | err = (traj_gt_rpt - traj[:, :, :, 0:2]) 35 | err = torch.pow(err, exponent=2) 36 | err = torch.sum(err, dim=3) 37 | err = torch.pow(err, exponent=0.5) 38 | err = torch.sum(err, dim=2) / op_len 39 | 40 | # err[stat_idx,:] = err[stat_idx,:]*10000 41 | 42 | err, inds = torch.min(err, dim=1) 43 | 44 | return err, inds 45 | 46 | def traj_nll(pred_dist: torch.Tensor, traj_gt: torch.Tensor): 47 | """ 48 | Computes negative log likelihood of ground truth trajectory under a predictive distribution with a single mode, 49 | with a bivariate Gaussian distribution predicted at each time in the prediction horizon 50 | 51 | :param pred_dist: parameters of a bivariate Gaussian distribution, shape [num_vehs, op_len, 5] 52 | :param traj_gt: ground truth trajectory, shape [num_vehs, op_len, 2] 53 | :return: 54 | """ 55 | # op_len = pred_dist.shape[1] 56 | # mu_x = pred_dist[:, :, 0] 57 | # mu_y = pred_dist[:, :, 1] 58 | # x = traj_gt[:, :, 0] 59 | # y = traj_gt[:, :, 1] 60 | 61 | # sig_x = pred_dist[:, :, 2] 62 | # sig_y = pred_dist[:, :, 3] 63 | # rho = pred_dist[:, :, 4] 64 | # ohr = torch.pow(1 - torch.pow(rho, 2), -0.5) 65 | 66 | # nll = 0.5 * torch.pow(ohr, 2) * \ 67 | # (torch.pow(sig_x, 2) * torch.pow(x - mu_x, 2) + 68 | # torch.pow(sig_y, 2) * torch.pow(y - mu_y, 2) - 69 | # 2 * rho * torch.pow(sig_x, 1) * torch.pow(sig_y, 1) * (x - mu_x) * (y - mu_y))\ 70 | # - torch.log(sig_x * sig_y * ohr) + 1.8379 71 | 72 | # nll[nll.isnan()] = 0 73 | # nll[nll.isinf()] = 0 74 | 75 | # nll = torch.sum(nll, dim=1) / op_len 76 | pred_loc = pred_dist[:,:,:2] 77 | pred_var = pred_dist[:,:,2:4] 78 | 79 | nll = torch.sum(0.5 * torch.log(pred_var) + 0.5 * torch.div(torch.square(traj_gt - pred_loc), pred_var) +\ 80 | 0.5 * torch.log(2 * torch.tensor(3.14159265358979323846))) 81 | 82 | 83 | return nll 84 | 85 | def NLLloss(y_pred, y_true, log_probs, anchors): 86 | """ 87 | params: 88 | :y_pred: [N_T, N_M, N_B, 2] 89 | :y_true: [N_T, N_B, 2] 90 | :log_probs: probability for each mode [N_B, N_M] 91 | :anchors: [N_M, N_T,2] 92 | where N_B is batch_size, N_M is num_modes, N_T is target_len 93 | """ 94 | 95 | 96 | batch_size = y_true.shape[1] 97 | trajectories = y_pred 98 | anchor_probs = log_probs 99 | 100 | #find the nearest anchor mode to y_true 101 | #[1, N_M, N_T,2] - [N_B, N_M, N_T, 2] = [N_B, N_M, N_T, 2] 102 | distance_to_anchors = torch.sum(torch.linalg.vector_norm(anchors.unsqueeze(0) - y_true.permute(1,0,2).unsqueeze(1), 103 | dim=(-1)),dim=-1) #[N_B, N_M] 104 | 105 | nearest_mode = distance_to_anchors.argmin(dim=-1) #[N_B] 106 | nearest_mode_indices = torch.stack([torch.arange(batch_size,dtype=torch.int64),nearest_mode],dim=-1) 107 | 108 | loss_cls = -log_probs[torch.arange(batch_size),nearest_mode].squeeze() #[N_B] 109 | 110 | #trajectories_xy: [N_B, N_M, N_T, 2] 111 | #nearest_trajs: [N_B, N_T, 2] 112 | #residual_trajs: [N_B, N_T, 2] 113 | trajectories_xy = y_pred.permute(2,1,0,3)[...,:2] + anchors.unsqueeze(0) 114 | nearest_trajs = trajectories_xy[torch.arange(batch_size),nearest_mode,:,:].squeeze() 115 | residual_trajs = y_true.permute(1,0,2) - nearest_trajs 116 | 117 | loss_reg = torch.mean(torch.square(residual_trajs[:,:,0])+torch.square(residual_trajs[:,:,1]), dim=-1) 118 | dx = residual_trajs[:,:,0] 119 | dy = residual_trajs[:,:,1] 120 | 121 | total_loss = torch.mean(loss_cls+loss_reg) 122 | 123 | return loss_cls, loss_reg 124 | -------------------------------------------------------------------------------- /metrics/mr.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | import torch 4 | from torchmetrics import Metric 5 | 6 | 7 | class MR(Metric): 8 | 9 | def __init__(self, 10 | miss_threshold: float = 2.0, 11 | compute_on_step: bool = True, 12 | dist_sync_on_step: bool = False, 13 | process_group: Optional[Any] = None, 14 | dist_sync_fn: Callable = None) -> None: 15 | super(MR, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, 16 | process_group=process_group, dist_sync_fn=dist_sync_fn) 17 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 18 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 19 | self.miss_threshold = miss_threshold 20 | 21 | def update(self, 22 | pred: torch.Tensor, 23 | target: torch.Tensor) -> None: 24 | self.sum += (torch.norm(pred[:, -1] - target[:, -1], p=2, dim=-1) > self.miss_threshold).sum() 25 | self.count += pred.size(0) 26 | 27 | def compute(self) -> torch.Tensor: 28 | return self.sum / self.count -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from os.path import join as pjoin 4 | 5 | import torch 6 | from torch_geometric.loader import DataLoader 7 | from torch.optim import Adam, AdamW 8 | from tqdm import tqdm 9 | import math 10 | 11 | from dataloader.carla_scene_process import CarlaData, scene_processed_dataset 12 | from ModelNet.msma import Base_Net 13 | from torch_geometric.utils import subgraph 14 | from losses.msma_loss import NLLloss 15 | from utils.optim_schedule import ScheduledOptim 16 | 17 | #load/process the data 18 | root = "../carla_data/" 19 | source_dir = "scene_mining" 20 | mpr = 0.8 21 | delay_frame = 1 22 | noise_var = 0.1 23 | save_dir = "scene_mining_cav/mpr8_delay{}_noise{}".format(delay_frame, noise_var) 24 | 25 | train_set = scene_processed_dataset(root, 26 | "train", 27 | mpr=mpr, 28 | delay_frame=delay_frame, 29 | noise_var=noise_var, 30 | source_dir=source_dir, 31 | save_dir=save_dir) 32 | val_set = scene_processed_dataset(root, 33 | "val", 34 | mpr=mpr, 35 | delay_frame=delay_frame, 36 | noise_var=noise_var, 37 | source_dir=source_dir, 38 | save_dir=save_dir) 39 | test_set = scene_processed_dataset(root, 40 | "test", 41 | mpr=mpr, 42 | delay_frame=delay_frame, 43 | noise_var=noise_var, 44 | source_dir=source_dir, 45 | save_dir=save_dir) 46 | #args 47 | batch_size = 64 48 | num_workers = 4 49 | horizon = 50 50 | lr = 1e-3 51 | betas=(0.9, 0.999) 52 | weight_decay = 0.0001 53 | warmup_epoch=10 54 | lr_update_freq=10 55 | lr_decay_rate=0.9 56 | 57 | 58 | log_freq = 10 59 | save_folder = "" 60 | model_path = '../carla_data/scene_mining_cav' 61 | ckpt_path = None 62 | verbose = True 63 | 64 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 65 | 66 | model = Base_Net(ip_dim=2, 67 | historical_steps=30, 68 | embed_dim=16, 69 | temp_ff=64, 70 | spat_hidden_dim=64, 71 | spat_out_dim=64, 72 | edge_attr_dim=2, 73 | map_out_dim=64, 74 | lane_dim=2, 75 | map_local_radius=30, 76 | decoder_hidden_dim=64, 77 | num_heads=8, 78 | dropout=0.1, 79 | num_temporal_layers=4, 80 | use_variance=False, 81 | device="cpu", 82 | commu_only=False, 83 | sensor_only=False, 84 | prediction_mode="all") 85 | 86 | #dataloader 87 | train_loader = DataLoader( 88 | train_set, 89 | batch_size=batch_size, 90 | num_workers=num_workers, 91 | pin_memory=True, 92 | shuffle=True, 93 | persistent_workers=True 94 | ) 95 | eval_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers, persistent_workers=True, shuffle=False) 96 | test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, persistent_workers=True, shuffle=False) 97 | 98 | #loss 99 | criterion = NLLloss(alpha=0.5, use_variance=False, device=device) 100 | # anchors = k_means_anchors(5, train_loader) 101 | 102 | # init optimizer 103 | optim = AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay) 104 | optm_schedule = ScheduledOptim( 105 | optim, 106 | lr, 107 | n_warmup_epoch=warmup_epoch, 108 | update_rate=lr_update_freq, 109 | decay_rate=lr_decay_rate 110 | ) 111 | 112 | model = model.to(device) 113 | if verbose: 114 | print("[MSMATrainer]: Train the mode with single device on {}.".format(device)) 115 | 116 | # model.load_state_dict(torch.load('{}/trained_models_review/model_mpr{}_noise{}_fuse_{}_2.tar'.format(model_path, mpr, noise_var, model.prediction_mode))) 117 | 118 | # iteration 119 | training = model.training 120 | avg_loss = 0.0 121 | avg_loss_val = 0.0 122 | losses_train =[] 123 | losses_val = [] 124 | 125 | epochs = 100 126 | minVal = math.inf 127 | 128 | # %% 129 | 130 | for epoch in range(epochs): 131 | avg_loss = 0.0 132 | ## Train:_______________________________________________________________________________________________________________________________ 133 | training = True 134 | # model.train() 135 | data_iter = tqdm( 136 | enumerate(train_loader), 137 | desc="{}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}".format("train" if training else "eval", 138 | epoch, 139 | 0.0, 140 | avg_loss), 141 | total=len(train_loader), 142 | bar_format="{l_bar}{r_bar}" 143 | ) 144 | count = 0 145 | 146 | for i, data in data_iter: #next(iter(train_loader)) 147 | data = data.to(device) 148 | 149 | if training: 150 | optm_schedule.zero_grad() 151 | predictions, mask = model(data) 152 | gt = torch.matmul(data.y, data.rotate_imat)[mask] 153 | loss = criterion(predictions['traj'], gt, predictions['log_probs']) 154 | loss.backward() 155 | losses_train.append(loss.detach().item()) 156 | 157 | torch.nn.utils.clip_grad_norm_(model.parameters(), 100) 158 | optim.step() 159 | # write_log("Train Loss", loss.detach().item() / n_graph, i + epoch * len(train_loader)) 160 | avg_loss += loss.detach().item() 161 | count += 1 162 | 163 | # print log info 164 | desc_str = "[Info: Device_{}: {}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}]".format( 165 | 0, 166 | "train" if training else "eval", 167 | epoch, 168 | loss.item(), 169 | avg_loss / count) 170 | data_iter.set_description(desc=desc_str, refresh=True) 171 | 172 | if training: 173 | learning_rate = optm_schedule.step_and_update_lr() 174 | if epoch%10==0: 175 | print("learning_rate: ", learning_rate) 176 | # write_log("LR", learning_rate, epoch) 177 | 178 | 179 | ## Val:_______________________________________________________________________________________________________________________________ 180 | training = False 181 | # model.eval() 182 | avg_loss_val = 0.0 183 | count_val = 0 184 | data_iter_val = tqdm(enumerate(eval_loader), desc="{}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}".format("eval", 185 | epoch, 186 | 0.0, 187 | avg_loss_val), 188 | total=len(eval_loader), 189 | bar_format="{l_bar}{r_bar}" 190 | ) 191 | for i, data_val in data_iter_val: 192 | data_val = data_val.to(device) 193 | 194 | with torch.no_grad(): 195 | predictions_val, mask_val = model(data_val) 196 | gt_val = torch.matmul(data_val.y, data_val.rotate_imat)[mask_val] 197 | loss_val = criterion(predictions_val['traj'], 198 | gt_val, predictions_val['log_probs']) 199 | 200 | losses_val.append(loss_val.detach().item()) 201 | avg_loss_val += loss_val.detach().item() 202 | count_val += 1 203 | 204 | # print log info 205 | desc_str_val = "[Info: Device_{}: {}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}]".format( 206 | 0, 207 | "eval", 208 | epoch, 209 | loss_val.item(), 210 | avg_loss_val / count_val) 211 | data_iter_val.set_description(desc=desc_str_val, refresh=True) 212 | 213 | if loss_val.item() < minVal: 214 | minVal = loss_val.item() 215 | torch.save(model.state_dict(), '{}/trained_models_review/model_mpr{}_noise{}_fuse_{}_3.tar'.format(model_path, mpr, noise_var, model.prediction_mode)) 216 | 217 | # %% 218 | ## Test:___________________________________________________________________________________________________________________________________ 219 | def test(model, test_loader, epoch): 220 | """ 221 | make predictions on test dataset 222 | 223 | """ 224 | training = model.training 225 | training = False 226 | # model.training = False 227 | count_test = 0 228 | avg_loss_test = 0.0 229 | predictions_test = {} 230 | gts_test = {} 231 | batch_info = {} 232 | probs = {} 233 | masks = {} 234 | sensor_masks = {} 235 | 236 | data_iter_test = tqdm(enumerate(test_loader), desc="{}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}".format("test", 237 | epoch, 238 | 0.0, 239 | avg_loss_test), 240 | total=len(test_loader), 241 | bar_format="{l_bar}{r_bar}" 242 | ) 243 | for i, data_test in data_iter_test: 244 | data_test = data_test.to(device) 245 | 246 | with torch.no_grad(): 247 | pred_test, mask_test = model(data_test) #pred_test: offset to anchors 248 | gt_test = torch.matmul(data_test.y, data_test.rotate_imat)[mask_test] #aligned at +x axis 249 | #sum of reg and cls loss for all detected vehs 250 | loss_test = criterion(pred_test['traj'], \ 251 | gt_test, pred_test['log_probs']) 252 | 253 | count_test += 1 254 | avg_loss_test += loss_test.detach().item() 255 | #compare predictions for vehs in sensor range when centered at [0,0] but not aligned with x-axis 256 | predictions_test_i = torch.zeros((mask_test.shape[0], 5, 50, 2)).to(device) 257 | predictions_test_i[mask_test]= pred_test["traj"] 258 | predictions_test[i] = torch.matmul(predictions_test_i, \ 259 | torch.inverse(data_test.rotate_imat.unsqueeze(1))) 260 | # predictions_test[i] = torch.matmul(pred_test["traj"] + anchors.unsqueeze(0), \ 261 | # torch.inverse(data_test.rotate_imat[mask_test])) 262 | batch_info[i] = data_test.batch 263 | probs_i = torch.zeros((mask_test.shape[0], 5)).to(device) 264 | probs_i[mask_test] = torch.exp(pred_test['log_probs']) 265 | probs[i] = probs_i 266 | # probs[i] = torch.exp(pred_test['log_probs']) 267 | masks[i] = mask_test 268 | sensor_masks[i] = data_test.sensor_mask 269 | gts_test[i] = data_test.y 270 | 271 | # print log info 272 | desc_str_test = "[Info: Device_{}: {}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}]".format( 273 | 0, 274 | "test", 275 | epoch, 276 | loss_test.item(), 277 | avg_loss_test / count_test) 278 | data_iter_test.set_description(desc=desc_str_test, refresh=True) 279 | 280 | return predictions_test, gts_test, probs, batch_info, masks, sensor_masks 281 | 282 | predictions_av_av, gt_av_av, probs_av_av, batch_av_av, mask_av_av, sensor_mask_av_av = test(model, test_loader, 100) 283 | -------------------------------------------------------------------------------- /utils/__pycache__/optim_schedule.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/utils/__pycache__/optim_schedule.cpython-37.pyc -------------------------------------------------------------------------------- /utils/optim_schedule.py: -------------------------------------------------------------------------------- 1 | # A wrapper class for optimizer 2 | # source: https://github.com/codertimo/BERT-pytorch/blob/master/bert_pytorch/trainer/optim_schedule.py 3 | import numpy as np 4 | 5 | 6 | class ScheduledOptim: 7 | """ A simple wrapper class for learning rate scheduling 8 | """ 9 | 10 | def __init__(self, optimizer, init_lr, n_warmup_epoch=10, update_rate=5, decay_rate=0.9): 11 | self._optimizer = optimizer 12 | self.n_warmup_epoch = n_warmup_epoch 13 | self.n_current_steps = 0 14 | self.init_lr = init_lr 15 | self.update_rate = update_rate 16 | self.decay_rate = decay_rate 17 | 18 | def step_and_update_lr(self): 19 | """Step with the inner optimizer""" 20 | self.n_current_steps += 1 21 | rate = self._update_learning_rate() 22 | 23 | return rate 24 | # self._optimizer.step() 25 | 26 | def zero_grad(self): 27 | "Zero out the gradients by the inner optimizer" 28 | self._optimizer.zero_grad() 29 | 30 | def _get_lr_scale(self): 31 | return np.power(self.decay_rate, max((self.n_current_steps - self.n_warmup_epoch + 1) // self.update_rate + 1, 0)) 32 | 33 | def _update_learning_rate(self): 34 | """ Learning rate scheduling per step """ 35 | 36 | lr = self.init_lr * self._get_lr_scale() 37 | 38 | for param_group in self._optimizer.param_groups: 39 | param_group['lr'] = lr 40 | return lr 41 | 42 | if __name__ == "__main__": 43 | lr = 1e-3 44 | betas=(0.9, 0.999) 45 | weight_decay = 0.0001 46 | warmup_epoch=150000 47 | lr_update_freq=5 48 | lr_decay_rate=0.3 49 | -------------------------------------------------------------------------------- /utils/viz.py: -------------------------------------------------------------------------------- 1 | #architecture picture in test.py on colab 2 | import matplotlib.pyplot as plt 3 | def visualize_centerline(centerline) -> None: 4 | """Visualize the computed centerline. 5 | Args: 6 | centerline: Sequence of coordinates forming the centerline 7 | """ 8 | line_coords = list(zip(*centerline)) 9 | lineX = line_coords[0] 10 | lineY = line_coords[1] 11 | plt.plot(lineX, lineY, "--", color="grey", alpha=1, linewidth=1, zorder=0) 12 | # plt.text(lineX[0], lineY[0], "s") 13 | # plt.text(lineX[-1], lineY[-1], "e") 14 | plt.axis("equal") 15 | 16 | def visualize_map(lane_strs, lane_vecs, lane_idcs): 17 | for i in range(1, len(lane_idcs.unique())): 18 | lane_start = lane_strs[lane_idcs == i] 19 | vecs = lane_vecs[lane_idcs == i] 20 | lane_end = lane_start + vecs 21 | lane = torch.vstack([lane_start, lane_end[-1,:].reshape(-1, 2)]) 22 | visualize_centerline(lane) 23 | 24 | def visualize_traj(prediction, gt, prob, best_mode=True): 25 | """ 26 | prediction: [num_nodes, num_modes, op_len, 2] 27 | gt: [num_nodes, op_len, 2] 28 | prob: [num_nodes, num_modes] 29 | """ 30 | n, m = prediction.shape[0], prediction.shape[1] 31 | 32 | if best_mode: 33 | # prs, inds = torch.max(prob, dim=1) 34 | 35 | # for i in range(n): 36 | # plt.plot(prediction[i,inds[i],:,0], prediction[i,inds[i],:,1]) 37 | # plt.text(prediction[i,inds[i],-1,0], prediction[i,inds[i],-1,1], 38 | # "{:.2f}".format(prs[i].item())) 39 | # plt.plot(gt[i,:,0], gt[i,:,1],'--') 40 | l2_norm = (torch.norm(prediction[:, :, :, : 2] - \ 41 | gt.unsqueeze(1), p=2, dim=-1)).sum(dim=-1) 42 | best_mode = l2_norm.argmin(dim=-1) 43 | y_pred_best = prediction[torch.arange(gt.shape[0]), best_mode, :, : 2] 44 | for i in range(n): 45 | plt.plot(y_pred_best[i,:,0], y_pred_best[i,:,1],'b') 46 | plt.plot(gt[i,:,0], gt[i,:,1], c='orange', linestyle='--') 47 | # circle_ncv = plt.Circle((gt[i,0,0], gt[i,0,1]), 48 | # 1, color='orange') 49 | # plt.gca().add_patch(circle_ncv) 50 | 51 | else: 52 | for i in range(n): 53 | for j in range(m): 54 | plt.plot(prediction[i,j,:,0], prediction[i,j,:,1]) 55 | plt.plot(gt[i,:,0], gt[i,:,1], c='orange', linestyle='--') 56 | circle_ncv = plt.Circle((gt[i,0,0], gt[i,0,1]), 57 | 1, color='orange') 58 | plt.gca().add_patch(circle_ncv) 59 | 60 | def visualize_gt_traj(gt): 61 | for i in range(gt.shape[0]): 62 | plt.plot(gt[i,:,0], gt[i,:,1], c='orange', linestyle='--') 63 | def visualize_pred_traj(pred, prob, best_mode=True): 64 | n, m = pred.shape[0], pred.shape[1] 65 | if best_mode: 66 | prs, inds = torch.max(prob, dim=1) 67 | for i in range(n): 68 | plt.plot(pred[i,inds[i],:,0], pred[i,inds[i],:,1]) 69 | plt.text(pred[i,inds[i],-1,0], pred[i,inds[i],-1,1], 70 | "{:.2f}".format(prs[i].item())) 71 | else: 72 | for i in range(n): 73 | for j in range(m): 74 | plt.plot(pred[i,j,:,0], pred[i,j,:,1]) 75 | 76 | def prediction_viz(sample, batch_size, test_set, predictions, probs, batch, masks, mpr=0): 77 | """ 78 | prediction: [num_nodes, num_modes, op_len, 2] 79 | gt: [num_nodes, op_len, 2] 80 | prob: [num_nodes, num_modes] 81 | """ 82 | s0, s1 = divmod(sample, batch_size) 83 | 84 | #map viz 85 | lane_vecs = test_set.get(sample).lane_vectors 86 | lane_strs = test_set.get(sample).lane_pos 87 | lane_idcs = test_set.get(sample).lane_idcs 88 | # visualize_map(lane_strs, lane_vecs, lane_idcs) 89 | #traj viz 90 | prediction = predictions[s0][batch[s0]==s1,:].cpu() #[num_nodes, num_modes, op_len, 2] 91 | prob = probs[s0][batch[s0]==s1,:].cpu() #[num_nodes, num_modes] 92 | mask = masks[s0][batch[s0]==s1].cpu() #[num_nodes] 93 | gt = test_set.get(sample).y.cpu() #[num_nodes, op_len, 2] 94 | orig = test_set.get(sample).positions[:,49,:].unsqueeze(1) #[num_nodes, 1, 2] 95 | # visualize_traj((prediction+orig.unsqueeze(1))[mask], (gt+orig)[mask], prob[mask], best_mode=True) 96 | #cav 97 | cav_ori = (gt+orig)[test_set.get(sample).cav_mask] 98 | cav_mask = test_set.get(sample).cav_mask 99 | visualize_traj((prediction+orig.unsqueeze(1))[cav_mask], (gt+orig)[cav_mask], prob[cav_mask], best_mode=True) 100 | for i in range(cav_ori.shape[0]): 101 | # plt.plot(cav_ori[i,:,0], cav_ori[i,:,1], 'r') 102 | # circle_cav = plt.Circle((cav_ori[i,0,0], cav_ori[i,0,1]), 103 | # 1, color='r') 104 | l1, = plt.plot(cav_ori[i,0,0], cav_ori[i,0,1], marker=(4, 0, 90), color="r",markersize=5) 105 | circle_commu = plt.Circle((cav_ori[i,0,0], cav_ori[i,0,1]), 106 | 65, color='honeydew') 107 | circle_sensor = plt.Circle((cav_ori[i,0,0], cav_ori[i,0,1]), 108 | 40, color='bisque') 109 | plt.gca().add_patch(circle_commu) 110 | plt.gca().add_patch(circle_sensor) 111 | # plt.gca().add_patch(circle_cav) 112 | #ncv 113 | ncv_ori = (gt+orig)[test_set.get(sample).sensor_mask] 114 | ncv_mask = test_set.get(sample).sensor_mask 115 | # for i in range(ncv_ori.shape[0]): 116 | for i in [0,2,3,4,5,7]: 117 | # plt.plot(ncv_ori[i,:,0], ncv_ori[i,:,1], c='orange') 118 | l2, = plt.plot(ncv_ori[i,0,0], ncv_ori[i,0,1], marker="o",color="darkorange",markersize=5) 119 | visualize_traj((prediction+orig.unsqueeze(1))[ncv_mask][i].unsqueeze(0), (gt+orig)[ncv_mask][i].unsqueeze(0), prob[ncv_mask][i].unsqueeze(0), best_mode=True) 120 | circle_ncv = plt.Circle((ncv_ori[i,0,0], ncv_ori[i,0,1]), 121 | 1, color='orange') 122 | plt.gca().add_patch(circle_ncv) 123 | #cv 124 | cv_ori = (gt+orig)[test_set.get(sample).commu_mask] 125 | cv_mask = test_set.get(sample).commu_mask 126 | 127 | for i in range(1, cv_ori.shape[0]): 128 | # plt.plot(cv_ori[i,:,0], cv_ori[i,:,1], 'g') 129 | l3, = plt.plot(cv_ori[i,0,0], cv_ori[i,0,1], marker="*",color="g",markersize=5) 130 | visualize_traj((prediction+orig.unsqueeze(1))[cv_mask][i].unsqueeze(0), (gt+orig)[cv_mask][i].unsqueeze(0), prob[cv_mask][i].unsqueeze(0), best_mode=True) 131 | # circle_cv = plt.Circle((cv_ori[i,0,0], cv_ori[i,0,1]), 132 | # 1, color='g') 133 | # plt.gca().add_patch(circle_cv) 134 | 135 | # #hist_cav 136 | # positions_cav = test_set.get(sample).positions[[test_set.get(sample).cav_mask]] 137 | # for i in range(positions_cav.shape[0]): 138 | # plt.plot(positions_cav[i,20:50,0], positions_cav[i,20:50,1], 'r--',linewidth=2) 139 | # #hist_ncv 140 | # positions_ncv = test_set.get(sample).positions[[test_set.get(sample).sensor_mask]] 141 | # for i in range(positions_ncv.shape[0]): 142 | # plt.plot(positions_ncv[i,20:50,0], positions_ncv[i,20:50,1], c='orange', linestyle='--',linewidth=2) 143 | # #hist_cv 144 | # positions_cv = test_set.get(sample).positions[[test_set.get(sample).commu_mask]] 145 | # for i in range(1, positions_cv.shape[0]): 146 | # plt.plot(positions_cv[i,20:50,0], positions_cv[i,20:50,1], 'g--',linewidth=2) 147 | # # visualize_gt_traj(gt+orig) 148 | # # visualize_pred_traj((prediction+orig.unsqueeze(1))[mask], prob[mask]) 149 | plt.axis('equal') 150 | plt.axis('off') 151 | # # plt.ylim((-60,80)) 152 | # # plt.xlim((-80,60)) 153 | # # plt.xlabel("position_x(m)") 154 | # # plt.ylabel("position_y(m)") 155 | # # plt.title('mpr={}'.format(mpr)) 156 | sample=452 157 | prediction_viz(sample, batch_size, test4, predictions_cav4_cav4, probs_cav4_cav4, batch_cav4_cav4, mask_cav4_cav4, mpr=0.4) 158 | 159 | --------------------------------------------------------------------------------