├── ModelNet
├── __pycache__
│ ├── msma.cpython-37.pyc
│ └── utils.cpython-37.pyc
├── msma.py
└── utils.py
├── README.md
├── assets
├── arch.png
├── delay.png
├── noise.png
├── performance_compr.png
├── s700_mpr0.png
├── s700_mpr2.png
├── s700_mpr4.png
├── s700_mpr6.png
└── s700_mpr8.png
├── carla_data
└── Town03.osm
├── dataloader
├── __pycache__
│ └── carla_scene_process.cpython-37.pyc
├── carla_scene_mining.py
├── carla_scene_process.py
├── utils
│ ├── __pycache__
│ │ ├── lane_sampling.cpython-37.pyc
│ │ ├── lane_segment.cpython-37.pyc
│ │ └── load_xml.cpython-37.pyc
│ ├── lane_sampling.py
│ ├── lane_segment.py
│ └── load_xml.py
└── visualization.py
├── losses
├── __pycache__
│ ├── get_anchors.cpython-37.pyc
│ ├── loss.cpython-37.pyc
│ ├── msma_loss.cpython-37.pyc
│ ├── mtp_loss.cpython-37.pyc
│ └── multipath_loss.cpython-37.pyc
├── get_anchors.py
├── hivt_loss.py
├── msma_loss.py
├── mtp_loss.py
└── multipath_loss.py
├── metrics
├── __pycache__
│ ├── ade.cpython-37.pyc
│ ├── fde.cpython-37.pyc
│ ├── metric.cpython-37.pyc
│ └── mr.cpython-37.pyc
├── ade.py
├── fde.py
├── metric.py
└── mr.py
├── train.py
└── utils
├── __pycache__
└── optim_schedule.cpython-37.pyc
├── optim_schedule.py
└── viz.py
/ModelNet/__pycache__/msma.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/ModelNet/__pycache__/msma.cpython-37.pyc
--------------------------------------------------------------------------------
/ModelNet/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/ModelNet/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/ModelNet/msma.py:
--------------------------------------------------------------------------------
1 | #test overall model architecture
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | import sys
7 |
8 | from ModelNet.utils import MLP, bivariate_gaussian_activation
9 | # from utils import MLP
10 | from typing import Optional, Tuple, Union, Dict
11 | import math
12 | from torch_scatter import scatter_mean, scatter_add
13 |
14 | from dataloader.carla_scene_process import CarlaData
15 | from itertools import product
16 | from torch_geometric.utils import subgraph, add_self_loops
17 |
18 | class Base_Net(nn.Module):
19 | def __init__(self,
20 | ip_dim: int=2,
21 | historical_steps: int=30,
22 | embed_dim: int=16,
23 | temp_ff: int=64,
24 | spat_hidden_dim: int=64,
25 | spat_out_dim: int=64,
26 | edge_attr_dim: int=2,
27 | map_out_dim: int=64,
28 | lane_dim: int = 2,
29 | map_local_radius: float=30.,
30 | decoder_hidden_dim: int=64,
31 | num_heads: int = 8,
32 | dropout: float = 0.1,
33 | num_temporal_layers: int = 4,
34 | use_variance: bool = False,
35 | device = 'cpu',
36 | commu_only = False,
37 | sensor_only = False,
38 | prediction_mode = None,
39 | ) -> None:
40 | super(Base_Net, self).__init__()
41 | self.ip_dim = ip_dim
42 | self.historical_steps = historical_steps
43 | self.embed_dim = embed_dim
44 | self.device = device
45 | self.local_radius = map_local_radius
46 | self.commu_only = commu_only
47 | self.sensor_only = sensor_only
48 | self.prediction_mode = prediction_mode
49 |
50 | if self.prediction_mode == "temp_only":
51 | decoder_in_dim = embed_dim
52 | elif self.prediction_mode == "temp_spat":
53 | decoder_in_dim = spat_out_dim
54 | else:
55 | decoder_in_dim = spat_out_dim+map_out_dim
56 |
57 | #input embedding
58 | self.ip_emb_cav = MLP(ip_dim, embed_dim)
59 | self.ip_emb_commu = MLP(ip_dim, embed_dim)
60 | self.ip_emb_sensor = MLP(ip_dim, embed_dim)
61 | self.ip_emb_fuse = MLP(ip_dim, embed_dim)
62 | #temporal encoders
63 | self.temp_encoder = TemporalEncoder(historical_steps=historical_steps,
64 | embed_dim=embed_dim,
65 | device=device,
66 | num_heads=num_heads,
67 | num_layers=num_temporal_layers,
68 | temp_ff=temp_ff,
69 | dropout=dropout)
70 | self.feature_fuse = FeatureFuse(embed_dim=embed_dim,
71 | num_heads=num_heads,
72 | dropout=dropout)
73 | self.spat_encoder = GAT(in_dim=embed_dim,
74 | hidden_dim=spat_hidden_dim,
75 | out_dim=spat_out_dim,
76 | edge_attr_dim=edge_attr_dim,
77 | device=device,
78 | num_heads=num_heads,
79 | dropout=dropout)
80 | self.map_encoder = MapEncoder(lane_dim=lane_dim,
81 | v_dim=spat_out_dim,
82 | out_dim=map_out_dim,
83 | edge_attr_dim=edge_attr_dim,
84 | num_heads=num_heads,
85 | device=device,
86 | dropout=dropout)
87 | self.decoder = PredictionDecoder(encoding_size=decoder_in_dim,
88 | hidden_size=decoder_hidden_dim,
89 | num_modes=5,
90 | op_len=50,
91 | use_variance=use_variance)
92 |
93 | def forward(self, data: CarlaData):
94 |
95 | #temporal encoding
96 | x_cav, x_commu, x_sensor = data.x_cav, data.x_commu, data.x_sensor #overlapping among different modes
97 | cav_mask, commu_mask, sensor_mask = data.cav_mask, data.commu_mask, data.sensor_mask
98 | rotate_imat = data.rotate_imat
99 | x_cav = torch.bmm(x_cav, rotate_imat[cav_mask])
100 | x_commu = torch.bmm(x_commu, rotate_imat[commu_mask])
101 | x_sensor = torch.bmm(x_sensor, rotate_imat[sensor_mask])
102 |
103 | x_cav_, x_commu_, x_sensor_ = self.ip_emb_cav(x_cav), self.ip_emb_commu(x_commu), self.ip_emb_sensor(x_sensor)
104 | cav_out, commu_out, sensor_out = self.temp_encoder(x_cav_, x_commu_, x_sensor_)
105 |
106 | #convert back to original num_nodes given masks
107 | node_features_all = torch.zeros((data.num_nodes, self.embed_dim)).to(self.device)
108 | node_features_all[cav_mask] = cav_out
109 | node_features_all[commu_mask] = commu_out
110 | node_features_all[sensor_mask] = sensor_out
111 | #fuse sensor&commu encodings
112 | mask_fuse = (commu_mask & sensor_mask)
113 | commu_emd, sensor_emd = self.get_overlap_feature(data, commu_out, sensor_out, mask_fuse, self.embed_dim)
114 | # commu_relpos, sensor_relpos = self.get_overlap_feature(data, data.x_commu_ori, data.x_sensor_ori, mask_fuse, self.ip_dim)
115 | # relpos_emd = self.ip_emb_fuse(sensor_relpos-commu_relpos)
116 |
117 | if self.commu_only:
118 | node_features_all[commu_mask] = commu_out
119 | # data.y[commu_mask] = data.y_commu
120 | elif self.sensor_only:
121 | node_features_all[sensor_mask] = sensor_out
122 | elif sum(mask_fuse)>0:
123 | node_features_all[mask_fuse] = self.feature_fuse(commu_emd, sensor_emd)
124 |
125 | mask_all = (cav_mask | commu_mask | sensor_mask)
126 |
127 | if self.prediction_mode == "temp_only":
128 | predictions = self.decoder(node_features_all[mask_all]) #'traj':[nodes_of_interest, 5, 50, 2], 'log_probs':[nodes_of_interest, 5]
129 | return predictions, mask_all
130 |
131 | edge_index, _ = subgraph(subset=mask_all, edge_index=data.edge_index)
132 | edge_index, _ = add_self_loops(edge_index, num_nodes=data.num_nodes)
133 | edge_attr = data['positions'][edge_index[0], 49] - data['positions'][edge_index[1], 49]
134 | # edge_attr = torch.bmm(edge_attr.unsqueeze(-2), rotate_imat[edge_index[1]]).squeeze(-2)
135 | spat_out = self.spat_encoder(node_features_all.view(data.num_nodes,-1), edge_index, edge_attr) #[num_nodes, 64]
136 |
137 | if self.prediction_mode == "temp_spat":
138 | predictions = self.decoder(spat_out[mask_all]) #'traj':[nodes_of_interest, 5, 50, 2], 'log_probs':[nodes_of_interest, 5]
139 | return predictions, mask_all
140 | #AL encoding
141 | map_out = self.map_encoder(data, spat_out, mask_all) #[num_nodes, 64]
142 | final_emd = torch.cat((spat_out, map_out), dim=-1) #[num_nodes, 128]
143 |
144 | predictions = self.decoder(final_emd[mask_all]) #'traj':[nodes_of_interest, 5, 50, 2], 'log_probs':[nodes_of_interest, 5]
145 | return predictions, mask_all
146 |
147 | def get_overlap_feature(self, data, commu_f, sensor_f, mask_fuse, dim):
148 | commu_mask, sensor_mask = data.commu_mask, data.sensor_mask
149 | commu_feature = torch.zeros((data.num_nodes, dim)).to(self.device)
150 | sensor_feature = torch.zeros((data.num_nodes, dim)).to(self.device)
151 | commu_feature[commu_mask] = commu_f
152 | sensor_feature[sensor_mask] = sensor_f
153 |
154 | return commu_feature[mask_fuse], sensor_feature[mask_fuse]
155 |
156 |
157 | class TemporalEncoder(nn.Module):
158 | '''
159 | for each agent, only one fused channel instead of three
160 | '''
161 | def __init__(self,
162 | historical_steps: int,
163 | embed_dim: int,
164 | device,
165 | num_heads: int=8,
166 | num_layers: int=4,
167 | temp_ff: int=64,
168 | dropout: float=0.1) -> None:
169 | super(TemporalEncoder, self).__init__()
170 | self.embed_dim = embed_dim
171 | self.device = device
172 | self.historical_steps = historical_steps
173 | encoder_layer_cav = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=temp_ff, dropout=dropout, batch_first=True)
174 | self.transformer_encoder_cav = nn.TransformerEncoder(encoder_layer=encoder_layer_cav, num_layers=num_layers,
175 | norm=nn.LayerNorm(embed_dim))
176 | encoder_layer_sensor = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=temp_ff, dropout=dropout, batch_first=True)
177 | self.transformer_encoder_sensor = nn.TransformerEncoder(encoder_layer=encoder_layer_sensor, num_layers=num_layers,
178 | norm=nn.LayerNorm(embed_dim))
179 | encoder_layer_commu = nn.TransformerEncoderLayer(embed_dim, num_heads, dim_feedforward=temp_ff, dropout=dropout, batch_first=True)
180 | self.transformer_encoder_commu= nn.TransformerEncoder(encoder_layer=encoder_layer_commu, num_layers=num_layers,
181 | norm=nn.LayerNorm(embed_dim))
182 | self.cls_token_cav = nn.Parameter(torch.Tensor(1, 1, embed_dim))
183 | self.cls_token_commu = nn.Parameter(torch.Tensor(1, 1, embed_dim))
184 | self.cls_token_sensor = nn.Parameter(torch.Tensor(1, 1, embed_dim))
185 |
186 | self.pos_embed_cav = nn.Parameter(torch.Tensor(1, historical_steps + 1, embed_dim))
187 | self.pos_embed_commu = nn.Parameter(torch.Tensor(1, historical_steps + 1, embed_dim))
188 | self.pos_embed_sensor = nn.Parameter(torch.Tensor(1, historical_steps + 1, embed_dim))
189 |
190 | nn.init.normal_(self.cls_token_cav, mean=0., std=.02)
191 | nn.init.normal_(self.cls_token_commu, mean=0., std=.02)
192 | nn.init.normal_(self.cls_token_sensor, mean=0., std=.02)
193 | nn.init.normal_(self.pos_embed_cav, mean=0., std=.02)
194 | nn.init.normal_(self.pos_embed_commu, mean=0., std=.02)
195 | nn.init.normal_(self.pos_embed_sensor, mean=0., std=.02)
196 | # self.apply(init_weights)
197 | self.dropout = nn.Dropout(dropout)
198 | self.layer_norm = nn.LayerNorm(embed_dim)
199 | self.linear = nn.Linear(embed_dim * 2, embed_dim)
200 |
201 | def forward(self, x_cav, x_commu, x_sensor):
202 | """
203 | input [batch, seq, feature]
204 | """
205 | num_sensor, seq_len = x_sensor.shape[0], x_sensor.shape[1]
206 | assert seq_len == self.historical_steps
207 |
208 | x_cav, x_commu, x_sensor = self._expand_cls_token(x_cav, x_commu, x_sensor)
209 |
210 | x_cav = x_cav + self.pos_embed_cav
211 | x_sensor = x_sensor + self.pos_embed_sensor
212 | x_commu = x_commu + self.pos_embed_commu
213 |
214 | # Apply dropout and layer normalization
215 | x_cav_t = self.layer_norm(self.dropout(x_cav))
216 | x_sensor_t = self.layer_norm(self.dropout(x_sensor))
217 | x_commu_t = self.layer_norm(self.dropout(x_commu))
218 |
219 | # Apply the transformers
220 | x_cav_temp = self.transformer_encoder_cav(x_cav_t)
221 | x_commu_temp = self.transformer_encoder_commu(x_commu_t)
222 | x_sensor_temp = self.transformer_encoder_sensor(x_sensor_t)
223 |
224 | return x_cav_temp[:,-1,:], x_commu_temp[:,-1,:], x_sensor_temp[:,-1,:] #encoding at last timestep
225 |
226 | def _expand_cls_token(self, x_cav, x_commu, x_sensor):
227 | expand_cls_token_cav= self.cls_token_cav.expand(x_cav.shape[0], -1, -1)
228 | expand_cls_token_commu= self.cls_token_commu.expand(x_commu.shape[0], -1, -1)
229 | expand_cls_token_sensor= self.cls_token_sensor.expand(x_sensor.shape[0], -1, -1)
230 |
231 | x_cav = torch.cat((x_cav, expand_cls_token_cav), dim=1)
232 | x_commu = torch.cat((x_commu, expand_cls_token_commu), dim=1)
233 | x_sensor = torch.cat((x_sensor, expand_cls_token_sensor), dim=1)
234 |
235 | return x_cav, x_commu, x_sensor
236 |
237 | class FeatureFuse(nn.Module):
238 | """
239 | cross attention module
240 | """
241 | def __init__(self,
242 | embed_dim,
243 | num_heads,
244 | dropout=0.1):
245 | super(FeatureFuse, self).__init__()
246 | self.embed_dim = embed_dim
247 | self.num_heads = num_heads
248 | self.lin_q = nn.Linear(embed_dim, embed_dim)
249 | self.lin_k = nn.Linear(embed_dim, embed_dim)
250 | self.lin_v = nn.Linear(embed_dim, embed_dim)
251 | self.lin_self = nn.Linear(embed_dim, embed_dim)
252 | self.lin_ih = nn.Linear(embed_dim, embed_dim)
253 | self.lin_hh = nn.Linear(embed_dim, embed_dim)
254 | self.attn_drop = nn.Dropout(dropout)
255 | self.softmax = nn.Softmax(dim=1)
256 |
257 | def forward(self, commu_enc, sensor_enc):
258 | query = self.lin_q(sensor_enc).view(-1, self.num_heads, self.embed_dim // self.num_heads)
259 | key = self.lin_k(commu_enc).view(-1, self.num_heads, self.embed_dim // self.num_heads)
260 | value = self.lin_v(commu_enc).view(-1, self.num_heads, self.embed_dim // self.num_heads)
261 | scale = (self.embed_dim // self.num_heads) ** 0.5
262 | alpha = (query * key).sum(dim=-1) / scale
263 | alpha = self.softmax(alpha)
264 | alpha = self.attn_drop(alpha)
265 | commu_att = (value * alpha.unsqueeze(-1)).reshape(-1, self.embed_dim)
266 | w = torch.sigmoid(self.lin_ih(sensor_enc) + self.lin_hh(commu_att))
267 | fused_enc = w * self.lin_self(sensor_enc) + (1-w) * commu_att
268 | return fused_enc
269 |
270 | class GAT(nn.Module):
271 | def __init__(self, in_dim, hidden_dim, out_dim, edge_attr_dim, device, num_heads=8, dropout=0.1):
272 | super(GAT, self).__init__()
273 |
274 | self.device = device
275 | self.attention_layers = nn.ModuleList(
276 | [GATlayer(in_dim, hidden_dim, edge_attr_dim) for _ in range(num_heads)]
277 | )
278 | self.out_att = GATlayer(hidden_dim*num_heads, out_dim, edge_attr_dim)
279 | self.dropout = nn.Dropout(dropout)
280 |
281 | def forward(self, X, edge_index, edge_attr):
282 | x = X
283 |
284 | # Concatenate multi-head attentions
285 | x = torch.cat([att(x, edge_index, edge_attr) for att in self.attention_layers], dim=1)
286 | x = F.elu(x)
287 | x = self.dropout(x)
288 | x = self.out_att(x, edge_index, edge_attr) # Final attention aggregation
289 | return F.log_softmax(x, dim=1)
290 |
291 | class GATlayer(nn.Module):
292 | def __init__(self,
293 | embed_dim: int,
294 | out_dim: int,
295 | edge_attr_dim: int,
296 | dropout: float=0.1) -> None:
297 | super(GATlayer, self).__init__()
298 |
299 | self.W = nn.Linear(embed_dim, out_dim, bias=False)
300 | self.a = nn.Linear(2*out_dim + edge_attr_dim, 1, bias=False)
301 | self.edge_attr_dim = edge_attr_dim
302 | self.dropout = nn.Dropout(dropout)
303 | self.out_transform = nn.Linear(out_dim, out_dim, bias=False)
304 |
305 | def forward(self,
306 | X: torch.Tensor,
307 | edge_index: torch.Tensor,
308 | edge_attr: torch.Tensor):
309 | #transform node features
310 | h = self.W(X)
311 | N = h.size(0)
312 | attn_input = self._prepare_attention_input(h, edge_index, edge_attr)
313 | score_per_edge = F.leaky_relu(self.a(attn_input)).squeeze(1) # Calculate attention coefficients
314 |
315 | #apply dropout to attention weights
316 | score_per_edge = self.dropout(score_per_edge)
317 | # softmax
318 | # Calculate the numerator. Make logits <= 0 so that e^logit <= 1 (this will improve the numerical stability)
319 | score_per_edge = score_per_edge - score_per_edge.max()
320 | exp_score_per_edge = score_per_edge.exp()
321 |
322 | neigborhood_aware_denominator = scatter_add(exp_score_per_edge, edge_index[0], dim=0, dim_size=N)
323 | neigborhood_aware_denominator = neigborhood_aware_denominator.index_select(0, edge_index[0])
324 | attentions_per_edge = exp_score_per_edge / (neigborhood_aware_denominator + 1e-16)
325 |
326 | # Apply attention weights to source node features and perform message passing
327 | out_src = h.index_select(0,edge_index[1]) * attentions_per_edge.unsqueeze(dim=1)
328 | h_prime = scatter_add(out_src, edge_index[0], dim=0, dim_size=N)
329 |
330 | # Apply activation function
331 | out = F.elu(h_prime)
332 | return out
333 |
334 | def _prepare_attention_input(self, h, edge_index, edge_attr):
335 | '''
336 | h has shape [N, out_dim]
337 | '''
338 | src, tgt = edge_index
339 | attn_input = torch.cat([h.index_select(0,src), h.index_select(0,tgt), edge_attr], dim=1)
340 |
341 | return attn_input
342 |
343 | class MapEncoder(nn.Module):
344 | def __init__(self,
345 | lane_dim: int,
346 | v_dim: int,
347 | out_dim: int,
348 | edge_attr_dim: int,
349 | num_heads: int,
350 | device: str,
351 | local_radius: float=30.,
352 | dropout: float=0.1) -> None:
353 | super(MapEncoder, self).__init__()
354 | self.local_radius = local_radius
355 | self.device = device
356 | self.attention_layers = nn.ModuleList(
357 | [MapEncoderLayer(out_dim, v_dim, edge_attr_dim) for _ in range(num_heads)]
358 | )
359 | self.lane_emb = MLP(lane_dim, v_dim) #out_dim = v_enc.size(1)
360 | self.edge_attr_dim = edge_attr_dim
361 | self.dropout = nn.Dropout(dropout)
362 | self.out_transform = nn.Linear(out_dim*num_heads, out_dim, bias=False)
363 |
364 | def forward(self, data: CarlaData, v_enc: torch.Tensor, v_mask: torch.Tensor):
365 |
366 | lane = data.lane_vectors
367 |
368 | lane_actor_mask = torch.cat((v_mask, (torch.ones(lane.size(0))==1).to(self.device)), dim=0)
369 | data.lane_actor_index[0] += data.num_nodes #lane_actor_index[0]:lane index, lane_actor_index[1]:actor index
370 | lane_actor_index, lane_actor_attr = subgraph(subset=lane_actor_mask,
371 | edge_index=data.lane_actor_index, edge_attr=data.lane_actor_attr)
372 | lane = torch.bmm(lane[lane_actor_index[0]-data.num_nodes].unsqueeze(-2), data.rotate_imat[lane_actor_index[1]]).squeeze(-2)
373 |
374 | lane_enc = self.lane_emb(lane)
375 | lane_actor_enc = torch.cat((v_enc, lane_enc), dim=0) #shape:[num_veh+num_lane, v_dim]
376 | # Concat multi-head attentions
377 | out = torch.cat([att(lane_actor_enc, data.num_nodes, lane.size(0), lane_actor_index, lane_actor_attr) for att in self.attention_layers], dim=1)
378 | out = F.elu(out)
379 | out = self.dropout(out)
380 | out = self.out_transform(out)
381 |
382 | return out
383 |
384 | class MapEncoderLayer(nn.Module):
385 | def __init__(self,
386 | v_dim: int,
387 | out_dim: int,
388 | edge_attr_dim: int,
389 | dropout: float=0.1) -> None:
390 | super(MapEncoderLayer, self).__init__()
391 |
392 | self.W = nn.Linear(v_dim, out_dim, bias=False)
393 | self.a = nn.Linear(2*out_dim + edge_attr_dim, 1, bias=False)
394 | self.dropout = nn.Dropout(dropout)
395 |
396 | def forward(self,
397 | lane_actor_enc: torch.Tensor,
398 | num_veh: int,
399 | num_lane: int,
400 | lane_actor_index: torch.Tensor,
401 | lane_actor_attr: torch.Tensor):
402 | #transform node features
403 | h = self.W(lane_actor_enc)
404 | N = h.size(0)
405 | assert N == num_veh+num_lane
406 |
407 | attn_input = self._prepare_attention_input(h, num_veh,lane_actor_index, lane_actor_attr)
408 | score_per_edge = F.leaky_relu(self.a(attn_input)).squeeze(1) # Calculate attention coefficients
409 |
410 | #apply dropout to attention weights
411 | score_per_edge = self.dropout(score_per_edge)
412 | # softmax
413 | # Calculate the numerator. Make logits <= 0 so that e^logit <= 1 (this will improve the numerical stability)
414 | score_per_edge = score_per_edge - score_per_edge.max()
415 | exp_score_per_edge = score_per_edge.exp()
416 |
417 | neigborhood_aware_denominator = scatter_add(exp_score_per_edge, lane_actor_index[1], dim=0, dim_size=num_veh)
418 | neigborhood_aware_denominator = neigborhood_aware_denominator.index_select(0, lane_actor_index[1])
419 | attentions_per_edge = exp_score_per_edge / (neigborhood_aware_denominator + 1e-16)
420 |
421 | out_src = h[num_veh:] * attentions_per_edge.unsqueeze(dim=1) #shape[num_lane]
422 | out = scatter_add(out_src, lane_actor_index[1], dim=0, dim_size=num_veh)
423 | assert out.shape[0] == num_veh
424 |
425 | # Apply activation function
426 | out = F.elu(out)
427 | return out
428 |
429 | def _prepare_attention_input(self, h, num_v, edge_index, edge_attr):
430 | '''
431 | h has shape [N, out_dim]
432 | '''
433 | src, tgt = edge_index
434 | attn_input = torch.cat([h[num_v:], h[:num_v].index_select(0,tgt), edge_attr], dim=1)
435 |
436 | return attn_input
437 |
438 | class PredictionDecoder(nn.Module):
439 |
440 | def __init__(self,
441 | encoding_size: int,
442 | hidden_size: int=64,
443 | num_modes: int=5,
444 | op_len: int=50,
445 | use_variance: bool=False) -> None:
446 | super(PredictionDecoder, self).__init__()
447 |
448 | self.op_dim = 5 if use_variance else 2
449 | self.op_len = op_len
450 | self.num_modes = num_modes
451 | self.use_variance = use_variance
452 | self.hidden = nn.Linear(encoding_size, hidden_size)
453 | self.traj_op = nn.Sequential(
454 | nn.Linear(hidden_size, hidden_size),
455 | nn.LayerNorm(hidden_size),
456 | nn.ReLU(inplace=True),
457 | nn.Linear(hidden_size, hidden_size),
458 | nn.LayerNorm(hidden_size),
459 | nn.ReLU(inplace=True),
460 | nn.Linear(hidden_size, self.op_len * self.op_dim * self.num_modes))
461 | self.prob_op = nn.Sequential(
462 | nn.Linear(hidden_size, hidden_size),
463 | nn.LayerNorm(hidden_size),
464 | nn.ReLU(inplace=True),
465 | nn.Linear(hidden_size, hidden_size),
466 | nn.LayerNorm(hidden_size),
467 | nn.ReLU(inplace=True),
468 | nn.Linear(hidden_size, self.num_modes))
469 |
470 | self.leaky_relu = nn.LeakyReLU(0.01)
471 | self.log_softmax = nn.LogSoftmax(dim=1)
472 |
473 |
474 | def forward(self, agg_encoding: torch.Tensor) -> Dict:
475 | """
476 | Forward pass for prediction decoder
477 | :param agg_encoding: aggregated context encoding
478 | :return predictions: dictionary with 'traj': K predicted trajectories and
479 | 'probs': K corresponding probabilities
480 | """
481 |
482 | h = self.leaky_relu(self.hidden(agg_encoding))
483 | num_vehs = h.shape[0] #n_v
484 | traj = self.traj_op(h) #[n_v, 1250]
485 | probs = self.log_softmax(self.prob_op(h)) #[n_v, 5]
486 | traj = traj.reshape(num_vehs, self.num_modes, self.op_len, self.op_dim)
487 | probs = probs.squeeze(dim=-1)
488 | traj = bivariate_gaussian_activation(traj) if self.use_variance else traj
489 |
490 | predictions = {'traj':traj, 'log_probs':probs}
491 |
492 | return predictions
493 |
--------------------------------------------------------------------------------
/ModelNet/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class MLP(nn.Module):
5 | def __init__(self,
6 | in_dim: int,
7 | out_dim: int) -> None:
8 | super(MLP, self).__init__()
9 | self.embed = nn.Sequential(
10 | nn.Linear(in_dim, out_dim),
11 | nn.LayerNorm(out_dim),
12 | nn.ReLU(inplace=True),
13 | nn.Linear(out_dim, out_dim),
14 | nn.LayerNorm(out_dim),
15 | nn.ReLU(inplace=True),
16 | nn.Linear(out_dim, out_dim),
17 | nn.LayerNorm(out_dim)
18 | )
19 | self.apply(init_weights)
20 |
21 | def forward(self, x: torch.Tensor) -> torch.Tensor:
22 | return self.embed(x)
23 |
24 | def init_weights(m: nn.Module) -> None:
25 | if isinstance(m, nn.Linear):
26 | nn.init.xavier_uniform_(m.weight)
27 | if m.bias is not None:
28 | nn.init.zeros_(m.bias)
29 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | fan_in = m.in_channels / m.groups
31 | fan_out = m.out_channels / m.groups
32 | bound = (6.0 / (fan_in + fan_out)) ** 0.5
33 | nn.init.uniform_(m.weight, -bound, bound)
34 | if m.bias is not None:
35 | nn.init.zeros_(m.bias)
36 | elif isinstance(m, nn.Embedding):
37 | nn.init.normal_(m.weight, mean=0.0, std=0.02)
38 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
39 | nn.init.ones_(m.weight)
40 | nn.init.zeros_(m.bias)
41 | elif isinstance(m, nn.LayerNorm):
42 | nn.init.ones_(m.weight)
43 | nn.init.zeros_(m.bias)
44 | elif isinstance(m, nn.MultiheadAttention):
45 | if m.in_proj_weight is not None:
46 | fan_in = m.embed_dim
47 | fan_out = m.embed_dim
48 | bound = (6.0 / (fan_in + fan_out)) ** 0.5
49 | nn.init.uniform_(m.in_proj_weight, -bound, bound)
50 | else:
51 | nn.init.xavier_uniform_(m.q_proj_weight)
52 | nn.init.xavier_uniform_(m.k_proj_weight)
53 | nn.init.xavier_uniform_(m.v_proj_weight)
54 | if m.in_proj_bias is not None:
55 | nn.init.zeros_(m.in_proj_bias)
56 | nn.init.xavier_uniform_(m.out_proj.weight)
57 | if m.out_proj.bias is not None:
58 | nn.init.zeros_(m.out_proj.bias)
59 | if m.bias_k is not None:
60 | nn.init.normal_(m.bias_k, mean=0.0, std=0.02)
61 | if m.bias_v is not None:
62 | nn.init.normal_(m.bias_v, mean=0.0, std=0.02)
63 | elif isinstance(m, nn.LSTM):
64 | for name, param in m.named_parameters():
65 | if 'weight_ih' in name:
66 | for ih in param.chunk(4, 0):
67 | nn.init.xavier_uniform_(ih)
68 | elif 'weight_hh' in name:
69 | for hh in param.chunk(4, 0):
70 | nn.init.orthogonal_(hh)
71 | elif 'weight_hr' in name:
72 | nn.init.xavier_uniform_(param)
73 | elif 'bias_ih' in name:
74 | nn.init.zeros_(param)
75 | elif 'bias_hh' in name:
76 | nn.init.zeros_(param)
77 | nn.init.ones_(param.chunk(4, 0)[1])
78 | elif isinstance(m, nn.GRU):
79 | for name, param in m.named_parameters():
80 | if 'weight_ih' in name:
81 | for ih in param.chunk(3, 0):
82 | nn.init.xavier_uniform_(ih)
83 | elif 'weight_hh' in name:
84 | for hh in param.chunk(3, 0):
85 | nn.init.orthogonal_(hh)
86 | elif 'bias_ih' in name:
87 | nn.init.zeros_(param)
88 | elif 'bias_hh' in name:
89 | nn.init.zeros_(param)
90 |
91 | def bivariate_gaussian_activation(ip: torch.Tensor) -> torch.Tensor:
92 | """
93 | Activation function to output parameters of bivariate Gaussian distribution
94 | """
95 | mu_x = ip[..., 0:1]
96 | mu_y = ip[..., 1:2]
97 | sig_x = ip[..., 2:3]
98 | sig_y = ip[..., 3:4]
99 | rho = ip[..., 4:5]
100 | sig_x = torch.exp(sig_x)
101 | sig_y = torch.exp(sig_y)
102 | rho = torch.tanh(rho)
103 | out = torch.cat([mu_x, mu_y, sig_x, sig_y, rho], dim = -1)
104 |
105 | return out
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MSMA
2 |
3 | we focus on traffic scenarios where a connected and autonomous vehicle (CAV) serves as the central agent, utilizing both sensors and communication technologies to perceive its surrounding traffics consisting of autonomous vehicles, connected vehicles, and human-driven vehicles.
4 |
5 | ## Overview
6 | 
7 |
8 | ## Gettting Started
9 |
10 | 1\. Clone this repository:
11 | ```
12 | git clone https://github.com/xichennn/MSMA.git
13 | cd MSMA
14 | ```
15 |
16 | 2\. Create a conda environment and install the dependencies:
17 | ```
18 | conda create -n MSMA python=3.8
19 | conda activate MSMA
20 | conda install pytorch==1.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge
21 |
22 | # install other dependencies
23 | pip install pytorch-lightning
24 | pip install torch-scatter torch-geometric -f https://pytorch-geometric.com/whl/torch-2.1.0+cu121.html
25 | ```
26 | 3\. Download the [CARLA simulation data](https://drive.google.com/file/d/1bxIS4O1ZF3AvKqnsRTYzy5xg7bVwvL-w/view?usp=drive_link) and move it to the carla_data dir.
27 |
28 | ## Training
29 | In train.py, There are 3 hyperparameters that control the data processing:
30 | - mpr: determines the mpr of the connected vehicles in the dataset
31 | - delay_frame: determines the latency ranging from 1 to 15 frames (0.1~1.5s)
32 | - noise_var: determines the Gaussian noise variance ranging from 0 to 0.5 \
33 |
34 | and there are two in the model arguments that control the data fusion:
35 | - commu_only: when set to true, only data from connected vehicles are utilized
36 | - sensor_only: when set to true, only data from AV sensors are utilized \
37 | when both commu_only and sensor_only are set to False, data from both sources will be integrated
38 |
39 | ## Results
40 |
41 | ### Quantitative Results
42 |
43 |
44 |
45 |
46 |
47 | | Metrics | MPR=0 | MPR=0.2 | MPR=0.4 | MPR=0.6 |MPR=0.8 |
48 | | :--- | :---: | :---: | :---: |:---: |:---: |
49 | | ADE | 0.62 | 0.61 | 0.59 | 0.59 | 0.56 |
50 | | FDE | 1.48 | 1.47 | 1.40 | 1.37 | 1.33 |
51 | | MR | 0.23 | 0.22 | 0.22 | 0.21 | 0.20 |
52 | ### Qualitative Results
53 |
54 | | MPR=0 | MPR=0.4 |MPR=0.8 |
55 | | -------------------------- | -------------------------- |-------------------------- |
56 | |  |  |  |
57 |
58 | ## Citation
59 |
60 | If you found this repository useful, please cite as:
61 |
62 | ```
63 | @article{chen2024msma,
64 | title={MSMA: Multi-agent Trajectory Prediction in Connected and Autonomous Vehicle Environment with Multi-source Data Integration},
65 | author={Chen, Xi and Bhadani, Rahul and Sun, Zhanbo and Head, Larry},
66 | journal={arXiv preprint arXiv:2407.21310},
67 | year={2024}
68 | }
69 | ```
70 |
71 | ## License
72 |
73 | This repository is licensed under [Apache 2.0](LICENSE).
74 |
--------------------------------------------------------------------------------
/assets/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/arch.png
--------------------------------------------------------------------------------
/assets/delay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/delay.png
--------------------------------------------------------------------------------
/assets/noise.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/noise.png
--------------------------------------------------------------------------------
/assets/performance_compr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/performance_compr.png
--------------------------------------------------------------------------------
/assets/s700_mpr0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr0.png
--------------------------------------------------------------------------------
/assets/s700_mpr2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr2.png
--------------------------------------------------------------------------------
/assets/s700_mpr4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr4.png
--------------------------------------------------------------------------------
/assets/s700_mpr6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr6.png
--------------------------------------------------------------------------------
/assets/s700_mpr8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/assets/s700_mpr8.png
--------------------------------------------------------------------------------
/dataloader/__pycache__/carla_scene_process.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/dataloader/__pycache__/carla_scene_process.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/carla_scene_mining.py:
--------------------------------------------------------------------------------
1 | """mine CAV scenarios from logged carla data"""
2 | import pandas as pd
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | # import torch
6 | import random
7 | # import math
8 | import os
9 | import copy
10 |
11 | def get_obj_type_at_mpr(data_df, vids, cav_id, mpr=0.2):
12 | others = list(set(vids) - set([cav_id]))
13 | #keep random seed here to ensure the same seed
14 | random.seed(30)
15 | cv_ids = random.sample(list(others), int(mpr*len(others)))
16 | vid_df = data_df["vid"].values
17 | obj_type_mpr = []
18 | for v in vid_df:
19 | if v == cav_id:
20 | obj_type_mpr.append("cav")
21 | elif v in cv_ids:
22 | obj_type_mpr.append("cv")
23 | else:
24 | obj_type_mpr.append("ncv")
25 | return obj_type_mpr
26 |
27 | # read the data
28 | data_raw = pd.read_csv("../carla_data/Location.csv", header=None)
29 | header = ["frame","time","vid","type_id","position_x","position_y","position_z","rotation_x","rotation_y","rotation_z","vel_x","vel_y","angular_z"]
30 | map = {idx:header[idx] for idx in range(13)}
31 | data_raw = data_raw.rename(columns = map)
32 | # make pos_y consistent with map
33 | data_raw["position_y"] = -data_raw["position_y"]
34 | # %%
35 | vids = list(data_raw["vid"].unique())
36 | ts = np.sort(np.unique(data_raw['frame'].values))
37 | random.seed(30)
38 | cv_range = 50
39 | av_range = 30
40 |
41 | data_df = data_raw.copy(deep=True)
42 | # segment the scenes into 10s
43 | min_ts = ts[0]
44 | max_ts = ts[-1]
45 | # 5s overlapping among scenes 10s = 100 steps/frames
46 | # remove the scenes where cav is not moving
47 | for cav_id in vids:
48 | for frame in range(min_ts+50,max_ts-50,50):
49 |
50 | vehicles_at_frame = data_df[data_df["frame"] == frame]
51 | cav_entry = data_df[(data_df["frame"]==frame) & (data_df["vid"]==cav_id)]
52 | cav_entry_previous = data_df[(data_df["frame"]==frame-1) & (data_df["vid"]==cav_id)]
53 | if (cav_entry.position_x.values == cav_entry_previous.position_x.values) and \
54 | (cav_entry.position_y.values == cav_entry_previous.position_y.values):
55 | continue
56 |
57 | dist = ((vehicles_at_frame["position_x"].values - cav_entry["position_x"].values)**2
58 | + (vehicles_at_frame["position_y"].values - cav_entry["position_y"].values)**2)**0.5
59 | cv_idx = np.where((dist0))[0]
60 | cv_neighbors = vehicles_at_frame["vid"].values[cv_idx]
61 | #remove unmoved surrounding vehicles
62 | vid_ngbr_unmove = []
63 | for i in range(len(cv_neighbors)):
64 | vid_ngbr = cv_neighbors[i]
65 | ngbr_entry = data_df[(data_df["frame"]==frame) & (data_df["vid"]==vid_ngbr)]
66 | ngbr_entry_previous = data_df[(data_df["frame"]==frame-1) & (data_df["vid"]==vid_ngbr)]
67 | if (ngbr_entry.position_x.values == ngbr_entry_previous.position_x.values) and \
68 | (ngbr_entry.position_y.values == ngbr_entry_previous.position_y.values):
69 | vid_ngbr_unmove.append(vid_ngbr)
70 | cv_ngbrs_move=list(set(cv_neighbors)-set(vid_ngbr_unmove))
71 |
72 | av_idx = np.where((dist0))[0]
73 | av_neighbors = vehicles_at_frame["vid"].values[av_idx]
74 | av_ngbrs_move = list(set(av_neighbors)-set(vid_ngbr_unmove))
75 |
76 | scene_frames = list(range(frame-50,frame+50))
77 | scene_vids = [cav_id]+cv_ngbrs_move
78 | scene_data = copy.deepcopy(data_df[data_df["vid"].isin(scene_vids) & data_df["frame"].isin(scene_frames)])
79 |
80 | #mprs
81 | obj_type_mpr_02 = get_obj_type_at_mpr(scene_data, scene_vids, cav_id, mpr=0.2)
82 | obj_type_mpr_04 = get_obj_type_at_mpr(scene_data, scene_vids, cav_id, mpr=0.4)
83 | obj_type_mpr_06 = get_obj_type_at_mpr(scene_data, scene_vids, cav_id, mpr=0.6)
84 | obj_type_mpr_08 = get_obj_type_at_mpr(scene_data, scene_vids, cav_id, mpr=0.8)
85 | scene_data["obj_type_mpr_02"] = obj_type_mpr_02
86 | scene_data["obj_type_mpr_04"] = obj_type_mpr_04
87 | scene_data["obj_type_mpr_06"] = obj_type_mpr_06
88 | scene_data["obj_type_mpr_08"] = obj_type_mpr_08
89 |
90 | scene_data["in_av_range"] = scene_data["vid"].isin([cav_id]+av_ngbrs_move).values
91 | scene_data.to_csv("scene_mining/scene_{}_{}".format(frame, cav_id),index=False)
92 |
--------------------------------------------------------------------------------
/dataloader/carla_scene_process.py:
--------------------------------------------------------------------------------
1 | """
2 | process the '.csv' files, save as '.pt' files
3 | """
4 | import os
5 | import sys
6 | import numpy as np
7 | import pandas as pd
8 | import copy
9 | from os.path import join as pjoin
10 |
11 | from dataloader.utils import lane_segment, load_xml
12 | from dataloader.utils.lane_sampling import Spline2D, visualize_centerline
13 | import matplotlib.pyplot as plt
14 |
15 | from typing import List, Optional, Tuple
16 |
17 | import torch
18 | torch.manual_seed(30)
19 | import torch.nn as nn
20 | from torch_geometric.data import Data, HeteroData
21 | from torch_geometric.data import Dataset
22 | from typing import Callable, Dict, List, Optional, Tuple, Union
23 | from itertools import permutations, product
24 | from tqdm import tqdm
25 |
26 | class scene_processed_dataset(Dataset):
27 | def __init__(self,
28 | root:str,
29 | split:str,
30 | radius:float = 75,
31 | local_radius:float = 30,
32 | transform: Optional[Callable] = None,
33 | mpr:float = 0.,
34 | obs_len:float=50,
35 | fut_len:float=50,
36 | cv_range:float=50,
37 | av_range:float=30,
38 | noise_var:float=0.1,
39 | delay_frame:float=1,
40 | normalized=True,
41 | source_dir:str = None,
42 | save_dir:str = None) ->None:
43 |
44 | self._split = split
45 | self._radius = radius
46 | self._local_radius = local_radius
47 | self.obs_len = obs_len
48 | self.fut_len = fut_len
49 | self.cv_range = cv_range
50 | self.av_range = av_range
51 | self.mpr = mpr
52 | self.noise_var = noise_var
53 | self.delay_frame = delay_frame
54 | self.normalized = normalized
55 | self.source_dir = source_dir
56 | self.save_dir = save_dir
57 |
58 | self.root = root
59 | self._raw_file_names = os.listdir(self.raw_dir)
60 | self._processed_file_names = [os.path.splitext(f)[0] + '.pt' for f in self.raw_file_names]
61 | self._processed_paths = [os.path.join(self.processed_dir, f) for f in self._processed_file_names]
62 | super(scene_processed_dataset, self).__init__(root)
63 |
64 | @property
65 | def raw_dir(self) -> str:
66 | return os.path.join(self.root, self.source_dir, self._split)
67 |
68 | @property
69 | def processed_dir(self) -> str:
70 | return os.path.join(self.root, self.save_dir, self._split)
71 |
72 | @property
73 | def raw_file_names(self) -> Union[str, List[str], Tuple]:
74 | return self._raw_file_names
75 |
76 | @property
77 | def processed_file_names(self) -> Union[str, List[str], Tuple]:
78 | return self._processed_file_names
79 |
80 | @property
81 | def processed_paths(self) -> List[str]:
82 | return self._processed_paths
83 |
84 | def process(self) -> None:
85 | self.get_map_polygon_bbox()
86 | for raw_path in tqdm(self.raw_paths):
87 | kwargs = self.get_scene_feats(raw_path, self._radius, self._local_radius, self._split)
88 | data = CarlaData(**kwargs)
89 | torch.save(data, os.path.join(self.processed_dir, str(kwargs['seq_id']) + '.pt'))
90 |
91 | def len(self) -> int:
92 | return len(self._raw_file_names)
93 |
94 | def get(self, idx) -> Data:
95 | return torch.load(self.processed_paths[idx])
96 |
97 | def get_map_polygon_bbox(self):
98 | rel_path = "Town03.osm"
99 | roads = load_xml.load_lane_segments_from_xml(pjoin(self.root, rel_path))
100 | polygon_bboxes, lane_starts, lane_ends = load_xml.build_polygon_bboxes(roads)
101 | self.roads = roads
102 | self.polygon_bboxes = polygon_bboxes
103 | self.lane_starts = lane_starts
104 | self.lane_ends = lane_ends
105 |
106 | def get_scene_feats(self, raw_path, radius, local_radius, split="train"):
107 |
108 | df = pd.read_csv(raw_path)
109 | # filter out actors that are unseen during the historical time steps
110 | timestamps = list(np.sort(df['frame'].unique()))
111 | historical_timestamps = timestamps[: 50]
112 | historical_df = df[df['frame'].isin(historical_timestamps)]
113 | actor_ids = list(historical_df['vid'].unique())
114 |
115 | # # filter out unmoved actors
116 | # actor_ids = self.remove_unmoved_ids(df, actor_ids)
117 |
118 | df = df[df['vid'].isin(actor_ids)]
119 | num_nodes = len(actor_ids)
120 |
121 | objs = df.groupby(['vid', 'obj_type_mpr_02', 'obj_type_mpr_04', 'obj_type_mpr_06', 'obj_type_mpr_08', 'in_av_range']).groups
122 | keys = list(objs.keys())
123 |
124 | vids = [x[0] for x in keys]
125 | actor_indices = [vids.index(x) for x in actor_ids]
126 | obj_type_02 = [keys[i][1] for i in actor_indices]
127 | obj_type_04 = [keys[i][2] for i in actor_indices]
128 | obj_type_06 = [keys[i][3] for i in actor_indices]
129 | obj_type_08 = [keys[i][4] for i in actor_indices]
130 | in_av_range = [keys[i][5] for i in actor_indices]
131 |
132 | cav_idx = np.where(np.asarray(obj_type_02)=="cav")[0] #np array
133 | cav_df = df[df['obj_type_mpr_02'] == 'cav'].iloc
134 |
135 | # make the scene centered at CAV
136 | origin = torch.tensor([cav_df[49]['position_x'], cav_df[49]['position_y']], dtype=torch.float)
137 | cav_heading_vector = origin - torch.tensor([cav_df[48]['position_x'], cav_df[48]['position_y']], dtype=torch.float)
138 | theta = torch.atan2(cav_heading_vector[1], cav_heading_vector[0])
139 | rotate_mat = torch.tensor([[torch.cos(theta), -torch.sin(theta)],
140 | [torch.sin(theta), torch.cos(theta)]])
141 |
142 | # initialization
143 | x = torch.zeros(num_nodes, 100, 2, dtype=torch.float)
144 | edge_index = torch.LongTensor(list(permutations(range(num_nodes), 2))).t().contiguous()
145 | padding_mask = torch.ones(num_nodes, 100, dtype=torch.bool)
146 | bos_mask = torch.zeros(num_nodes, 50, dtype=torch.bool)
147 | rotate_angles = torch.zeros(num_nodes, dtype=torch.float)
148 |
149 | for actor_id, actor_df in df.groupby('vid'):
150 | node_idx = actor_ids.index(actor_id)
151 | node_steps = [timestamps.index(timestamp) for timestamp in actor_df['frame']]
152 | padding_mask[node_idx, node_steps] = False
153 | if padding_mask[node_idx, 49]: # make no predictions for actors that are unseen at the current time step
154 | padding_mask[node_idx, 50:] = True
155 | xy = torch.from_numpy(np.stack([actor_df['position_x'].values, actor_df['position_y'].values], axis=-1)).float()
156 | x[node_idx, node_steps] = torch.matmul(rotate_mat, (xy - origin.reshape(-1, 2)).T).T
157 | node_historical_steps = list(filter(lambda node_step: node_step < 50, node_steps))
158 | if len(node_historical_steps) > 1: # calculate the heading of the actor (approximately)
159 | heading_vector = x[node_idx, node_historical_steps[-1]] - x[node_idx, node_historical_steps[-2]]
160 | rotate_angles[node_idx] = torch.atan2(heading_vector[1], heading_vector[0])
161 | else: # make no predictions for the actor if the number of valid time steps is less than 2
162 | padding_mask[node_idx, 50:] = True
163 |
164 | # bos_mask is True if time step t is valid and time step t-1 is invalid
165 | bos_mask[:, 0] = ~padding_mask[:, 0]
166 | bos_mask[:, 1: 50] = padding_mask[:, : 49] & ~padding_mask[:, 1: 50]
167 |
168 | #positions are transformed absolute x, y coordinates
169 | positions = x.clone()
170 |
171 | #reformat encode strs and bools, CAV:1, CV:2, NCV:3
172 | obj_type_mapping = {"cav":1, "cv":2, "ncv":3}
173 | obj_type_02_ = torch.tensor([obj_type_mapping[x] for x in obj_type_02])
174 | obj_type_04_ = torch.tensor([obj_type_mapping[x] for x in obj_type_04])
175 | obj_type_06_ = torch.tensor([obj_type_mapping[x] for x in obj_type_06])
176 | obj_type_08_ = torch.tensor([obj_type_mapping[x] for x in obj_type_08])
177 | in_av_range_ = torch.tensor([1 if in_av_range[i]==True else 0 for i in range(len(in_av_range))])
178 |
179 | #get masks for different data sources
180 | types = [obj_type_02_, obj_type_04_, obj_type_06_, obj_type_08_]
181 | mprs = [0.2, 0.4, 0.6, 0.8]
182 | cav_mask, commu_mask, sensor_mask = self.get_masks(self.mpr, mprs, types, in_av_range_)
183 | positions_hist = positions[:,:50,:].clone()
184 | x_cav = positions_hist[cav_mask][:,20:50,:]
185 | x_commu = positions_hist[commu_mask]
186 | x_sensor = positions_hist[sensor_mask]
187 |
188 | #inject errors to different data sources
189 | x_sensor_noise, padding_mask_noise = self.get_noisy_x(x_sensor, padding_mask[sensor_mask], self.noise_var)
190 | x_commu_delay, padding_mask_delay = self.get_delayed_x(x_commu, padding_mask[commu_mask], self.delay_frame)
191 |
192 | #get vectorized x
193 | x_cav_vec = self.get_vectorized_x(x_cav, padding_mask[cav_mask][:,20:50])
194 | x_commu_delay_vec = self.get_vectorized_x(x_commu_delay, padding_mask_delay)
195 | x_sensor_noise_vec = self.get_vectorized_x(x_sensor_noise, padding_mask_noise)
196 |
197 |
198 | y = torch.where((padding_mask[:, 49].unsqueeze(-1) | padding_mask[:, 50:]).unsqueeze(-1),
199 | torch.zeros(num_nodes, 50, 2),
200 | x[:, 50:] - x[:, 49].unsqueeze(-2))
201 |
202 | y_commu = torch.where((padding_mask[:, 49].unsqueeze(-1) | padding_mask[:, 50:]).unsqueeze(-1),
203 | torch.zeros(num_nodes, 50, 2),
204 | x[:, 50:] - x[:, 49-self.delay_frame].unsqueeze(-2))[commu_mask]
205 |
206 | lane_pos, lane_vectors, lane_idcs,lane_actor_index, lane_actor_attr = \
207 | self.get_lane_feats(origin, rotate_mat, num_nodes, positions, radius, local_radius)
208 |
209 | #get rotate-invariant matrix
210 | rotate_imat = torch.empty(num_nodes, 2, 2)
211 | sin_vals = torch.sin(rotate_angles)
212 | cos_vals = torch.cos(rotate_angles)
213 | rotate_imat[:, 0, 0] = cos_vals
214 | rotate_imat[:, 0, 1] = -sin_vals
215 | rotate_imat[:, 1, 0] = sin_vals
216 | rotate_imat[:, 1, 1] = cos_vals
217 |
218 | seq_id = os.path.splitext(os.path.basename(raw_path))[0]
219 |
220 | return {
221 | 'x_cav': x_cav_vec, # [1, 30, 2]
222 | 'x_commu': x_commu_delay_vec, # [N1, 30, 2]
223 | 'x_sensor': x_sensor_noise_vec, # [N2, 30, 2]
224 | 'cav_mask': cav_mask, # [N]
225 | 'commu_mask': commu_mask, # [N]
226 | 'sensor_mask': sensor_mask, # [N]
227 | 'positions': positions, # [N, 100, 2]
228 | 'edge_index': edge_index, # [2, N x (N - 1)]
229 | 'y': y, # [N, 50, 2]
230 | 'y_commu': y_commu, #[M, 50, 2]
231 | 'x_commu_ori': x_commu_delay[:,-1,:], #abs starting pos of delayed traj
232 | 'x_sensor_ori': x_sensor_noise[:,-1,:], #abs starting pos of nosiy traj
233 | 'seq_id': seq_id, #str, file_name
234 | 'num_nodes': num_nodes,
235 | 'padding_mask': padding_mask, # [N, 100]
236 | 'bos_mask': bos_mask, # [N, 50]
237 | 'rotate_angles': rotate_angles, # [N]
238 | 'rotate_imat': rotate_imat, #[N, 2, 2]
239 | 'lane_vectors': lane_vectors, # [L, 2]
240 | 'lane_pos': lane_pos, #[L, 2]
241 | 'lane_idcs': lane_idcs, #[L]
242 | 'lane_actor_index': lane_actor_index,
243 | 'lane_actor_attr': lane_actor_attr,
244 | 'mpr': self.mpr,
245 | 'origin': origin.unsqueeze(0),
246 | 'theta': theta,
247 | 'rotate_mat': rotate_mat
248 | }
249 |
250 | def get_lane_feats(self, origin, rotate_mat, num_nodes, positions, radius=75, local_radius=30):
251 |
252 | road_ids = load_xml.get_road_ids_in_xy_bbox(self.polygon_bboxes, self.lane_starts, self.lane_ends, self.roads, origin[0], origin[1], radius)
253 | road_ids = copy.deepcopy(road_ids)
254 |
255 | lanes=dict()
256 | for road_id in road_ids:
257 | road = self.roads[road_id]
258 | ctr_line = torch.from_numpy(np.stack(((self.roads[road_id].l_bound[:,0]+self.roads[road_id].r_bound[:,0])/2,
259 | (self.roads[road_id].l_bound[:,1]+self.roads[road_id].r_bound[:,1])/2),axis=-1))
260 | ctr_line = torch.matmul(rotate_mat, (ctr_line - origin.reshape(-1, 2)).T.float()).T
261 |
262 | x, y = ctr_line[:,0], ctr_line[:,1]
263 | # if x.max() < x_min or x.min() > x_max or y.max() < y_min or y.min() > y_max:
264 | # continue
265 | # else:
266 | """getting polygons requires original centerline"""
267 | polygon, _, _ = load_xml.build_polygon_bboxes({road_id: self.roads[road_id]})
268 | polygon_x = torch.from_numpy(np.array([polygon[:,0],polygon[:,0],polygon[:,2],polygon[:,2],polygon[:,0]]))
269 | polygon_y = torch.from_numpy(np.array([polygon[:,1],polygon[:,3],polygon[:,3],polygon[:,1],polygon[:,1]]))
270 | polygon_reshape = torch.cat([polygon_x,polygon_y],dim=-1) #shape(5,2)
271 |
272 | road.centerline = ctr_line
273 | road.polygon = torch.matmul(rotate_mat, (polygon_reshape.float() - origin.reshape(-1, 2)).T).T
274 | lanes[road_id] = road
275 |
276 | lane_ids = list(lanes.keys())
277 | lane_pos, lane_vectors = [], []
278 | for lane_id in lane_ids:
279 | lane = lanes[lane_id]
280 | ctrln = lane.centerline
281 |
282 | # lane_ctrs.append(torch.from_numpy(np.asarray((ctrln[:-1]+ctrln[1:])/2.0, np.float32)))#lane center point
283 | # lane_vectors.append(torch.from_numpy(np.asarray(ctrln[1:]-ctrln[:-1], np.float32))) #length between waypoints
284 | lane_pos.append(ctrln[:-1]) #lane center point
285 | lane_vectors.append(ctrln[1:]-ctrln[:-1])#length between waypoints
286 |
287 | lane_idcs = []
288 | count = 0
289 | for i, position in enumerate(lane_pos):
290 | lane_idcs.append(i*torch.ones(len(position)))
291 | count += len(position)
292 |
293 | lane_idcs = torch.cat(lane_idcs, dim=0)
294 | lane_pos = torch.cat(lane_pos, dim=0)
295 | lane_vectors = torch.cat(lane_vectors, dim=0)
296 |
297 | lane_actor_index = torch.LongTensor(list(product(torch.arange(lane_vectors.size(0)), \
298 | torch.arange(num_nodes)))).t().contiguous()
299 | lane_actor_attr = \
300 | lane_pos[lane_actor_index[0]] - positions[:,49,:][lane_actor_index[1]]
301 | mask = torch.norm(lane_actor_attr, p=2, dim=-1) < local_radius
302 | lane_actor_index = lane_actor_index[:, mask]
303 | lane_actor_attr = lane_actor_attr[mask]
304 |
305 |
306 | return lane_pos, lane_vectors, lane_idcs, lane_actor_index, lane_actor_attr
307 |
308 | def get_vectorized_x(self, x0, padding_mask):
309 | '''
310 | x: torch.Tensor: [n, 30, 2]
311 | padding_mask: torch.Tensor:[n, 30]
312 | '''
313 | x = x0.clone()
314 | x[:, 1: 30] = torch.where((padding_mask[:, : 29] | padding_mask[:, 1: 30]).unsqueeze(-1),
315 | torch.zeros(x.shape[0], 29, 2),
316 | x[:, 1: 30] - x[:, : 29])
317 | x[:, 0] = torch.zeros(x.shape[0], 2)
318 |
319 | return x
320 |
321 | def get_masks(self, mpr, mprs, types, in_av_range):
322 | #ncv in av range
323 | #and all cv
324 | if mpr == 0:
325 | cav_mask = types[0]==1
326 | commu_mask = torch.zeros(cav_mask.shape)==True
327 | sensor_mask = (types[0]!=1) & (in_av_range==1)
328 | else:
329 | type_idx = mprs.index(mpr)
330 | cav_mask = types[type_idx]==1
331 | commu_mask = types[type_idx]==2
332 | sensor_mask = (types[type_idx]!=1) & (in_av_range==1)
333 |
334 | return cav_mask, commu_mask, sensor_mask
335 |
336 | def get_noisy_x(self, x, padding_mask, var=0.1):
337 | """
338 | get noisy feats for sensor data
339 | x: torch.Tensor of shape(n, 50, 2)
340 |
341 | return
342 | noise_x: torch.Tensor of shape(n, 30, 2)
343 | """
344 | noise = torch.normal(0, var, x.shape)
345 |
346 | return (x+noise)[:,20:,:], padding_mask[:,20:50]
347 |
348 | def get_delayed_x(self, x, padding_mask, lag=1):
349 | """
350 | get delayed feats of communication data
351 | x: torch tensor of shape(n, 50, 2)
352 | lag: number of frames in [0:20]
353 |
354 | return
355 | delayed_x: torch.Tensor of shape(n, 30, 2)
356 | """
357 | if lag<0 or lag>20:
358 | raise Exception("lag must be in the range(0,20)")
359 |
360 | delayed_x = x[:,20-lag:-lag,:]
361 |
362 | return delayed_x, padding_mask[:, 20-lag:50-lag]
363 |
364 | class CarlaData(Data):
365 |
366 | def __init__(self,
367 | x_cav: Optional[torch.Tensor] = None,
368 | x_commu: Optional[torch.Tensor] = None,
369 | x_sensor: Optional[torch.Tensor] = None,
370 | cav_mask: Optional[torch.Tensor] = None,
371 | commu_mask: Optional[torch.Tensor] = None,
372 | sensor_mask: Optional[torch.Tensor] = None,
373 | positions: Optional[torch.Tensor] = None,
374 | edge_index: Optional[torch.Tensor] = None,
375 | edge_attrs: Optional[List[torch.Tensor]] = None,
376 | lane_actor_index: Optional[torch.Tensor] = None,
377 | lane_actor_attr: Optional[torch.Tensor] = None,
378 | y: Optional[torch.Tensor] = None,
379 | y_commu: Optional[torch.Tensor] = None,
380 | x_commu_ori: Optional[torch.Tensor] = None,
381 | x_sensor_ori: Optional[torch.Tensor] = None,
382 | seq_id: Optional[str] = None,
383 | num_nodes: Optional[int] = None,
384 | padding_mask: Optional[torch.Tensor] = None,
385 | bos_mask: Optional[torch.Tensor] = None,
386 | rotate_angles: Optional[torch.Tensor] = None,
387 | rotate_imat: Optional[torch.Tensor] = None,
388 | lane_vectors: Optional[torch.Tensor] = None,
389 | lane_pos: Optional[torch.Tensor] = None,
390 | lane_idcs: Optional[torch.Tensor] = None,
391 | mpr: Optional[torch.Tensor] = None,
392 | origin: Optional[torch.Tensor] = None,
393 | theta: Optional[torch.Tensor] = None,
394 | rotate_mat: Optional[torch.Tensor] = None,
395 | # obj_type_02: Optional[torch.Tensor] = None,
396 | # obj_type_04: Optional[torch.Tensor] = None,
397 | # obj_type_06: Optional[torch.Tensor] = None,
398 | # obj_type_08: Optional[torch.Tensor] = None,
399 | # in_av_range: Optional[torch.Tensor] = None,
400 | **kwargs) -> None:
401 | if x_cav is None:
402 | super(CarlaData, self).__init__()
403 | return
404 | super(CarlaData, self).__init__(x_cav=x_cav, x_commu=x_commu, x_sensor=x_sensor, mpr=mpr,
405 | cav_mask=cav_mask, commu_mask=commu_mask, sensor_mask=sensor_mask,
406 | positions=positions, edge_index=edge_index, edge_attrs=edge_attrs,
407 | lane_actor_index=lane_actor_index, lane_actor_attr=lane_actor_attr,
408 | y=y, y_commu=y_commu, x_commu_ori=x_commu_ori, x_sensor_ori=x_sensor_ori,
409 | seq_id=seq_id, num_nodes=num_nodes, padding_mask=padding_mask,
410 | bos_mask=bos_mask, rotate_angles=rotate_angles, rotate_imat=rotate_imat,
411 | lane_vectors=lane_vectors, lane_pos=lane_pos, lane_idcs=lane_idcs,
412 | theta=theta, rotate_mat=rotate_mat,
413 | **kwargs)
414 | if edge_attrs is not None:
415 | for t in range(self.x.size(1)):
416 | self[f'edge_attr_{t}'] = edge_attrs[t]
417 |
418 | def __inc__(self, key, value, *args, **kwargs):
419 | if key == 'lane_actor_index':
420 | return torch.tensor([[self['lane_vectors'].size(0)], [self.num_nodes]])
421 | else:
422 | return super().__inc__(key, value)
423 |
424 |
--------------------------------------------------------------------------------
/dataloader/utils/__pycache__/lane_sampling.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/dataloader/utils/__pycache__/lane_sampling.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/utils/__pycache__/lane_segment.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/dataloader/utils/__pycache__/lane_segment.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/utils/__pycache__/load_xml.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/dataloader/utils/__pycache__/load_xml.cpython-37.pyc
--------------------------------------------------------------------------------
/dataloader/utils/lane_sampling.py:
--------------------------------------------------------------------------------
1 | #implement equal sampling of map vector
2 | import numpy as np
3 | import math
4 | import matplotlib.pyplot as plt
5 |
6 | class Spline:
7 | """
8 | Cubic Spline class
9 | """
10 |
11 | def __init__(self, x, y):
12 | self.b, self.c, self.d, self.w = [], [], [], []
13 |
14 | self.x = np.array(x)
15 | self.y = np.array(y)
16 |
17 | self.eps = np.finfo(float).eps
18 |
19 | self.nx = len(x) # dimension of x
20 | h = np.diff(x)
21 |
22 | # calc coefficient c
23 | self.a = np.array([iy for iy in y])
24 |
25 | # calc coefficient c
26 | A = self.__calc_A(h)
27 | B = self.__calc_B(h)
28 | self.c = np.linalg.solve(A, B)
29 | # print(self.c1)
30 |
31 | # calc spline coefficient b and d
32 | for i in range(self.nx - 1):
33 | self.d.append((self.c[i + 1] - self.c[i]) / (3.0 * h[i] + self.eps))
34 | tb = (self.a[i + 1] - self.a[i]) / (h[i] + self.eps) - h[i] * \
35 | (self.c[i + 1] + 2.0 * self.c[i]) / 3.0
36 | self.b.append(tb)
37 | self.b = np.array(self.b)
38 | self.d = np.array(self.d)
39 |
40 | def calc(self, t):
41 | """
42 | Calc position
43 | if t is outside of the input x, return None
44 | """
45 | t = np.asarray(t)
46 | mask = np.logical_and(t < self.x[0], t > self.x[-1])
47 | t[mask] = self.x[0]
48 |
49 | i = self.__search_index(t)
50 | dx = t - self.x[i.astype(int)]
51 | result = self.a[i] + self.b[i] * dx + \
52 | self.c[i] * dx ** 2.0 + self.d[i] * dx ** 3.0
53 |
54 | result = np.asarray(result)
55 | result[mask] = None
56 | return result
57 |
58 | def calcd(self, t):
59 | """
60 | Calc first derivative
61 | if t is outside of the input x, return None
62 | """
63 | t = np.asarray(t)
64 | mask = np.logical_and(t < self.x[0], t > self.x[-1])
65 | t[mask] = 0
66 |
67 | i = self.__search_index(t)
68 | dx = t - self.x[i]
69 | result = self.b[i] + 2.0 * self.c[i] * dx + 3.0 * self.d[i] * dx ** 2.0
70 |
71 | result = np.asarray(result)
72 | result[mask] = None
73 | return result
74 |
75 | def calcdd(self, t):
76 | """
77 | Calc second derivative
78 | """
79 | t = np.asarray(t)
80 | mask = np.logical_and(t < self.x[0], t > self.x[-1])
81 | t[mask] = 0
82 |
83 | i = self.__search_index(t)
84 | dx = t - self.x[i]
85 | result = 2.0 * self.c[i] + 6.0 * self.d[i] * dx
86 |
87 | result = np.asarray(result)
88 | result[mask] = None
89 | return result
90 |
91 | def __search_index(self, x):
92 | """
93 | search data segment index
94 | """
95 | indices = np.asarray(np.searchsorted(self.x, x, "left") - 1)
96 | indices[indices <= 0] = 0
97 | return indices
98 |
99 | def __calc_A(self, h):
100 | """
101 | calc matrix A for spline coefficient c
102 | """
103 | A = np.zeros((self.nx, self.nx))
104 | A[0, 0] = 1.0
105 | for i in range(self.nx - 1):
106 | if i != (self.nx - 2):
107 | A[i + 1, i + 1] = 2.0 * (h[i] + h[i + 1])
108 | A[i + 1, i] = h[i]
109 | A[i, i + 1] = h[i]
110 |
111 | A[0, 1] = 0.0
112 | A[self.nx - 1, self.nx - 2] = 0.0
113 | A[self.nx - 1, self.nx - 1] = 1.0
114 | # print(A)
115 | return A
116 |
117 | def __calc_B(self, h):
118 | """
119 | calc matrix B for spline coefficient c
120 | """
121 | B = np.zeros(self.nx)
122 | for i in range(self.nx - 2):
123 | B[i + 1] = 3.0 * (self.a[i + 2] - self.a[i + 1]) / (h[i + 1] + self.eps) \
124 | - 3.0 * (self.a[i + 1] - self.a[i]) / (h[i] + self.eps)
125 | return B
126 | class Spline2D:
127 | """
128 | 2D Cubic Spline class
129 | """
130 |
131 | def __init__(self, x, y, resolution=0.1):
132 | self.s = self.__calc_s(x, y)
133 | self.sx = Spline(self.s, x)
134 | self.sy = Spline(self.s, y)
135 |
136 | self.s_fine = np.arange(0, self.s[-1], resolution)
137 | xy = np.array([self.calc_global_position_online(s_i) for s_i in self.s_fine])
138 |
139 | self.x_fine = xy[:, 0]
140 | self.y_fine = xy[:, 1]
141 |
142 | def __calc_s(self, x, y):
143 | dx = np.diff(x)
144 | dy = np.diff(y)
145 | self.ds = np.hypot(dx, dy)
146 | s = [0]
147 | s.extend(np.cumsum(self.ds))
148 | return s
149 |
150 | def calc_global_position_online(self, s):
151 | """
152 | calc global position of points on the line, s: float
153 | return: x: float; y: float; the global coordinate of given s on the spline
154 | """
155 | x = self.sx.calc(s)
156 | y = self.sy.calc(s)
157 |
158 | return x, y
159 |
160 | def calc_global_position_offline(self, s, d):
161 | """
162 | calc global position of points in the frenet coordinate w.r.t. the line.
163 | s: float, longitudinal; d: float, lateral;
164 | return: x, float; y, float;
165 | """
166 | s_x = self.sx.calc(s)
167 | s_y = self.sy.calc(s)
168 |
169 | theta = math.atan2(self.sy.calcd(s), self.sx.calcd(s))
170 | x = s_x - math.sin(theta) * d
171 | y = s_y + math.cos(theta) * d
172 | return x, y
173 |
174 | def calc_frenet_position(self, x, y):
175 | """
176 | cal the frenet position of given global coordinate (x, y)
177 | return s: the longitudinal; d: the lateral
178 | """
179 | # find nearst x, y
180 | diff = np.hypot(self.x_fine - x, self.y_fine - y)
181 | idx = np.argmin(diff)
182 | [x_s, y_s] = self.x_fine[idx], self.y_fine[idx]
183 | s = self.s_fine[idx]
184 |
185 | # compute theta
186 | theta = math.atan2(self.sy.calcd(s), self.sx.calcd(s))
187 | d_x, d_y = x - x_s, y - y_s
188 | cross_rd_nd = math.cos(theta) * d_y - math.sin(theta) * d_x
189 | d = math.copysign(np.hypot(d_x, d_y), cross_rd_nd)
190 | return s, d
191 |
192 | def calc_curvature(self, s):
193 | """
194 | calc curvature
195 | """
196 | dx = self.sx.calcd(s)
197 | ddx = self.sx.calcdd(s)
198 | dy = self.sy.calcd(s)
199 | ddy = self.sy.calcdd(s)
200 | k = (ddy * dx - ddx * dy) / ((dx ** 2 + dy ** 2)**(3 / 2))
201 | return k
202 |
203 | def calc_yaw(self, s):
204 | """
205 | calc yaw
206 | """
207 | dx = self.sx.calcd(s)
208 | dy = self.sy.calcd(s)
209 | yaw = np.arctan2(dy, dx)
210 | return yaw
211 |
212 | def visualize_centerline(centerline) -> None:
213 | """Visualize the computed centerline.
214 | Args:
215 | centerline: Sequence of coordinates forming the centerline
216 | """
217 | line_coords = list(zip(*centerline))
218 | lineX = line_coords[0]
219 | lineY = line_coords[1]
220 | plt.plot(lineX, lineY, "--", color="grey", alpha=1, linewidth=1, zorder=0)
221 | plt.text(lineX[0], lineY[0], "s")
222 | plt.text(lineX[-1], lineY[-1], "e")
223 | plt.axis("equal")
--------------------------------------------------------------------------------
/dataloader/utils/lane_segment.py:
--------------------------------------------------------------------------------
1 | #
2 | from typing import List, Optional
3 |
4 | import numpy as np
5 |
6 |
7 | class LaneSegment:
8 | def __init__(
9 | self,
10 | id: int,
11 | l_neighbor_id: Optional[int],
12 | r_neighbor_id: Optional[int],
13 | centerline: np.ndarray,
14 | ) -> None:
15 | """
16 | Initialize the lane segment.
17 |
18 | Args:
19 | id: Unique lane ID that serves as identifier for this "Way"
20 | l_neighbor_id: Unique ID for left neighbor
21 | r_neighbor_id: Unique ID for right neighbor
22 | centerline: The coordinates of the lane segment's center line.
23 | """
24 | self.id = id
25 | self.l_neighbor_id = l_neighbor_id
26 | self.r_neighbor_id = r_neighbor_id
27 | self.centerline = centerline
28 |
29 | class Road:
30 | def __init__(
31 | self,
32 | id: int,
33 | l_bound: np.ndarray,
34 | r_bound: np.ndarray,
35 | ) -> None:
36 | """Initialize the lane segment.
37 |
38 | Args:
39 | id: Unique lane ID that serves as identifier for this "Way".
40 | l_bound: The coordinates of the lane segment's left bound.
41 | r_bound: The coordinates of the lane segment's right bound.
42 | """
43 | self.id = id
44 | self.l_bound = l_bound
45 | self.r_bound = r_bound
46 |
47 |
--------------------------------------------------------------------------------
/dataloader/utils/load_xml.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | #
3 |
4 | """
5 | Utility to load the Argoverse vector map from disk, where it is stored in an XML format.
6 |
7 | We release our Argoverse vector map in a modified OpenStreetMap (OSM) form. We also provide
8 | the map data loader. OpenStreetMap (OSM) provides XML data and relies upon "Nodes" and "Ways" as
9 | its fundamental element.
10 |
11 | A "Node" is a point of interest, or a constituent point of a line feature such as a road.
12 | In OpenStreetMap, a `Node` has tags, which might be
13 | -natural: If it's a natural feature, indicates the type (hill summit, etc)
14 | -man_made: If it's a man made feature, indicates the type (water tower, mast etc)
15 | -amenity: If it's an amenity (e.g. a pub, restaurant, recycling
16 | centre etc) indicates the type
17 |
18 | In OSM, a "Way" is most often a road centerline, composed of an ordered list of "Nodes".
19 | An OSM way often represents a line or polygon feature, e.g. a road, a stream, a wood, a lake.
20 | Ways consist of two or more nodes. Tags for a Way might be:
21 | -highway: the class of road (motorway, primary,secondary etc)
22 | -maxspeed: maximum speed in km/h
23 | -ref: the road reference number
24 | -oneway: is it a one way road? (boolean)
25 |
26 | However, in Argoverse, a "Way" corresponds to a LANE segment centerline. An Argoverse Way has the
27 | following 9 attributes:
28 | - id: integer, unique lane ID that serves as identifier for this "Way"
29 | - has_traffic_control: boolean
30 | - turn_direction: string, 'RIGHT', 'LEFT', or 'NONE'
31 | - is_intersection: boolean
32 | - l_neighbor_id: integer, unique ID for left neighbor
33 | - r_neighbor_id: integer, unique ID for right neighbor
34 | - predecessors: list of integers or None
35 | - successors: list of integers or None
36 | - centerline_node_ids: list
37 |
38 | In Argoverse, a `LaneSegment` object is derived from a combination of a single `Way` and two or more
39 | `Node` objects.
40 | """
41 |
42 | import logging
43 | import os
44 | import xml.etree.ElementTree as ET
45 | from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union, cast
46 |
47 | import numpy as np
48 | import matplotlib.pyplot as plt
49 |
50 | from dataloader.utils.lane_segment import LaneSegment, Road
51 |
52 | logger = logging.getLogger(__name__)
53 |
54 |
55 | _PathLike = Union[str, "os.PathLike[str]"]
56 |
57 |
58 | class Node:
59 | """
60 | e.g. a point of interest, or a constituent point of a
61 | line feature such as a road
62 | """
63 |
64 | def __init__(self, id: int, x: float, y: float, height: Optional[float] = None):
65 | """
66 | Args:
67 | id: representing unique node ID
68 | x: x-coordinate in city reference system
69 | y: y-coordinate in city reference system
70 |
71 | Returns:
72 | None
73 | """
74 | self.id = id
75 | self.x = x
76 | self.y = y
77 | self.height = height
78 |
79 |
80 | def str_to_bool(s: str) -> bool:
81 | """
82 | Args:
83 | s: string representation of boolean, either 'True' or 'False'
84 |
85 | Returns:
86 | boolean
87 | """
88 | if s == "True":
89 | return True
90 | assert s == "False"
91 | return False
92 |
93 |
94 | def convert_dictionary_to_lane_segment_obj(lane_id: int, lane_dictionary: Mapping[str, Any]) -> LaneSegment:
95 | """
96 | Not all lanes have predecessors and successors.
97 |
98 | Args:
99 | lane_id: representing unique lane ID
100 | lane_dictionary: dictionary with LaneSegment attributes, not yet in object instance form
101 |
102 | Returns:
103 | ls: LaneSegment object
104 | """
105 |
106 | l_neighbor_id = None
107 | r_neighbor_id = None
108 | ls = LaneSegment(
109 | lane_id,
110 | l_neighbor_id,
111 | r_neighbor_id,
112 | lane_dictionary["centerline"],
113 | )
114 | return ls
115 |
116 |
117 | def append_additional_key_value_pair(lane_obj: MutableMapping[str, Any], way_field: List[Tuple[str, str]]) -> None:
118 | """
119 | Key name was either 'predecessor' or 'successor', for which we can have multiple.
120 | Thus we append them to a list. They should be integers, as lane IDs.
121 |
122 | Args:
123 | lane_obj: lane object
124 | way_field: key and value pair to append
125 |
126 | Returns:
127 | None
128 | """
129 | assert len(way_field) == 2
130 | k = way_field[0][1]
131 | v = int(way_field[1][1])
132 | lane_obj.setdefault(k, []).append(v)
133 |
134 |
135 | def append_unique_key_value_pair(lane_obj: MutableMapping[str, Any], way_field: List[Tuple[str, str]]) -> None:
136 | """
137 | For the following types of Way "tags", the key, value pair is defined only once within
138 | the object:
139 | - has_traffic_control, turn_direction, is_intersection, l_neighbor_id, r_neighbor_id
140 |
141 | Args:
142 | lane_obj: lane object
143 | way_field: key and value pair to append
144 |
145 | Returns:
146 | None
147 | """
148 | assert len(way_field) == 2
149 | k = way_field[0][1]
150 | v = way_field[1][1]
151 | lane_obj[k] = v
152 |
153 |
154 | def extract_node_waypt(way_field: List[Tuple[str, str]]) -> int:
155 | """
156 | Given a list with a reference node such as [('ref', '0')], extract out the lane ID.
157 |
158 | Args:
159 | way_field: key and node id pair to extract
160 |
161 | Returns:
162 | node_id: unique ID for a node waypoint
163 | """
164 | key = way_field[0][0]
165 | node_id = way_field[0][1]
166 | assert key == "ref"
167 | return int(node_id)
168 |
169 |
170 | def get_lane_identifier(child: ET.Element) -> int:
171 | """
172 | Fetch lane ID from XML ET.Element.
173 |
174 | Args:
175 | child: ET.Element with information about Way
176 |
177 | Returns:
178 | unique lane ID
179 | """
180 | return int(child.attrib["id"])
181 |
182 |
183 | def convert_node_id_list_to_xy(node_id_list: List[int], all_graph_nodes: Mapping[int, Node]) -> np.ndarray:
184 | """
185 | convert node id list to centerline xy coordinate
186 |
187 | Args:
188 | node_id_list: list of node_id's
189 | all_graph_nodes: dictionary mapping node_ids to Node
190 |
191 | Returns:
192 | centerline
193 | """
194 | num_nodes = len(node_id_list)
195 |
196 | if all_graph_nodes[node_id_list[0]].height is not None:
197 | centerline = np.zeros((num_nodes, 3))
198 | else:
199 | centerline = np.zeros((num_nodes, 2))
200 | for i, node_id in enumerate(node_id_list):
201 | if all_graph_nodes[node_id].height is not None:
202 | centerline[i] = np.array(
203 | [
204 | all_graph_nodes[node_id].x,
205 | all_graph_nodes[node_id].y,
206 | all_graph_nodes[node_id].height,
207 | ]
208 | )
209 | else:
210 | centerline[i] = np.array([all_graph_nodes[node_id].x, all_graph_nodes[node_id].y])
211 |
212 | return centerline
213 |
214 |
215 | def extract_node_from_ET_element(child: ET.Element) -> Node:
216 | """
217 | Given a line of XML, build a node object. The "node_fields" dictionary will hold "id", "x", "y".
218 | The XML will resemble:
219 |
220 |
221 |
222 | Args:
223 | child: xml.etree.ElementTree element
224 |
225 | Returns:
226 | Node object
227 | """
228 | node_fields = child.attrib
229 | node_id = int(node_fields["id"])
230 | for element in child:
231 | way_field = cast(List[Tuple[str, str]], list(element.items()))
232 | key = way_field[0][1]
233 | if key == "local_x":
234 | x = float(way_field[1][1])
235 | elif key == "local_y":
236 | y = float(way_field[1][1])
237 |
238 | return Node(id=node_id, x=x, y=y)
239 |
240 |
241 | def extract_lane_segment_from_ET_element(
242 | child: ET.Element, all_graph_nodes: Mapping[int, Node]
243 | ) -> Tuple[LaneSegment, int]:
244 | """
245 | We build a lane segment from an XML element. A lane segment is equivalent
246 | to a "Way" in our XML file. Each Lane Segment has a polyline representing its centerline.
247 | The relevant XML data might resemble::
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 | ...
257 |
258 |
259 | ...
260 |
261 |
262 |
263 | Args:
264 | child: xml.etree.ElementTree element
265 | all_graph_nodes
266 |
267 | Returns:
268 | lane_segment: LaneSegment object
269 | lane_id
270 | """
271 | lane_obj: Dict[str, Any] = {}
272 | lane_id = get_lane_identifier(child)
273 | node_id_list: List[int] = []
274 | for element in child:
275 | # The cast on the next line is the result of a typeshed bug. This really is a List and not a ItemsView.
276 | way_field = cast(List[Tuple[str, str]], list(element.items()))
277 | field_name = way_field[0][0]
278 | if field_name == "k":
279 | key = way_field[0][1]
280 | if key in {"predecessor", "successor"}:
281 | append_additional_key_value_pair(lane_obj, way_field)
282 | else:
283 | append_unique_key_value_pair(lane_obj, way_field)
284 | else:
285 | node_id_list.append(extract_node_waypt(way_field))
286 |
287 | lane_obj["centerline"] = convert_node_id_list_to_xy(node_id_list, all_graph_nodes)
288 | lane_segment = convert_dictionary_to_lane_segment_obj(lane_id, lane_obj)
289 | return lane_segment, lane_id
290 |
291 | def construct_road_from_ET_element(
292 | child: ET.Element, lane_objs: Mapping[int, LaneSegment]
293 | ):
294 | road_id = int(child.attrib["id"])
295 | for element in child:
296 | if element.tag == "member":
297 | relation_field = cast(List[Tuple[str, str]], list(element.items()))
298 | if relation_field[2][1] == "right":
299 | r_bound_idx = int(relation_field[1][1])
300 | elif relation_field[2][1] == "left":
301 | l_bound_idx = int(relation_field[1][1])
302 | l_bound = lane_objs[l_bound_idx].centerline
303 | r_bound = lane_objs[r_bound_idx].centerline
304 | road = Road(
305 | road_id,
306 | l_bound,
307 | r_bound
308 | )
309 | return road, road_id
310 |
311 |
312 | def load_lane_segments_from_xml(map_fpath: _PathLike) -> Mapping[int, LaneSegment]:
313 | """
314 | Load lane segment object from xml file
315 |
316 | Args:
317 | map_fpath: path to xml file
318 |
319 | Returns:
320 | lane_objs: List of LaneSegment objects
321 | """
322 | tree = ET.parse(os.fspath(map_fpath))
323 | root = tree.getroot()
324 |
325 | logger.info(f"Loaded root: {root.tag}")
326 |
327 | all_graph_nodes = {}
328 | lane_objs = {}
329 | roads = {}
330 | # all children are either Nodes or Ways or relations
331 | for child in root:
332 | if child.tag == "node":
333 | node_obj = extract_node_from_ET_element(child)
334 | all_graph_nodes[node_obj.id] = node_obj
335 | elif child.tag == "way":
336 | lane_obj, lane_id = extract_lane_segment_from_ET_element(child, all_graph_nodes)
337 | lane_objs[lane_id] = lane_obj
338 | elif child.tag == "relation":
339 | road, road_id = construct_road_from_ET_element(child, lane_objs)
340 | roads[road_id] = road
341 | else:
342 | logger.error("Unknown XML item encountered.")
343 | raise ValueError("Unknown XML item encountered.")
344 | return roads
345 |
346 | def build_polygon_bboxes(roads):
347 | """
348 | roads: dict, key: road id; value field: l_bound, r_bound
349 | polygon_bboxes: An array of shape (K,), each array element is a NumPy array of shape (4,) representing
350 | the bounding box for a polygon or point cloud.
351 | each road_id corresponds to a polygon_bbox
352 | lane_start: An array of shape (,4), indicating (x_l, y_l, x_r, y_r)
353 | lane_end: An array of shape (,4), indicating (x_l, y_l, x_r, y_r)
354 | """
355 | polygon_bboxes = []
356 | lane_starts = []
357 | lane_ends = []
358 | for road_id in roads.keys():
359 | x = np.concatenate((roads[road_id].l_bound[:,0], roads[road_id].r_bound[:,0]))
360 | xmin = np.min(x)
361 | xmax = np.max(x)
362 | y = np.concatenate((roads[road_id].l_bound[:,1], roads[road_id].r_bound[:,1]))
363 | ymin = np.min(y)
364 | ymax = np.max(y)
365 | polygon_bbox = np.array([xmin, ymin, xmax, ymax])
366 | polygon_bboxes.append(polygon_bbox)
367 |
368 | lane_start = np.array([roads[road_id].l_bound[0,0], roads[road_id].l_bound[0,1],
369 | roads[road_id].r_bound[0,0], roads[road_id].r_bound[0,1]])
370 | lane_end = np.array([roads[road_id].l_bound[-1,0], roads[road_id].l_bound[-1,1],
371 | roads[road_id].r_bound[-1,0], roads[road_id].r_bound[-1,1]])
372 | lane_starts.append(lane_start)
373 | lane_ends.append(lane_end)
374 |
375 | return np.array(polygon_bboxes), np.array(lane_starts), np.array(lane_ends)
376 |
377 | def find_all_polygon_bboxes_overlapping_query_bbox(polygon_bboxes: np.ndarray,
378 | query_bbox: np.ndarray,
379 | lane_starts: np.ndarray,
380 | lane_ends: np.ndarray) -> np.ndarray:
381 | """Find all the overlapping polygon bounding boxes.
382 | Each bounding box has the following structure:
383 | bbox = np.array([x_min,y_min,x_max,y_max])
384 | In 3D space, if the coordinates are equal (polygon bboxes touch), then these are considered overlapping.
385 | We have a guarantee that the cropped image will have any sort of overlap with the zero'th object bounding box
386 | inside of the image e.g. along the x-dimension, either the left or right side of the bounding box lies between the
387 | edges of the query bounding box, or the bounding box completely engulfs the query bounding box.
388 | Args:
389 | polygon_bboxes: An array of shape (K, 4), each array element is a NumPy array of shape (4,) representing
390 | the bounding box for a polygon or point cloud.
391 | query_bbox: An array of shape (4,) representing a 2d axis-aligned bounding box, with order
392 | [min_x,min_y,max_x,max_y].
393 | lane_starts: An array of shape (, 4), representing the start point of lane left bound and right bound
394 | lane_ends: An array of shape (, 4), representing the end point of lane left bound and right bound
395 | Returns:
396 | An integer array of shape (K,) representing indices where overlap occurs.
397 | """
398 | query_min_x = query_bbox[0]
399 | query_min_y = query_bbox[1]
400 |
401 | query_max_x = query_bbox[2]
402 | query_max_y = query_bbox[3]
403 |
404 | bboxes_x1 = polygon_bboxes[:, 0]
405 | bboxes_x2 = polygon_bboxes[:, 2]
406 |
407 | bboxes_y1 = polygon_bboxes[:, 1]
408 | bboxes_y2 = polygon_bboxes[:, 3]
409 |
410 | # check if falls within range
411 | overlaps_left = (query_min_x <= bboxes_x2) & (bboxes_x2 <= query_max_x)
412 | overlaps_right = (query_min_x <= bboxes_x1) & (bboxes_x1 <= query_max_x)
413 |
414 | x_check1 = bboxes_x1 <= query_min_x
415 | x_check2 = query_min_x <= query_max_x
416 | x_check3 = query_max_x <= bboxes_x2
417 | x_subsumed = x_check1 & x_check2 & x_check3
418 |
419 | x_in_range = overlaps_left | overlaps_right | x_subsumed
420 |
421 | overlaps_below = (query_min_y <= bboxes_y2) & (bboxes_y2 <= query_max_y)
422 | overlaps_above = (query_min_y <= bboxes_y1) & (bboxes_y1 <= query_max_y)
423 |
424 | y_check1 = bboxes_y1 <= query_min_y
425 | y_check2 = query_min_y <= query_max_y
426 | y_check3 = query_max_y <= bboxes_y2
427 | y_subsumed = y_check1 & y_check2 & y_check3
428 | y_in_range = overlaps_below | overlaps_above | y_subsumed
429 |
430 | # at least one lane endpoint in range
431 | # xy_check1 = (query_min_x <= lane_starts[:,0]) & (lane_starts[:,0] <= query_max_x) & \
432 | # (query_min_y <= lane_starts[:,1]) & (lane_starts[:,1] <= query_max_y)
433 | # xy_check2 = (query_min_x <= lane_starts[:,2]) & (lane_starts[:,2] <= query_max_x) & \
434 | # (query_min_y <= lane_starts[:,3]) & (lane_starts[:,3] <= query_max_y)
435 | # xy_check3 = (query_min_x <= lane_ends[:,0]) & (lane_ends[:,0] <= query_max_x) & \
436 | # (query_min_y <= lane_ends[:,1]) & (lane_ends[:,1] <= query_max_y)
437 | # xy_check4 = (query_min_x <= lane_ends[:,2]) & (lane_ends[:,2] <= query_max_x) & \
438 | # (query_min_y <= lane_ends[:,3]) & (lane_ends[:,3] <= query_max_y)
439 | # xy_in_range = xy_check1 | xy_check2 | xy_check3 | xy_check4
440 |
441 | # overlap_indxs = np.where(x_in_range & y_in_range & xy_in_range)[0]
442 |
443 | overlap_indxs = np.where(x_in_range & y_in_range)[0]
444 | return overlap_indxs
445 |
446 | def get_road_ids_in_xy_bbox(
447 | polygon_bboxes,
448 | lane_starts,
449 | lane_ends,
450 | roads,
451 | query_x: float,
452 | query_y: float,
453 | query_search_range_manhattan: float = 50.0,
454 | ):
455 | """
456 | Prune away all lane segments based on Manhattan distance. We vectorize this instead
457 | of using a for-loop. Get all lane IDs within a bounding box in the xy plane.
458 | This is a approximation of a bubble search for point-to-polygon distance.
459 | The bounding boxes of small point clouds (lane centerline waypoints) are precomputed in the map.
460 | We then can perform an efficient search based on manhattan distance search radius from a
461 | given 2D query point.
462 | We pre-assign lane segment IDs to indices inside a big lookup array, with precomputed
463 | hallucinated lane polygon extents.
464 | Args:
465 | query_x: representing x coordinate of xy query location
466 | query_y: representing y coordinate of xy query location
467 | city_name: either 'MIA' for Miami or 'PIT' for Pittsburgh
468 | query_search_range_manhattan: search radius along axes
469 | Returns:
470 | lane_ids: lane segment IDs that live within a bubble
471 | """
472 | query_min_x = query_x - query_search_range_manhattan
473 | query_max_x = query_x + query_search_range_manhattan
474 | query_min_y = query_y - query_search_range_manhattan
475 | query_max_y = query_y + query_search_range_manhattan
476 |
477 | overlap_indxs = find_all_polygon_bboxes_overlapping_query_bbox(
478 | polygon_bboxes,
479 | np.array([query_min_x, query_min_y, query_max_x, query_max_y],),
480 | lane_starts,
481 | lane_ends
482 | )
483 |
484 | if len(overlap_indxs) == 0:
485 | return []
486 |
487 | neighborhood_road_ids = []
488 | for overlap_idx in overlap_indxs:
489 | lane_segment_id = list(roads.keys())[overlap_idx]
490 | neighborhood_road_ids.append(lane_segment_id)
491 |
492 | return neighborhood_road_ids
493 |
494 | if __name__ == "__main__":
495 | roads = load_lane_segments_from_xml("Town03.osm")
496 | polygon_bboxes = build_polygon_bboxes(roads)
497 | query_x = 5.772
498 | query_y = 119.542
499 | cv_range = 50
500 | neighborhood_road_ids = get_road_ids_in_xy_bbox(polygon_bboxes, query_x, query_y, cv_range)
501 |
502 |
503 | # # %%
504 | # plt.figure(dpi=200)
505 | # fig, (ax1,ax2) = plt.subplots(1,2)
506 | # fig.set_figheight(2)
507 | # fig.set_figwidth(4)
508 | # for i in roads.keys():
509 |
510 | # road_id = i
511 | # ax1.plot(roads[road_id].l_bound[:,0], roads[road_id].l_bound[:,1], color='k')#, marker='o', markerfacecolor='blue', markersize=5)
512 | # ax1.plot(roads[road_id].r_bound[:,0], roads[road_id].r_bound[:,1], color='k')#, marker='o', markerfacecolor='red', markersize=5)
513 | # ax1.plot((roads[road_id].l_bound[:,0]+roads[road_id].r_bound[:,0])/2, (roads[road_id].l_bound[:,1]+roads[road_id].r_bound[:,1])/2, color="0.7",linestyle='dashed')
514 | # ax2.plot(roads[road_id].l_bound[:,0], roads[road_id].l_bound[:,1], color='k')#, marker='o', markerfacecolor='blue', markersize=5)
515 | # ax2.plot(roads[road_id].r_bound[:,0], roads[road_id].r_bound[:,1], color='k')#, marker='o', markerfacecolor='red', markersize=5)
516 | # ax2.plot((roads[road_id].l_bound[:,0]+roads[road_id].r_bound[:,0])/2, (roads[road_id].l_bound[:,1]+roads[road_id].r_bound[:,1])/2, color="0.7",linestyle='dashed')
517 |
518 | # ax1.set_xlim([-60,60])
519 | # ax1.set_ylim([-60,60])
520 | # ax2.set_xlim([60,120])
521 | # ax2.set_ylim([80,180])
522 | # ax1.axis("off")
523 | # ax2.axis("off")
524 | # # plt.show()
525 | # plt.savefig("town03_lane_segment.jpg")
526 | # # %%
527 | # # plot one lane segment
528 | # for i in roads.keys():
529 | # road_id = i
530 | # if min(roads[road_id].l_bound[:,0])>60 and max(roads[road_id].l_bound[:,1])>-20 and max(roads[road_id].r_bound[:,0])<120 and max(roads[road_id].r_bound[:,1])<70:
531 |
532 | # plt.plot(roads[road_id].l_bound[:,0], roads[road_id].l_bound[:,1], color='0.7')#, marker='o', markerfacecolor='blue', markersize=5)
533 | # plt.plot(roads[road_id].r_bound[:,0], roads[road_id].r_bound[:,1], color='0.7')#, marker='o', markerfacecolor='red', markersize=5)
534 | # # plt.
535 | # # plt.xlim((60,120))
536 | # # plt.ylim((80,180))
537 | # # plt.axis("off")
538 | # plt.show()
539 |
540 | # # %%
541 | # for i in roads.keys():
542 | # road_id = i
543 | # plt.plot(roads[road_id].l_bound[:,0], roads[road_id].l_bound[:,1], color='0.7')#, marker='o', markerfacecolor='blue', markersize=5)
544 | # plt.plot(roads[road_id].r_bound[:,0], roads[road_id].r_bound[:,1], color='0.7')#, marker='o', markerfacecolor='red', markersize=5)
545 | # plt.show()
546 | # # %%
547 |
--------------------------------------------------------------------------------
/dataloader/visualization.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from carla_scene_process import CarlaData
3 | import torch
4 | import numpy as np
5 |
6 | def visualize_centerline(centerline) -> None:
7 | """Visualize the computed centerline.
8 | Args:
9 | centerline: Sequence of coordinates forming the centerline
10 | """
11 | line_coords = list(zip(*centerline))
12 | lineX = line_coords[0]
13 | lineY = line_coords[1]
14 | plt.plot(lineX, lineY, "--", color="grey", alpha=1, linewidth=1, zorder=0)
15 | # plt.text(lineX[0], lineY[0], "s")
16 | # plt.text(lineX[-1], lineY[-1], "e")
17 | plt.axis("equal")
18 |
19 | def get_rotate_invariant_trajs(data: CarlaData):
20 |
21 | rotate_mat = torch.empty(data.num_nodes, 2, 2)
22 | sin_vals = torch.sin(data['rotate_angles'])
23 | cos_vals = torch.cos(data['rotate_angles'])
24 | rotate_mat[:, 0, 0] = cos_vals
25 | rotate_mat[:, 0, 1] = -sin_vals
26 | rotate_mat[:, 1, 0] = sin_vals
27 | rotate_mat[:, 1, 1] = cos_vals
28 |
29 | xrot = torch.bmm(data.positions[:,20:50,:], rotate_mat)
30 | yrot = torch.bmm(data.y, rotate_mat)
31 | # for i in range(xrot.shape[0]):
32 | # plt.plot(xrot[i,:,0], xrot[i,:,1])
33 | # plt.plot(data.x_sensor[i,:,0], data.x_sensor[i,:,1],'--')
34 | # for i in range(yrot.shape[0]):
35 | # plt.plot(yrot[i,:,0], yrot[i,:,1])
36 | # plt.plot(data.y[i,:,0], data.y[i,:,1],'--')
37 |
38 | return xrot, yrot, rotate_mat
39 | def viz_devectorize(xrot_vec):
40 | """
41 | xrot_vec: rotated vector [N,30,2]
42 | """
43 | x_devec = torch.cumsum(xrot_vec, dim=1)
44 | # translate back to original location
45 | x_devec_ori = x_devec - x_devec[:,-1,:]
46 | for i in range(x_devec_ori.shape[0]):
47 | plt.plot(x_devec_ori[i,:,0], x_devec_ori[i,:,1])
48 |
49 | def local_invariant_scenes(data: CarlaData):
50 | xrot, yrot, rotate_mat = get_rotate_invariant_trajs(data)
51 | lane_str, lane_vectors = data.lane_pos, data.lane_vectors
52 | lane_idcs = data.lane_idcs
53 | # # visualize the centerlines
54 | # lane_pos = data.lane_pos
55 | # lane_vectors = data.lane_vectors
56 | # lane_idcs = data.lane_idcs
57 | # for i in torch.unique(lane_idcs):
58 | # lane_str = lane_pos[lane_idcs == i]
59 | # lane_vector = lane_vectors[lane_idcs == i]
60 | # lane_end = lane_str + lane_vector
61 | # lane = torch.vstack([lane_str, lane_end[-1,:].reshape(-1, 2)])
62 | # visualize_centerline(lane)
63 |
64 | #rotate locally
65 | edge_index = data.lane_actor_index
66 |
67 | lane_rotate_mat = rotate_mat[edge_index[1]]
68 | lane_vectors_rot = torch.bmm(lane_vectors[edge_index[0]].unsqueeze(-2), lane_rotate_mat).squeeze(-2) #[#, 2]
69 | lane_pos_rot = torch.bmm(lane_str[edge_index[0]].unsqueeze(-2), lane_rotate_mat).squeeze(-2) #[#, 2]
70 |
71 | #viz local map and traj
72 | for i in range(data.num_nodes):
73 | #traj viz
74 | plt.plot(xrot[i,:,0], xrot[i,:,1])
75 | plt.text(xrot[i,-1,0], xrot[i,-1,1], "q")
76 | #map viz
77 | lane_idx_i = (edge_index[1] == i).nonzero().squeeze()
78 | for j in lane_idx_i:
79 | # lane_str_i = lane_pos_rot[edge_index[1] == i]
80 | lane_str_i = lane_pos_rot[j].unsqueeze(0) #[1,2]
81 | # lane_vector_i = lane_vectors_rot[edge_index[1] == i]
82 | lane_vector_i = lane_vectors_rot[j].unsqueeze(0)
83 | lane_end_i = lane_str_i + lane_vector_i
84 | lane_i = torch.vstack([lane_str_i, lane_end_i])
85 | visualize_centerline(lane_i)
86 |
87 |
88 | #for each agent, get self-centered maps
89 | for i in range(xrot.shape[0]):
90 | lane_vector_i = lane_vectors_rot[edge_index[1]==i]
91 | lane_pos_i = lane_pos_rot[edge_index[1]==i]
92 | lane_end_i = lane_vector_i + lane_pos_i
93 | lane_i = torch.vstack([lane_pos_i, lane_end_i[-1,:].reshape(-1, 2)]) #[L, 2]
94 |
95 | visualize_centerline(lane_i)
96 |
97 | # visualize the centerlines
98 | lane_pos = data.lane_pos
99 | lane_vectors = data.lane_vectors
100 | lane_idcs = data.lane_idcs
101 | for i in torch.unique(lane_idcs):
102 | lane_str = lane_pos[lane_idcs == i]
103 | lane_vector = lane_vectors[lane_idcs == i]
104 | lane_end = lane_str + lane_vector
105 | lane = torch.vstack([lane_str, lane_end[-1,:].reshape(-1, 2)])
106 | visualize_centerline(lane)
107 |
108 | for i in range(data.x.shape[0]):
109 | lane_vector_i = lane_vectors[edge_index[0]][edge_index[1]==i]
110 | lane_pos_i = lane_str[edge_index[0]][edge_index[1]==i]
111 | lane_end_i = lane_vector_i + lane_pos_i
112 | lane_i = torch.vstack([lane_pos_i, lane_end_i[-1,:].reshape(-1, 2)]) #[L, 2]
113 |
114 | visualize_centerline(lane_i)
115 |
116 | def viz_lane_rot():
117 | pass
118 |
119 | def tensor_viz(node_features_all, cav_mask, commu_mask, sensor_mask):
120 |
121 | axes = [8, 16, 3]
122 | filled = np.ones(axes, dtype=np.bool)
123 | colors = np.empty(axes + [4], dtype=np.float32)
124 | alpha = 0.5
125 | colors[:] = [1, 1, 1, alpha]
126 | colors[cav_mask,:,0] = [1, 0, 0, alpha]
127 | colors[commu_mask,:,1] = [0, 1, 0, alpha]
128 | colors[sensor_mask,:,2] = [0, 0, 1, alpha]
129 |
130 | fig = plt.figure()
131 | ax = fig.add_subplot(projection='3d')
132 | ax.voxels(filled, facecolors=colors, edgecolors='grey',shade=True)
133 | plt.show()
134 | plt.axis('off')
135 |
136 |
137 |
--------------------------------------------------------------------------------
/losses/__pycache__/get_anchors.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/get_anchors.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/__pycache__/loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/loss.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/__pycache__/msma_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/msma_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/__pycache__/mtp_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/mtp_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/__pycache__/multipath_loss.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/losses/__pycache__/multipath_loss.cpython-37.pyc
--------------------------------------------------------------------------------
/losses/get_anchors.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sklearn.cluster import KMeans
3 | # import psutil
4 | # import ray
5 | # from scipy.spatial.distance import cdist
6 |
7 | #Initialize device:
8 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
9 |
10 | # # Initialize ray:
11 | # num_cpus = psutil.cpu_count(logical=False)
12 | # ray.init(num_cpus=num_cpus, log_to_driver=False)
13 |
14 | def k_means_anchors(k, train_loader):
15 | """
16 | Extract anchors for multipath/covernet using k-means on train set trajectories
17 | gt_y: [num_v, op_len, 2]
18 | train_loader: CarlaData
19 | """
20 |
21 | trajectories = []
22 | rotate_imat= []
23 | for i, data in enumerate(train_loader):
24 | trajectories.append(data.y)
25 | rotate_imat.append(data.rotate_imat)
26 |
27 | traj_all = torch.cat(trajectories, dim=0)
28 | rotate_imat_all = torch.cat(rotate_imat, dim=0)
29 | traj_all_rot = torch.matmul(traj_all, rotate_imat_all)
30 |
31 | clustering = KMeans(n_clusters=k).fit(traj_all_rot.reshape((traj_all_rot.shape[0], -1)))
32 | op_len, op_dim = traj_all_rot.shape[1], traj_all_rot.shape[2]
33 | anchors = torch.zeros((k, op_len, op_dim)).to(device)
34 | for i in range(k):
35 | anchors[i] = torch.mean(traj_all_rot[clustering.labels_==i], axis=0)
36 | # for i in range(traj_all_rot.shape[0]):
37 | # plt.plot(traj_all_rot[i, :, 0], traj_all_rot[i, :, 1])
38 | # for i in range(anchors.shape[0]):
39 | # plt.plot(anchors[i, :, 0], anchors[i, :, 1])
40 |
41 | return anchors
42 |
43 |
44 | def bivariate_gaussian_activation(ip: torch.Tensor) -> torch.Tensor:
45 | """
46 | Activation function to output parameters of bivariate Gaussian distribution
47 | """
48 | mu_x = ip[..., 0:1]
49 | mu_y = ip[..., 1:2]
50 | sig_x = ip[..., 2:3]
51 | sig_y = ip[..., 3:4]
52 | rho = ip[..., 4:5]
53 | sig_x = torch.exp(sig_x)
54 | sig_y = torch.exp(sig_y)
55 | rho = torch.tanh(rho)
56 | out = torch.cat([mu_x, mu_y, sig_x, sig_y, rho], dim = -1)
57 |
58 | return out
--------------------------------------------------------------------------------
/losses/hivt_loss.py:
--------------------------------------------------------------------------------
1 | # source: https://github.com/ZikangZhou/HiVT/blob/main/
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class LaplaceNLLLoss(nn.Module):
7 |
8 | def __init__(self,
9 | eps: float = 1e-6,
10 | reduction: str = 'mean') -> None:
11 | super(LaplaceNLLLoss, self).__init__()
12 | self.eps = eps
13 | self.reduction = reduction
14 |
15 | def forward(self,
16 | y_hat: torch.Tensor,
17 | y_gt: torch.Tensor,
18 | pi: torch.Tensor) -> torch.Tensor:
19 | loc, scale = pred.chunk(2, dim=-1)
20 | scale = scale.clone()
21 | with torch.no_grad():
22 | scale.clamp_(min=self.eps)
23 | nll = torch.log(2 * scale) + torch.abs(target - loc) / scale
24 | if self.reduction == 'mean':
25 | return nll.mean()
26 | elif self.reduction == 'sum':
27 | return nll.sum()
28 | elif self.reduction == 'none':
29 | return nll
30 | else:
31 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction))
32 |
33 | class SoftTargetCrossEntropyLoss(nn.Module):
34 |
35 | def __init__(self, reduction: str = 'mean') -> None:
36 | super(SoftTargetCrossEntropyLoss, self).__init__()
37 | self.reduction = reduction
38 |
39 | def forward(self,
40 | pred: torch.Tensor,
41 | target: torch.Tensor) -> torch.Tensor:
42 | cross_entropy = torch.sum(-target * F.log_softmax(pred, dim=-1), dim=-1)
43 | if self.reduction == 'mean':
44 | return cross_entropy.mean()
45 | elif self.reduction == 'sum':
46 | return cross_entropy.sum()
47 | elif self.reduction == 'none':
48 | return cross_entropy
49 | else:
50 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction))
--------------------------------------------------------------------------------
/losses/msma_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from metrics.metric import min_ade, traj_nll
4 |
5 | class NLLloss(nn.Module):
6 | """
7 | MTP loss modified to include variances. Uses MSE for mode selection. Can also be used with
8 | Multipath outputs, with residuals added to anchors.
9 | """
10 | def __init__(self, alpha=0.2, use_variance=True, device='cpu'):
11 | """
12 | Initialize MSMA loss
13 | :param args: Dictionary with the following (optional) keys
14 | use_variance: bool, whether or not to use variances for computing regression component of loss,
15 | default: False
16 | alpha: float, relative weight assigned to classification component, compared to regression component
17 | of loss, default: 1
18 | """
19 | super(NLLloss, self).__init__()
20 | self.use_variance = use_variance
21 | self.alpha = alpha
22 | self.device = device
23 |
24 | def forward(self, y_pred, y_true, log_probs):
25 | """
26 | params:
27 | :y_pred: [num_nodes, num_modes, op_len, 2]
28 | :y_true: [num_nodes, op_len, 2]
29 | :log_probs: probability for each mode [N_B, N_M]
30 | where N_B is batch_size, N_M is num_modes, op_len is target_len
31 | """
32 |
33 |
34 | num_nodes = y_true.shape[0]
35 | l2_norm = (torch.norm(y_pred - y_true.unsqueeze(1), p=2, dim=-1)).sum(dim=-1)
36 | best_mode = l2_norm.argmin(dim=1)
37 | pred_best = y_pred[torch.arange(num_nodes), best_mode, :, :]
38 |
39 |
40 | loss_cls = (-log_probs[torch.arange(num_nodes).to(self.device), best_mode].squeeze()).mean() #[N_B]
41 |
42 | loss_reg = (torch.norm(pred_best-y_true, p=2, dim=-1)).mean()
43 |
44 |
45 | loss = loss_reg + self.alpha * loss_cls
46 |
47 | return loss
--------------------------------------------------------------------------------
/losses/mtp_loss.py:
--------------------------------------------------------------------------------
1 | # source: https://github.com/nachiket92/PGP/blob/main/metrics/mtp_loss.py
2 | import torch
3 | import torch.nn as nn
4 | from metrics.metric import min_ade, traj_nll
5 |
6 | class NLLloss(nn.Module):
7 | """
8 | MTP loss modified to include variances. Uses MSE for mode selection. Can also be used with
9 | Multipath outputs, with residuals added to anchors.
10 | """
11 | def __init__(self, alpha=0.2, use_variance=True):
12 | """
13 | Initialize MTP loss
14 | :param args: Dictionary with the following (optional) keys
15 | use_variance: bool, whether or not to use variances for computing regression component of loss,
16 | default: False
17 | alpha: float, relative weight assigned to classification component, compared to regression component
18 | of loss, default: 1
19 | """
20 | super(NLLloss, self).__init__()
21 | self.use_variance = use_variance
22 | self.alpha = alpha
23 |
24 | def forward(self, y_pred, y_gt, log_probs):
25 | """
26 | params:
27 | :y_pred: [num_vehs, num_modes, op_len, op_dim]
28 | :y_gt: [num_vehs, op_len, 2]
29 | :log_probs: probability for each mode [num_vehs, num_modes]
30 | :alpha: float, relative weight assigned to classification component, compared to regression component
31 | of loss, default: 1
32 | """
33 | alpha = self.alpha
34 | use_variance = self.use_variance
35 | # Obtain mode with minimum ADE with respect to ground truth:
36 | op_len = y_pred.shape[2]
37 | pred_params = 5 if use_variance else 2
38 |
39 | errs, inds = min_ade(y_pred, y_gt)
40 | inds_rep = inds.repeat(op_len, pred_params, 1, 1).permute(3, 2, 0, 1)
41 |
42 | # Calculate MSE or NLL loss for trajectories corresponding to selected outputs:
43 | traj_best = y_pred.gather(1, inds_rep).squeeze(dim=1)
44 | # # devectorize traj_best
45 | # for i in range(1,50):
46 | # traj_best[:,i,:] += traj_best[:,i-1,:]
47 |
48 | if use_variance:
49 | l_reg = traj_nll(traj_best, y_gt)
50 | else:
51 | l_reg = errs
52 |
53 | # Compute classification loss
54 | l_class = - torch.squeeze(log_probs.gather(1, inds.unsqueeze(1)))
55 |
56 | loss = l_reg + alpha * l_class
57 | loss = torch.mean(loss)
58 |
59 | return loss
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/losses/multipath_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from metrics.metric import min_ade, traj_nll
4 |
5 | class NLLloss(nn.Module):
6 | """
7 | MTP loss modified to include variances. Uses MSE for mode selection. Can also be used with
8 | Multipath outputs, with residuals added to anchors.
9 | """
10 | def __init__(self, alpha=0.2, use_variance=True):
11 | """
12 | Initialize MTP loss
13 | :param args: Dictionary with the following (optional) keys
14 | use_variance: bool, whether or not to use variances for computing regression component of loss,
15 | default: False
16 | alpha: float, relative weight assigned to classification component, compared to regression component
17 | of loss, default: 1
18 | """
19 | super(NLLloss, self).__init__()
20 | self.use_variance = use_variance
21 | self.alpha = alpha
22 |
23 | def forward(self, y_pred, y_true, log_probs, anchors):
24 | """
25 | params:
26 | :y_pred: [num_nodes, num_modes, op_len, 2]
27 | :y_true: [num_nodes, op_len, 2]
28 | :log_probs: probability for each mode [N_B, N_M]
29 | :anchors: [num_modes, op_len, 2]
30 | where N_B is batch_size, N_M is num_modes, N_T is target_len
31 | """
32 |
33 |
34 | num_nodes = y_true.shape[0]
35 | trajectories = y_pred
36 | anchor_probs = log_probs
37 |
38 | #find the nearest anchor mode to y_true
39 | #[1, num_modes, op_len, 2] - [num_nodes, 1, op_len, 2] = [num_nodes, num_modes, op_len, 2]
40 | distance_to_anchors = torch.sum(torch.linalg.vector_norm(anchors.unsqueeze(0) - y_true.unsqueeze(1),
41 | dim=-1),dim=-1) #[num_nodes, num_modes]
42 |
43 | nearest_mode = distance_to_anchors.argmin(dim=-1) #[num_nodes]
44 | nearest_mode_indices = torch.stack([torch.arange(num_nodes,dtype=torch.int64),nearest_mode],dim=-1)
45 |
46 | loss_cls = -log_probs[torch.arange(num_nodes),nearest_mode].squeeze() #[N_B]
47 |
48 | trajectories_xy = y_pred + anchors.unsqueeze(0)
49 | # l2_norm = (torch.norm(trajectories_xy[:, :, :, :2] - y_true.unsqueeze(1), p=2, dim=-1)).sum(dim=-1) # [num_nodes, num_modes]
50 |
51 | nearest_trajs = trajectories_xy[torch.arange(num_nodes),nearest_mode,:,:].squeeze()
52 | residual_trajs = y_true - nearest_trajs
53 |
54 | loss_reg = torch.mean(torch.square(residual_trajs[:,:,0])+torch.square(residual_trajs[:,:,1]), dim=-1)
55 | dx = residual_trajs[:,:,0]
56 | dy = residual_trajs[:,:,1]
57 |
58 | loss = loss_reg + self.alpha * loss_cls
59 | loss = torch.mean(loss)
60 |
61 | return loss
62 |
63 |
64 |
--------------------------------------------------------------------------------
/metrics/__pycache__/ade.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/metrics/__pycache__/ade.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/fde.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/metrics/__pycache__/fde.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/metric.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/metrics/__pycache__/metric.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/__pycache__/mr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/metrics/__pycache__/mr.cpython-37.pyc
--------------------------------------------------------------------------------
/metrics/ade.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 |
7 | class ADE(Metric):
8 |
9 | def __init__(self,
10 | compute_on_step: bool = True,
11 | dist_sync_on_step: bool = False,
12 | process_group: Optional[Any] = None,
13 | dist_sync_fn: Callable = None) -> None:
14 | super(ADE, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step,
15 | process_group=process_group, dist_sync_fn=dist_sync_fn)
16 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
17 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
18 |
19 | def update(self,
20 | pred: torch.Tensor,
21 | target: torch.Tensor) -> None:
22 | self.sum += torch.norm(pred - target, p=2, dim=-1).mean(dim=-1).sum()
23 | self.count += pred.size(0)
24 |
25 | def compute(self) -> torch.Tensor:
26 | return self.sum / self.count
--------------------------------------------------------------------------------
/metrics/fde.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 |
7 | class FDE(Metric):
8 |
9 | def __init__(self,
10 | compute_on_step: bool = True,
11 | dist_sync_on_step: bool = False,
12 | process_group: Optional[Any] = None,
13 | dist_sync_fn: Callable = None) -> None:
14 | super(FDE, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step,
15 | process_group=process_group, dist_sync_fn=dist_sync_fn)
16 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
17 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
18 |
19 | def update(self,
20 | pred: torch.Tensor,
21 | target: torch.Tensor) -> None:
22 | self.sum += torch.norm(pred[:, -1] - target[:, -1], p=2, dim=-1).sum()
23 | self.count += pred.size(0)
24 |
25 | def compute(self) -> torch.Tensor:
26 | return self.sum / self.count
--------------------------------------------------------------------------------
/metrics/metric.py:
--------------------------------------------------------------------------------
1 | #source: https://github.com/nachiket92/PGP/blob/main/metrics/utils.py
2 | import torch
3 | from typing import Tuple
4 |
5 | def ade(traj: torch.Tensor, traj_gt: torch.Tensor):
6 | ls = torch.norm(traj - traj_gt, p=2, dim=-1).mean(dim=-1).mean()
7 |
8 | return ls
9 |
10 | def fde(traj: torch.Tensor, traj_gt: torch.Tensor):
11 | ls = torch.norm(traj[:, -1] - traj_gt[:, -1], p=2, dim=-1).mean()
12 |
13 | return ls
14 |
15 | def mr(traj: torch.Tensor, traj_gt: torch.Tensor, miss_threshold: torch.Tensor):
16 | ls = (torch.norm(traj[:, -1] - traj_gt[:, -1], p=2, dim=-1) > miss_threshold).sum()
17 |
18 | return ls/traj.shape[0]
19 |
20 |
21 | def min_ade(traj: torch.Tensor, traj_gt: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
22 | """
23 | Computes average displacement error for the best trajectory in a set, with respect to ground truth
24 | :param traj: predictions, shape [num_vehs, num_modes, op_len, 2]
25 | :param traj_gt: ground truth trajectory, shape [num_vehs, op_len, 2]
26 | :return errs, inds: errors and indices for modes with min error, shape [num_vehs]
27 | """
28 | num_modes = traj.shape[1]
29 | op_len = traj.shape[2]
30 |
31 | traj_gt_rpt = traj_gt.unsqueeze(1).repeat(1, num_modes, 1, 1)
32 | # masks_rpt = masks.unsqueeze(1).repeat(1, num_modes, 1)
33 |
34 | err = (traj_gt_rpt - traj[:, :, :, 0:2])
35 | err = torch.pow(err, exponent=2)
36 | err = torch.sum(err, dim=3)
37 | err = torch.pow(err, exponent=0.5)
38 | err = torch.sum(err, dim=2) / op_len
39 |
40 | # err[stat_idx,:] = err[stat_idx,:]*10000
41 |
42 | err, inds = torch.min(err, dim=1)
43 |
44 | return err, inds
45 |
46 | def traj_nll(pred_dist: torch.Tensor, traj_gt: torch.Tensor):
47 | """
48 | Computes negative log likelihood of ground truth trajectory under a predictive distribution with a single mode,
49 | with a bivariate Gaussian distribution predicted at each time in the prediction horizon
50 |
51 | :param pred_dist: parameters of a bivariate Gaussian distribution, shape [num_vehs, op_len, 5]
52 | :param traj_gt: ground truth trajectory, shape [num_vehs, op_len, 2]
53 | :return:
54 | """
55 | # op_len = pred_dist.shape[1]
56 | # mu_x = pred_dist[:, :, 0]
57 | # mu_y = pred_dist[:, :, 1]
58 | # x = traj_gt[:, :, 0]
59 | # y = traj_gt[:, :, 1]
60 |
61 | # sig_x = pred_dist[:, :, 2]
62 | # sig_y = pred_dist[:, :, 3]
63 | # rho = pred_dist[:, :, 4]
64 | # ohr = torch.pow(1 - torch.pow(rho, 2), -0.5)
65 |
66 | # nll = 0.5 * torch.pow(ohr, 2) * \
67 | # (torch.pow(sig_x, 2) * torch.pow(x - mu_x, 2) +
68 | # torch.pow(sig_y, 2) * torch.pow(y - mu_y, 2) -
69 | # 2 * rho * torch.pow(sig_x, 1) * torch.pow(sig_y, 1) * (x - mu_x) * (y - mu_y))\
70 | # - torch.log(sig_x * sig_y * ohr) + 1.8379
71 |
72 | # nll[nll.isnan()] = 0
73 | # nll[nll.isinf()] = 0
74 |
75 | # nll = torch.sum(nll, dim=1) / op_len
76 | pred_loc = pred_dist[:,:,:2]
77 | pred_var = pred_dist[:,:,2:4]
78 |
79 | nll = torch.sum(0.5 * torch.log(pred_var) + 0.5 * torch.div(torch.square(traj_gt - pred_loc), pred_var) +\
80 | 0.5 * torch.log(2 * torch.tensor(3.14159265358979323846)))
81 |
82 |
83 | return nll
84 |
85 | def NLLloss(y_pred, y_true, log_probs, anchors):
86 | """
87 | params:
88 | :y_pred: [N_T, N_M, N_B, 2]
89 | :y_true: [N_T, N_B, 2]
90 | :log_probs: probability for each mode [N_B, N_M]
91 | :anchors: [N_M, N_T,2]
92 | where N_B is batch_size, N_M is num_modes, N_T is target_len
93 | """
94 |
95 |
96 | batch_size = y_true.shape[1]
97 | trajectories = y_pred
98 | anchor_probs = log_probs
99 |
100 | #find the nearest anchor mode to y_true
101 | #[1, N_M, N_T,2] - [N_B, N_M, N_T, 2] = [N_B, N_M, N_T, 2]
102 | distance_to_anchors = torch.sum(torch.linalg.vector_norm(anchors.unsqueeze(0) - y_true.permute(1,0,2).unsqueeze(1),
103 | dim=(-1)),dim=-1) #[N_B, N_M]
104 |
105 | nearest_mode = distance_to_anchors.argmin(dim=-1) #[N_B]
106 | nearest_mode_indices = torch.stack([torch.arange(batch_size,dtype=torch.int64),nearest_mode],dim=-1)
107 |
108 | loss_cls = -log_probs[torch.arange(batch_size),nearest_mode].squeeze() #[N_B]
109 |
110 | #trajectories_xy: [N_B, N_M, N_T, 2]
111 | #nearest_trajs: [N_B, N_T, 2]
112 | #residual_trajs: [N_B, N_T, 2]
113 | trajectories_xy = y_pred.permute(2,1,0,3)[...,:2] + anchors.unsqueeze(0)
114 | nearest_trajs = trajectories_xy[torch.arange(batch_size),nearest_mode,:,:].squeeze()
115 | residual_trajs = y_true.permute(1,0,2) - nearest_trajs
116 |
117 | loss_reg = torch.mean(torch.square(residual_trajs[:,:,0])+torch.square(residual_trajs[:,:,1]), dim=-1)
118 | dx = residual_trajs[:,:,0]
119 | dy = residual_trajs[:,:,1]
120 |
121 | total_loss = torch.mean(loss_cls+loss_reg)
122 |
123 | return loss_cls, loss_reg
124 |
--------------------------------------------------------------------------------
/metrics/mr.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Optional
2 |
3 | import torch
4 | from torchmetrics import Metric
5 |
6 |
7 | class MR(Metric):
8 |
9 | def __init__(self,
10 | miss_threshold: float = 2.0,
11 | compute_on_step: bool = True,
12 | dist_sync_on_step: bool = False,
13 | process_group: Optional[Any] = None,
14 | dist_sync_fn: Callable = None) -> None:
15 | super(MR, self).__init__(compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step,
16 | process_group=process_group, dist_sync_fn=dist_sync_fn)
17 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum')
18 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum')
19 | self.miss_threshold = miss_threshold
20 |
21 | def update(self,
22 | pred: torch.Tensor,
23 | target: torch.Tensor) -> None:
24 | self.sum += (torch.norm(pred[:, -1] - target[:, -1], p=2, dim=-1) > self.miss_threshold).sum()
25 | self.count += pred.size(0)
26 |
27 | def compute(self) -> torch.Tensor:
28 | return self.sum / self.count
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | from os.path import join as pjoin
4 |
5 | import torch
6 | from torch_geometric.loader import DataLoader
7 | from torch.optim import Adam, AdamW
8 | from tqdm import tqdm
9 | import math
10 |
11 | from dataloader.carla_scene_process import CarlaData, scene_processed_dataset
12 | from ModelNet.msma import Base_Net
13 | from torch_geometric.utils import subgraph
14 | from losses.msma_loss import NLLloss
15 | from utils.optim_schedule import ScheduledOptim
16 |
17 | #load/process the data
18 | root = "../carla_data/"
19 | source_dir = "scene_mining"
20 | mpr = 0.8
21 | delay_frame = 1
22 | noise_var = 0.1
23 | save_dir = "scene_mining_cav/mpr8_delay{}_noise{}".format(delay_frame, noise_var)
24 |
25 | train_set = scene_processed_dataset(root,
26 | "train",
27 | mpr=mpr,
28 | delay_frame=delay_frame,
29 | noise_var=noise_var,
30 | source_dir=source_dir,
31 | save_dir=save_dir)
32 | val_set = scene_processed_dataset(root,
33 | "val",
34 | mpr=mpr,
35 | delay_frame=delay_frame,
36 | noise_var=noise_var,
37 | source_dir=source_dir,
38 | save_dir=save_dir)
39 | test_set = scene_processed_dataset(root,
40 | "test",
41 | mpr=mpr,
42 | delay_frame=delay_frame,
43 | noise_var=noise_var,
44 | source_dir=source_dir,
45 | save_dir=save_dir)
46 | #args
47 | batch_size = 64
48 | num_workers = 4
49 | horizon = 50
50 | lr = 1e-3
51 | betas=(0.9, 0.999)
52 | weight_decay = 0.0001
53 | warmup_epoch=10
54 | lr_update_freq=10
55 | lr_decay_rate=0.9
56 |
57 |
58 | log_freq = 10
59 | save_folder = ""
60 | model_path = '../carla_data/scene_mining_cav'
61 | ckpt_path = None
62 | verbose = True
63 |
64 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65 |
66 | model = Base_Net(ip_dim=2,
67 | historical_steps=30,
68 | embed_dim=16,
69 | temp_ff=64,
70 | spat_hidden_dim=64,
71 | spat_out_dim=64,
72 | edge_attr_dim=2,
73 | map_out_dim=64,
74 | lane_dim=2,
75 | map_local_radius=30,
76 | decoder_hidden_dim=64,
77 | num_heads=8,
78 | dropout=0.1,
79 | num_temporal_layers=4,
80 | use_variance=False,
81 | device="cpu",
82 | commu_only=False,
83 | sensor_only=False,
84 | prediction_mode="all")
85 |
86 | #dataloader
87 | train_loader = DataLoader(
88 | train_set,
89 | batch_size=batch_size,
90 | num_workers=num_workers,
91 | pin_memory=True,
92 | shuffle=True,
93 | persistent_workers=True
94 | )
95 | eval_loader = DataLoader(val_set, batch_size=batch_size, num_workers=num_workers, persistent_workers=True, shuffle=False)
96 | test_loader = DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, persistent_workers=True, shuffle=False)
97 |
98 | #loss
99 | criterion = NLLloss(alpha=0.5, use_variance=False, device=device)
100 | # anchors = k_means_anchors(5, train_loader)
101 |
102 | # init optimizer
103 | optim = AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
104 | optm_schedule = ScheduledOptim(
105 | optim,
106 | lr,
107 | n_warmup_epoch=warmup_epoch,
108 | update_rate=lr_update_freq,
109 | decay_rate=lr_decay_rate
110 | )
111 |
112 | model = model.to(device)
113 | if verbose:
114 | print("[MSMATrainer]: Train the mode with single device on {}.".format(device))
115 |
116 | # model.load_state_dict(torch.load('{}/trained_models_review/model_mpr{}_noise{}_fuse_{}_2.tar'.format(model_path, mpr, noise_var, model.prediction_mode)))
117 |
118 | # iteration
119 | training = model.training
120 | avg_loss = 0.0
121 | avg_loss_val = 0.0
122 | losses_train =[]
123 | losses_val = []
124 |
125 | epochs = 100
126 | minVal = math.inf
127 |
128 | # %%
129 |
130 | for epoch in range(epochs):
131 | avg_loss = 0.0
132 | ## Train:_______________________________________________________________________________________________________________________________
133 | training = True
134 | # model.train()
135 | data_iter = tqdm(
136 | enumerate(train_loader),
137 | desc="{}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}".format("train" if training else "eval",
138 | epoch,
139 | 0.0,
140 | avg_loss),
141 | total=len(train_loader),
142 | bar_format="{l_bar}{r_bar}"
143 | )
144 | count = 0
145 |
146 | for i, data in data_iter: #next(iter(train_loader))
147 | data = data.to(device)
148 |
149 | if training:
150 | optm_schedule.zero_grad()
151 | predictions, mask = model(data)
152 | gt = torch.matmul(data.y, data.rotate_imat)[mask]
153 | loss = criterion(predictions['traj'], gt, predictions['log_probs'])
154 | loss.backward()
155 | losses_train.append(loss.detach().item())
156 |
157 | torch.nn.utils.clip_grad_norm_(model.parameters(), 100)
158 | optim.step()
159 | # write_log("Train Loss", loss.detach().item() / n_graph, i + epoch * len(train_loader))
160 | avg_loss += loss.detach().item()
161 | count += 1
162 |
163 | # print log info
164 | desc_str = "[Info: Device_{}: {}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}]".format(
165 | 0,
166 | "train" if training else "eval",
167 | epoch,
168 | loss.item(),
169 | avg_loss / count)
170 | data_iter.set_description(desc=desc_str, refresh=True)
171 |
172 | if training:
173 | learning_rate = optm_schedule.step_and_update_lr()
174 | if epoch%10==0:
175 | print("learning_rate: ", learning_rate)
176 | # write_log("LR", learning_rate, epoch)
177 |
178 |
179 | ## Val:_______________________________________________________________________________________________________________________________
180 | training = False
181 | # model.eval()
182 | avg_loss_val = 0.0
183 | count_val = 0
184 | data_iter_val = tqdm(enumerate(eval_loader), desc="{}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}".format("eval",
185 | epoch,
186 | 0.0,
187 | avg_loss_val),
188 | total=len(eval_loader),
189 | bar_format="{l_bar}{r_bar}"
190 | )
191 | for i, data_val in data_iter_val:
192 | data_val = data_val.to(device)
193 |
194 | with torch.no_grad():
195 | predictions_val, mask_val = model(data_val)
196 | gt_val = torch.matmul(data_val.y, data_val.rotate_imat)[mask_val]
197 | loss_val = criterion(predictions_val['traj'],
198 | gt_val, predictions_val['log_probs'])
199 |
200 | losses_val.append(loss_val.detach().item())
201 | avg_loss_val += loss_val.detach().item()
202 | count_val += 1
203 |
204 | # print log info
205 | desc_str_val = "[Info: Device_{}: {}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}]".format(
206 | 0,
207 | "eval",
208 | epoch,
209 | loss_val.item(),
210 | avg_loss_val / count_val)
211 | data_iter_val.set_description(desc=desc_str_val, refresh=True)
212 |
213 | if loss_val.item() < minVal:
214 | minVal = loss_val.item()
215 | torch.save(model.state_dict(), '{}/trained_models_review/model_mpr{}_noise{}_fuse_{}_3.tar'.format(model_path, mpr, noise_var, model.prediction_mode))
216 |
217 | # %%
218 | ## Test:___________________________________________________________________________________________________________________________________
219 | def test(model, test_loader, epoch):
220 | """
221 | make predictions on test dataset
222 |
223 | """
224 | training = model.training
225 | training = False
226 | # model.training = False
227 | count_test = 0
228 | avg_loss_test = 0.0
229 | predictions_test = {}
230 | gts_test = {}
231 | batch_info = {}
232 | probs = {}
233 | masks = {}
234 | sensor_masks = {}
235 |
236 | data_iter_test = tqdm(enumerate(test_loader), desc="{}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}".format("test",
237 | epoch,
238 | 0.0,
239 | avg_loss_test),
240 | total=len(test_loader),
241 | bar_format="{l_bar}{r_bar}"
242 | )
243 | for i, data_test in data_iter_test:
244 | data_test = data_test.to(device)
245 |
246 | with torch.no_grad():
247 | pred_test, mask_test = model(data_test) #pred_test: offset to anchors
248 | gt_test = torch.matmul(data_test.y, data_test.rotate_imat)[mask_test] #aligned at +x axis
249 | #sum of reg and cls loss for all detected vehs
250 | loss_test = criterion(pred_test['traj'], \
251 | gt_test, pred_test['log_probs'])
252 |
253 | count_test += 1
254 | avg_loss_test += loss_test.detach().item()
255 | #compare predictions for vehs in sensor range when centered at [0,0] but not aligned with x-axis
256 | predictions_test_i = torch.zeros((mask_test.shape[0], 5, 50, 2)).to(device)
257 | predictions_test_i[mask_test]= pred_test["traj"]
258 | predictions_test[i] = torch.matmul(predictions_test_i, \
259 | torch.inverse(data_test.rotate_imat.unsqueeze(1)))
260 | # predictions_test[i] = torch.matmul(pred_test["traj"] + anchors.unsqueeze(0), \
261 | # torch.inverse(data_test.rotate_imat[mask_test]))
262 | batch_info[i] = data_test.batch
263 | probs_i = torch.zeros((mask_test.shape[0], 5)).to(device)
264 | probs_i[mask_test] = torch.exp(pred_test['log_probs'])
265 | probs[i] = probs_i
266 | # probs[i] = torch.exp(pred_test['log_probs'])
267 | masks[i] = mask_test
268 | sensor_masks[i] = data_test.sensor_mask
269 | gts_test[i] = data_test.y
270 |
271 | # print log info
272 | desc_str_test = "[Info: Device_{}: {}_Ep_{}: loss: {:.5e}; avg_loss: {:.5e}]".format(
273 | 0,
274 | "test",
275 | epoch,
276 | loss_test.item(),
277 | avg_loss_test / count_test)
278 | data_iter_test.set_description(desc=desc_str_test, refresh=True)
279 |
280 | return predictions_test, gts_test, probs, batch_info, masks, sensor_masks
281 |
282 | predictions_av_av, gt_av_av, probs_av_av, batch_av_av, mask_av_av, sensor_mask_av_av = test(model, test_loader, 100)
283 |
--------------------------------------------------------------------------------
/utils/__pycache__/optim_schedule.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xichennn/MSMA/93e0814391d1ed0586c87f37ea5fc267f818cf16/utils/__pycache__/optim_schedule.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/optim_schedule.py:
--------------------------------------------------------------------------------
1 | # A wrapper class for optimizer
2 | # source: https://github.com/codertimo/BERT-pytorch/blob/master/bert_pytorch/trainer/optim_schedule.py
3 | import numpy as np
4 |
5 |
6 | class ScheduledOptim:
7 | """ A simple wrapper class for learning rate scheduling
8 | """
9 |
10 | def __init__(self, optimizer, init_lr, n_warmup_epoch=10, update_rate=5, decay_rate=0.9):
11 | self._optimizer = optimizer
12 | self.n_warmup_epoch = n_warmup_epoch
13 | self.n_current_steps = 0
14 | self.init_lr = init_lr
15 | self.update_rate = update_rate
16 | self.decay_rate = decay_rate
17 |
18 | def step_and_update_lr(self):
19 | """Step with the inner optimizer"""
20 | self.n_current_steps += 1
21 | rate = self._update_learning_rate()
22 |
23 | return rate
24 | # self._optimizer.step()
25 |
26 | def zero_grad(self):
27 | "Zero out the gradients by the inner optimizer"
28 | self._optimizer.zero_grad()
29 |
30 | def _get_lr_scale(self):
31 | return np.power(self.decay_rate, max((self.n_current_steps - self.n_warmup_epoch + 1) // self.update_rate + 1, 0))
32 |
33 | def _update_learning_rate(self):
34 | """ Learning rate scheduling per step """
35 |
36 | lr = self.init_lr * self._get_lr_scale()
37 |
38 | for param_group in self._optimizer.param_groups:
39 | param_group['lr'] = lr
40 | return lr
41 |
42 | if __name__ == "__main__":
43 | lr = 1e-3
44 | betas=(0.9, 0.999)
45 | weight_decay = 0.0001
46 | warmup_epoch=150000
47 | lr_update_freq=5
48 | lr_decay_rate=0.3
49 |
--------------------------------------------------------------------------------
/utils/viz.py:
--------------------------------------------------------------------------------
1 | #architecture picture in test.py on colab
2 | import matplotlib.pyplot as plt
3 | def visualize_centerline(centerline) -> None:
4 | """Visualize the computed centerline.
5 | Args:
6 | centerline: Sequence of coordinates forming the centerline
7 | """
8 | line_coords = list(zip(*centerline))
9 | lineX = line_coords[0]
10 | lineY = line_coords[1]
11 | plt.plot(lineX, lineY, "--", color="grey", alpha=1, linewidth=1, zorder=0)
12 | # plt.text(lineX[0], lineY[0], "s")
13 | # plt.text(lineX[-1], lineY[-1], "e")
14 | plt.axis("equal")
15 |
16 | def visualize_map(lane_strs, lane_vecs, lane_idcs):
17 | for i in range(1, len(lane_idcs.unique())):
18 | lane_start = lane_strs[lane_idcs == i]
19 | vecs = lane_vecs[lane_idcs == i]
20 | lane_end = lane_start + vecs
21 | lane = torch.vstack([lane_start, lane_end[-1,:].reshape(-1, 2)])
22 | visualize_centerline(lane)
23 |
24 | def visualize_traj(prediction, gt, prob, best_mode=True):
25 | """
26 | prediction: [num_nodes, num_modes, op_len, 2]
27 | gt: [num_nodes, op_len, 2]
28 | prob: [num_nodes, num_modes]
29 | """
30 | n, m = prediction.shape[0], prediction.shape[1]
31 |
32 | if best_mode:
33 | # prs, inds = torch.max(prob, dim=1)
34 |
35 | # for i in range(n):
36 | # plt.plot(prediction[i,inds[i],:,0], prediction[i,inds[i],:,1])
37 | # plt.text(prediction[i,inds[i],-1,0], prediction[i,inds[i],-1,1],
38 | # "{:.2f}".format(prs[i].item()))
39 | # plt.plot(gt[i,:,0], gt[i,:,1],'--')
40 | l2_norm = (torch.norm(prediction[:, :, :, : 2] - \
41 | gt.unsqueeze(1), p=2, dim=-1)).sum(dim=-1)
42 | best_mode = l2_norm.argmin(dim=-1)
43 | y_pred_best = prediction[torch.arange(gt.shape[0]), best_mode, :, : 2]
44 | for i in range(n):
45 | plt.plot(y_pred_best[i,:,0], y_pred_best[i,:,1],'b')
46 | plt.plot(gt[i,:,0], gt[i,:,1], c='orange', linestyle='--')
47 | # circle_ncv = plt.Circle((gt[i,0,0], gt[i,0,1]),
48 | # 1, color='orange')
49 | # plt.gca().add_patch(circle_ncv)
50 |
51 | else:
52 | for i in range(n):
53 | for j in range(m):
54 | plt.plot(prediction[i,j,:,0], prediction[i,j,:,1])
55 | plt.plot(gt[i,:,0], gt[i,:,1], c='orange', linestyle='--')
56 | circle_ncv = plt.Circle((gt[i,0,0], gt[i,0,1]),
57 | 1, color='orange')
58 | plt.gca().add_patch(circle_ncv)
59 |
60 | def visualize_gt_traj(gt):
61 | for i in range(gt.shape[0]):
62 | plt.plot(gt[i,:,0], gt[i,:,1], c='orange', linestyle='--')
63 | def visualize_pred_traj(pred, prob, best_mode=True):
64 | n, m = pred.shape[0], pred.shape[1]
65 | if best_mode:
66 | prs, inds = torch.max(prob, dim=1)
67 | for i in range(n):
68 | plt.plot(pred[i,inds[i],:,0], pred[i,inds[i],:,1])
69 | plt.text(pred[i,inds[i],-1,0], pred[i,inds[i],-1,1],
70 | "{:.2f}".format(prs[i].item()))
71 | else:
72 | for i in range(n):
73 | for j in range(m):
74 | plt.plot(pred[i,j,:,0], pred[i,j,:,1])
75 |
76 | def prediction_viz(sample, batch_size, test_set, predictions, probs, batch, masks, mpr=0):
77 | """
78 | prediction: [num_nodes, num_modes, op_len, 2]
79 | gt: [num_nodes, op_len, 2]
80 | prob: [num_nodes, num_modes]
81 | """
82 | s0, s1 = divmod(sample, batch_size)
83 |
84 | #map viz
85 | lane_vecs = test_set.get(sample).lane_vectors
86 | lane_strs = test_set.get(sample).lane_pos
87 | lane_idcs = test_set.get(sample).lane_idcs
88 | # visualize_map(lane_strs, lane_vecs, lane_idcs)
89 | #traj viz
90 | prediction = predictions[s0][batch[s0]==s1,:].cpu() #[num_nodes, num_modes, op_len, 2]
91 | prob = probs[s0][batch[s0]==s1,:].cpu() #[num_nodes, num_modes]
92 | mask = masks[s0][batch[s0]==s1].cpu() #[num_nodes]
93 | gt = test_set.get(sample).y.cpu() #[num_nodes, op_len, 2]
94 | orig = test_set.get(sample).positions[:,49,:].unsqueeze(1) #[num_nodes, 1, 2]
95 | # visualize_traj((prediction+orig.unsqueeze(1))[mask], (gt+orig)[mask], prob[mask], best_mode=True)
96 | #cav
97 | cav_ori = (gt+orig)[test_set.get(sample).cav_mask]
98 | cav_mask = test_set.get(sample).cav_mask
99 | visualize_traj((prediction+orig.unsqueeze(1))[cav_mask], (gt+orig)[cav_mask], prob[cav_mask], best_mode=True)
100 | for i in range(cav_ori.shape[0]):
101 | # plt.plot(cav_ori[i,:,0], cav_ori[i,:,1], 'r')
102 | # circle_cav = plt.Circle((cav_ori[i,0,0], cav_ori[i,0,1]),
103 | # 1, color='r')
104 | l1, = plt.plot(cav_ori[i,0,0], cav_ori[i,0,1], marker=(4, 0, 90), color="r",markersize=5)
105 | circle_commu = plt.Circle((cav_ori[i,0,0], cav_ori[i,0,1]),
106 | 65, color='honeydew')
107 | circle_sensor = plt.Circle((cav_ori[i,0,0], cav_ori[i,0,1]),
108 | 40, color='bisque')
109 | plt.gca().add_patch(circle_commu)
110 | plt.gca().add_patch(circle_sensor)
111 | # plt.gca().add_patch(circle_cav)
112 | #ncv
113 | ncv_ori = (gt+orig)[test_set.get(sample).sensor_mask]
114 | ncv_mask = test_set.get(sample).sensor_mask
115 | # for i in range(ncv_ori.shape[0]):
116 | for i in [0,2,3,4,5,7]:
117 | # plt.plot(ncv_ori[i,:,0], ncv_ori[i,:,1], c='orange')
118 | l2, = plt.plot(ncv_ori[i,0,0], ncv_ori[i,0,1], marker="o",color="darkorange",markersize=5)
119 | visualize_traj((prediction+orig.unsqueeze(1))[ncv_mask][i].unsqueeze(0), (gt+orig)[ncv_mask][i].unsqueeze(0), prob[ncv_mask][i].unsqueeze(0), best_mode=True)
120 | circle_ncv = plt.Circle((ncv_ori[i,0,0], ncv_ori[i,0,1]),
121 | 1, color='orange')
122 | plt.gca().add_patch(circle_ncv)
123 | #cv
124 | cv_ori = (gt+orig)[test_set.get(sample).commu_mask]
125 | cv_mask = test_set.get(sample).commu_mask
126 |
127 | for i in range(1, cv_ori.shape[0]):
128 | # plt.plot(cv_ori[i,:,0], cv_ori[i,:,1], 'g')
129 | l3, = plt.plot(cv_ori[i,0,0], cv_ori[i,0,1], marker="*",color="g",markersize=5)
130 | visualize_traj((prediction+orig.unsqueeze(1))[cv_mask][i].unsqueeze(0), (gt+orig)[cv_mask][i].unsqueeze(0), prob[cv_mask][i].unsqueeze(0), best_mode=True)
131 | # circle_cv = plt.Circle((cv_ori[i,0,0], cv_ori[i,0,1]),
132 | # 1, color='g')
133 | # plt.gca().add_patch(circle_cv)
134 |
135 | # #hist_cav
136 | # positions_cav = test_set.get(sample).positions[[test_set.get(sample).cav_mask]]
137 | # for i in range(positions_cav.shape[0]):
138 | # plt.plot(positions_cav[i,20:50,0], positions_cav[i,20:50,1], 'r--',linewidth=2)
139 | # #hist_ncv
140 | # positions_ncv = test_set.get(sample).positions[[test_set.get(sample).sensor_mask]]
141 | # for i in range(positions_ncv.shape[0]):
142 | # plt.plot(positions_ncv[i,20:50,0], positions_ncv[i,20:50,1], c='orange', linestyle='--',linewidth=2)
143 | # #hist_cv
144 | # positions_cv = test_set.get(sample).positions[[test_set.get(sample).commu_mask]]
145 | # for i in range(1, positions_cv.shape[0]):
146 | # plt.plot(positions_cv[i,20:50,0], positions_cv[i,20:50,1], 'g--',linewidth=2)
147 | # # visualize_gt_traj(gt+orig)
148 | # # visualize_pred_traj((prediction+orig.unsqueeze(1))[mask], prob[mask])
149 | plt.axis('equal')
150 | plt.axis('off')
151 | # # plt.ylim((-60,80))
152 | # # plt.xlim((-80,60))
153 | # # plt.xlabel("position_x(m)")
154 | # # plt.ylabel("position_y(m)")
155 | # # plt.title('mpr={}'.format(mpr))
156 | sample=452
157 | prediction_viz(sample, batch_size, test4, predictions_cav4_cav4, probs_cav4_cav4, batch_cav4_cav4, mask_cav4_cav4, mpr=0.4)
158 |
159 |
--------------------------------------------------------------------------------