├── model
├── __init__.py
├── aggregator.py
├── InteractPlanner.py
├── decoder.py
└── encoder.py
├── utils
├── __init__.py
├── waymo_tf_utils.py
├── net_utils.py
├── plan_utils.py
├── occupancy_grid_utils.py
├── train_utils.py
└── occupancy_render_utils.py
├── README.md
├── testing.py
├── training.py
├── planner.py
├── metric.py
└── preprocess.py
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/model/aggregator.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | import sys
5 | import os
6 |
7 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
8 | sys.path.append(os.path.dirname(SCRIPT_DIR))
9 |
10 |
11 | def temporal_upsample(inputs, size=(2, 2), mode='nearest'):
12 | assert len(inputs.shape) == 5
13 | b, c, t, h, w = inputs.shape
14 | inputs = inputs.permute(0, 2, 1, 3, 4).contiguous().reshape(b*t, c, h, w)
15 | inputs = F.interpolate(input=inputs, size=(2*h, 2*w), mode=mode, align_corners=True)
16 | inputs = inputs.reshape(b, t, c, 2*h, 2*w).permute(0, 2, 1, 3, 4)
17 | return inputs
18 |
19 | class CrossAttention(nn.Module):
20 | def __init__(self, dim=384, heads=8, dropout=0.1):
21 | super(CrossAttention, self).__init__()
22 | self.cross_attention = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True,)
23 | self.ffn = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(0.1), nn.Linear(dim*4, dim), nn.Dropout(0.1))
24 | self.norm_0 = nn.LayerNorm(dim)
25 | self.norm_1 = nn.LayerNorm(dim)
26 |
27 | def forward(self, query, key, mask):
28 | output, _ = self.cross_attention(query, key, key, key_padding_mask=mask)
29 | attention_output = self.norm_0(output + query)
30 | n_output = self.ffn(attention_output)
31 | return self.norm_1(n_output + attention_output)
--------------------------------------------------------------------------------
/model/InteractPlanner.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.append(os.path.dirname(SCRIPT_DIR))
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from .encoder import *
11 | from .decoder import *
12 |
13 | class InteractPlanner(nn.Module):
14 | def __init__(self, config, dim=256, enc_layer=2, heads=8, dropout=0.1,
15 | timestep=5, decoder_dim=384, fpn_len=2, use_dynamic=True,
16 | large_scale=True,flow_pred=False):
17 |
18 | super(InteractPlanner, self).__init__()
19 |
20 | self.visual_encoder = VisualEncoder(config, input_resolution=(256, 256),
21 | patch_size=2,use_deformable_block=False,large_scale=large_scale)
22 |
23 | self.vector_encoder = VectorEncoder(dim, enc_layer, heads, dropout)
24 |
25 | self.ogm_decoder = STrajNetDecoder(decoder_dim, heads, len_fpn=fpn_len, timestep=timestep, dropout=dropout,
26 | flow_pred=flow_pred,large_scale=large_scale)
27 |
28 | self.plan_decoder = PlanningDecoder(dim, heads, dropout, use_dynamic=use_dynamic,
29 | timestep=timestep)
30 |
31 |
32 | def forward(self, inputs):
33 | encoder_outputs = self.vector_encoder(inputs)
34 | bev_list, _ = self.visual_encoder(inputs)
35 | encoder_outputs.update({
36 | 'bev_feature' : bev_list[-1]
37 | })
38 | bev_pred = self.ogm_decoder(bev_list, encoder_outputs['encodings'][:, 0], encoder_outputs['masks'][:, 0])
39 | traj, score = self.plan_decoder(encoder_outputs)
40 | return bev_pred, traj, score
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OPGP
2 |
3 | This repo is the implementation of:
4 |
5 | **Occupancy Prediction-Guided Neural Planner for Autonomous Driving**
6 |
[Haochen Liu](https://scholar.google.com/citations?user=iizqKUsAAAAJ&hl=en), [Zhiyu Huang](https://mczhi.github.io/), [Chen Lv](https://scholar.google.com/citations?user=UKVs2CEAAAAJ&hl=en)
7 |
[AutoMan Research Lab, Nanyang Technological University](https://lvchen.wixsite.com/automan)
8 |
**[[Paper]](https://ieeexplore.ieee.org/abstract/document/10422055/)** **[[arXiv]](https://arxiv.org/abs/2305.03303)** **[[Zhihu]](https://zhuanlan.zhihu.com/p/680304839)**
9 |
10 | - Code is now released 😀!
11 |
12 | ## Overview
13 | In this repository, you can expect to find the following features 🤩:
14 | * Pipelines for data process and training
15 | * Open-loop evaluations
16 |
17 | Not included 😵:
18 | * Model weights (Due to license from WOMD)
19 | * Real-time planning (Codes are not optimized for real-time performance)
20 |
21 | ## Experiment Pipelines
22 |
23 | ### Dataset and Environment
24 |
25 |
26 | - Downloading [Waymo Open Motion Dataset](https://waymo.com/open/download/) v1.1. Utilize data from ```scenario/training_20s``` for train set, and data from ```scenario/validation``` for val & test.
27 |
28 | - Clone this repository and install required packages.
29 |
30 | - **[NOTED]** For [theseus](https://github.com/facebookresearch/theseus) library, you may build from scratch and add system PATH in ```planner.py```
31 |
32 | ### Data Process
33 |
34 | - Preprocess data for training & testing:
35 |
36 | ```
37 | python preprocess.py \
38 | --root_dir path/to/your/Waymo_Dataset/scenario/ \
39 | --save_dir path/to/your/processed_data/ \
40 | --processes=16
41 | ```
42 |
43 | - You may also refer to [Waymo_candid_list](https://github.com/MCZhi/GameFormer/blob/main/open_loop_planning/waymo_candid_list.csv) for more interactive and safety-critical scenarios filtered in ```scenario/validation```
44 |
45 | ### Training & Testing
46 |
47 | - Train & Eval the model using the command:
48 |
49 | ```
50 | python -m torch.distributed.launch \
51 | --nproc_per_node 1 \ # number of gpus
52 | --master_port 16666 \
53 | training.py \
54 | --data_dir path/to/your/processed_data/ \
55 | --save_dir path/to/save/your/logs/
56 | ```
57 |
58 | - Conduct Open-loop Testing using the command:
59 |
60 | ```
61 | python testing.py \
62 | --data_dir path/to/your/testing_data/ \
63 | --model_dir path/to/pretrained/model/
64 | ```
65 |
66 | ## Citation
67 | If you find this repository useful for your research, please consider giving us a star 🌟 and citing our paper.
68 |
69 | ```angular2html
70 | @inproceedings{liu2023occupancy,
71 | title={Occupancy prediction-guided neural planner for autonomous driving},
72 | author={Liu, Haochen and Huang, Zhiyu and Lv, Chen},
73 | booktitle={2023 IEEE 26th International Conference on Intelligent Transportation Systems (ITSC)},
74 | pages={4859--4865},
75 | year={2023},
76 | organization={IEEE}
77 | }
78 |
--------------------------------------------------------------------------------
/testing.py:
--------------------------------------------------------------------------------
1 |
2 | import csv
3 | import argparse
4 | import time
5 | import sys
6 |
7 | import torch
8 | from torch import optim
9 | from torch.utils.data import DataLoader
10 |
11 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2
12 | from google.protobuf import text_format
13 |
14 | from model.InteractPlanner import InteractPlanner
15 | from utils.net_utils import *
16 | from metric import TestingMetrics
17 |
18 | from planner import Planner
19 |
20 |
21 | def test_modal_selection(traj, score, targets, level=-1):
22 | gt_future = targets['ego_future_states']
23 | if isinstance(traj, list):
24 | traj, score = traj[level], score[level]
25 | gt_modes = torch.argmax(score, dim=-1)
26 | B = traj.shape[0]
27 | selected_trajs = traj[torch.arange(B)[:, None], gt_modes.unsqueeze(-1)].squeeze(1)
28 | return selected_trajs, gt_modes
29 |
30 | def flow_warp(bev_pred, current_ogm, occ=False):
31 | ogm_pred, pred_flow = bev_pred[:, :4].sigmoid(), bev_pred[:, -2:]
32 | if not occ:
33 | ogm_pred = torch.cat([(ogm_pred[:,0] + ogm_pred[:,-1]).clamp(0,1).unsqueeze(1),
34 | ogm_pred[:, 1:2], ogm_pred[:, 2:3]],dim=1)
35 |
36 | b, c, t, h, w = pred_flow.shape
37 | pred_flow = pred_flow.permute(0, 2, 3, 4, 1)
38 | x = torch.linspace(0, w - 1, w)
39 | y = torch.linspace(0, h - 1, h)
40 | grid = torch.stack(torch.meshgrid([x, y])).transpose(1, 2)
41 | grid = grid.permute(1, 2, 0).unsqueeze(0).unsqueeze(1).expand(b, t, -1, -1, -1).to(local_rank)
42 |
43 | flow_grid = grid + pred_flow + 0.5
44 | flow_grid = 2 * flow_grid / (h) - 1
45 |
46 | warped_flow = []
47 | for i in range(flow_grid.shape[1]):
48 | flow_origin_ogm = current_ogm if i==0 else ogm_pred[:, :, i-1]
49 | wf = F.grid_sample(flow_origin_ogm, flow_grid[:, i], mode='nearest', align_corners=False)
50 | warped_flow.append(wf)
51 |
52 | warped_flow = torch.stack(warped_flow, dim=2)
53 | warped_ogm = ogm_pred * warped_flow
54 | return warped_ogm
55 |
56 | def model_testing(valid_data):
57 |
58 | epoch_metrics = TestingMetrics(config)
59 | model.eval()
60 | current = 0
61 | start_time = time.time()
62 | size = len(valid_data)
63 |
64 | print(f'Testing....')
65 | for batch in valid_data:
66 | # prepare data
67 | inputs, target = batch_to_dict(batch, local_rank, use_flow)
68 |
69 | # query the model
70 | with torch.no_grad():
71 | bev_pred, traj, score = model(inputs)
72 | selected_trajs, gt_modes = test_modal_selection(traj, score, target, level=0)
73 | selected_ref = target['ref_line']
74 | b, h, w, d = inputs['hist_ogm'].shape
75 | types = inputs['hist_ogm'].reshape(b, h, w, d//3, 3)
76 | current_ogm = types[:, :, :, -1, :].permute(0, 3, 1, 2)
77 | type_mask = types[..., -1, :].sum(-2).sum(-2) > 0
78 | warped_ogm = flow_warp(bev_pred, current_ogm)
79 |
80 | planning_inputs = planner.preprocess(inputs['ego_state'], selected_trajs[:, :50, :2],
81 | selected_ref, warped_ogm[:, :, :5], type_mask, config,left=True)
82 | xy_plan = planner.plan(planning_inputs, selected_ref, inputs['ego_state'])
83 |
84 | epoch_metrics.update(xy_plan, score, warped_ogm, gt_modes, target, inputs['ego_state'][:,-1,:])
85 |
86 | current += args.batch_size
87 | sys.stdout.write(f'\rVal: [{current:>6d}/{size*args.batch_size:>6d}]|{(time.time()-start_time)/current:>.4f}s/sample')
88 | sys.stdout.flush()
89 |
90 | print('Calculating Open Loop Planning Results...')
91 | print(epoch_metrics.result())
92 |
93 |
94 | if __name__ == "__main__":
95 |
96 | parser = argparse.ArgumentParser()
97 | parser.add_argument("--local_rank", type=int)
98 | parser.add_argument("--batch_size", type=int,default=4)
99 | parser.add_argument("--dim", type=int,default=256)
100 | parser.add_argument("--use_flow", type=bool, action='store_true', default=True,
101 | help='whether to use flow warp')
102 | parser.add_argument("--data_dir", type=str, default='',
103 | help='path to load preprocessed data')
104 | parser.add_argument("--model_dir", type=str, default='',
105 | help='path to load pretrained IL model')
106 |
107 | args = parser.parse_args()
108 | local_rank = args.local_rank
109 |
110 | config = occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig()
111 | config_text = f"""
112 | num_past_steps: {10}
113 | num_future_steps: {50}
114 | num_waypoints: {5}
115 | cumulative_waypoints: {'true'}
116 | normalize_sdc_yaw: true
117 | grid_height_cells: {128}
118 | grid_width_cells: {128}
119 | sdc_y_in_grid: {int(128*0.75)}
120 | sdc_x_in_grid: {64}
121 | pixels_per_meter: {1.6}
122 | agent_points_per_side_length: 48
123 | agent_points_per_side_width: 16
124 | """
125 |
126 | text_format.Parse(config_text, config)
127 |
128 | use_flow = args.use_flow
129 |
130 | model = InteractPlanner(config, dim=args.dim, enc_layer=2, heads=8, dropout=0.1,
131 | timestep=5, decoder_dim=384, fpn_len=2, flow_pred=use_flow)
132 |
133 | local_rank = torch.device('cuda')
134 | print(local_rank)
135 |
136 | model = model.to(local_rank)
137 |
138 | planner = DiffPlanner(device=local_rank,g_length=1200,g_width=60, horizon=5,test_iters=50)
139 |
140 | assert args.model_dir != '', 'you must load a pretrained weights for OL testing!'
141 | kw_dict = {}
142 | for k,v in torch.load(args.model_dir,map_location=torch.device('cpu')).items():
143 | kw_dict[k[7:]] = v
144 | model.load_state_dict(kw_dict)
145 | continue_ep = int(args.model_dir.split('_')[-3]) - 1
146 | print(f'model loaded!:epoch {continue_ep + 1}')
147 |
148 | test_dataset = DrivingData(args.data_dir + f'*.npz', use_flow=True)
149 |
150 | training_size = len(test_dataset)
151 | print(f'Length test: {training_size}')
152 |
153 | test_data = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=8)
154 | model_testing(test_data)
155 |
--------------------------------------------------------------------------------
/model/decoder.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.append(os.path.dirname(SCRIPT_DIR))
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from .aggregator import CrossAttention, temporal_upsample
12 |
13 | import math
14 |
15 |
16 | class DecodeUpsample(nn.Module):
17 | def __init__(self, input_dim, kernel, timestep):
18 | super(DecodeUpsample, self).__init__()
19 | self.conv = nn.Sequential(nn.Conv3d(input_dim, input_dim//2, (1, kernel, kernel), padding='same'), nn.GELU())
20 | self.residual_conv = nn.Sequential(nn.Conv3d(input_dim//2, input_dim//2, (timestep, 1, 1)), nn.GELU())
21 |
22 | def forward(self, inputs, res):
23 | #b, t, c, h, w = inputs.shape
24 | inputs = temporal_upsample(inputs, mode='bilinear')
25 | inputs = self.conv(inputs) + self.residual_conv(res)
26 | return inputs
27 |
28 |
29 | class PredFinalDecoder(nn.Module):
30 | def __init__(self, input_dim, kernel=3, large_scale=False,planning=True, use_flow=False):
31 | super(PredFinalDecoder, self).__init__()
32 | '''
33 | input h,w = 128
34 | dual deconv for flow and ogms
35 | '''
36 | self.input_dim = input_dim
37 | if large_scale:
38 | self.ogm_conv = nn.Conv3d(input_dim, 4, (1, kernel, kernel), padding='same')
39 | self.flow_conv = nn.Conv3d(input_dim, 2, (1, kernel, kernel), padding='same')
40 | else:
41 | self.ogm_conv = nn.Sequential(nn.Conv3d(input_dim, input_dim//2, (1, kernel, kernel), padding='same'),
42 | nn.GELU(), nn.Upsample(scale_factor=(1, 2, 2)),
43 | nn.Conv3d(input_dim//2, 4 if planning else 2, (1, kernel, kernel), padding='same'))
44 |
45 | self.flow_conv = nn.Sequential(nn.Conv3d(input_dim, input_dim//2, (1, kernel, kernel), padding='same'),
46 | nn.GELU(), nn.Upsample(scale_factor=(1, 2, 2)),
47 | nn.Conv3d(input_dim//2, 2, (1, kernel, kernel), padding='same'))
48 |
49 | def forward(self, inputs):
50 | ogms = self.ogm_conv(inputs)
51 | flows = self.flow_conv(inputs)
52 | return torch.cat([ogms, flows], dim=1)
53 |
54 | class STrajNetDecoder(nn.Module):
55 | def __init__(self, dim=384, heads=8, len_fpn=2, kernel=3, timestep=5, dropout=0.1,
56 | flow_pred=False, large_scale=False):
57 | super(STrajNetDecoder, self).__init__()
58 |
59 | self.timestep = timestep
60 | self.len_fpn = len_fpn
61 | self.residual_conv = nn.Sequential(nn.Conv3d(dim, dim, (timestep, 1, 1)), nn.GELU())
62 | self.aggregator = nn.ModuleList([CrossAttention(dim, heads, dropout) for _ in range(timestep)])
63 |
64 | self.actor_layer = nn.Sequential(nn.Linear(256, dim), nn.GELU())
65 |
66 | self.fpn_decoders = nn.ModuleList([
67 | DecodeUpsample(dim // (2 ** i), kernel, timestep) for i in range(len_fpn)
68 | ])
69 |
70 | self.upsample = nn.Upsample(scale_factor=(1, 2, 2))
71 | if flow_pred:
72 | self.output_conv = PredFinalDecoder(dim // (2 ** len_fpn),large_scale=large_scale)
73 | else:
74 | self.output_conv = nn.Conv3d(dim // (2 ** len_fpn), 4, (1, kernel, kernel), padding='same')
75 |
76 | def forward(self, output_list, actor, actor_mask):
77 | # Aggregations:
78 | enc_output = output_list[-1]
79 | b, c, h, w = enc_output.shape
80 | res_output = enc_output.unsqueeze(2).expand(-1, -1, self.timestep, -1, -1)
81 | enc_output = enc_output.reshape(b, c, h*w).permute(0, 2, 1)
82 | #[b, t, h*w, c]
83 | enc_output = enc_output.unsqueeze(1).expand(-1, self.timestep, -1, -1)
84 |
85 | actor = self.actor_layer(actor)
86 | actor_mask[:, 0] = False
87 | agg_output = torch.stack([self.aggregator[i](enc_output[:, i], actor, actor_mask) for i in range(self.timestep)], dim=2)
88 | agg_output = agg_output.permute(0, 3, 2, 1).reshape(b, -1, self.timestep, h, w)
89 | decode_output = agg_output + self.residual_conv(res_output)
90 | # fpn decoding:
91 | for j in range(self.len_fpn):
92 | decode_output = self.fpn_decoders[j](decode_output, output_list[-2-j].unsqueeze(2).expand(-1, -1, self.timestep, -1, -1))
93 | decode_output = self.output_conv(self.upsample(decode_output))
94 |
95 | #[b, t, c, h, w]
96 | return decode_output
97 |
98 |
99 |
100 | class EgoPlanner(nn.Module):
101 | def __init__(self, dim=256, use_dynamic=False,timestep=5):
102 | super(EgoPlanner,self).__init__()
103 | self.timestep = timestep
104 | self.out_step = 2
105 | self.planner = nn.Sequential(nn.Linear(256, 128), nn.ELU(),nn.Dropout(0.1),
106 | nn.Linear(128, timestep*self.out_step *10))
107 | self.scorer = nn.Sequential(nn.Linear(256, 128), nn.ELU(),nn.Dropout(0.1),
108 | nn.Linear(128, 1))
109 | self.use_dynamic = use_dynamic
110 |
111 | def physical(self, action, last_state):
112 | d_t = 0.1
113 | d_v = action[:, :, :, 0].clamp(-5, 5)
114 | d_theta = action[:, :, :, 1].clamp(-1, 1)
115 |
116 | x_0 = last_state[:, 0]
117 | y_0 = last_state[:, 1]
118 | theta_0 = last_state[:, 4]
119 | v_0 = torch.hypot(last_state[:, 2], last_state[:, 3])
120 |
121 | v = v_0.reshape(-1,1,1) + torch.cumsum(d_v * d_t, dim=-1)
122 | v = torch.clamp(v, min=0)
123 | theta = theta_0.reshape(-1,1,1) + torch.cumsum(d_theta * d_t, dim=-1)
124 | theta = torch.fmod(theta, 2*torch.pi)
125 | x = x_0.reshape(-1,1,1) + torch.cumsum(v * torch.cos(theta) * d_t, dim=-1)
126 | y = y_0.reshape(-1,1,1) + torch.cumsum(v * torch.sin(theta) * d_t, dim=-1)
127 | traj = torch.stack([x, y, theta], dim=-1)
128 | return traj
129 |
130 | def forward(self, features, current_state):
131 | traj = self.planner(features).reshape(-1, 9, self.timestep*10, self.out_step)
132 | if self.use_dynamic:
133 | traj = self.physical(traj, current_state)
134 | score = self.scorer(features)
135 | return traj, score
136 |
137 |
138 | class PlanningDecoder(nn.Module):
139 | def __init__(self, dim=256, heads=8, dropout=0.1, use_dynamic=False, timestep=5):
140 | super(PlanningDecoder,self).__init__()
141 |
142 | self.region_embed = nn.Parameter(torch.zeros(1, 9, 256), requires_grad=True)
143 | nn.init.kaiming_uniform_(self.region_embed)
144 |
145 | self.self_attention = nn.MultiheadAttention(dim, heads, dropout, batch_first=True)
146 | self.cross_attention = nn.MultiheadAttention(dim, heads, dropout, batch_first=True)
147 | self.bev_attention = nn.MultiheadAttention(dim, heads, dropout, batch_first=True)
148 | self.bev_layer = nn.Sequential(nn.Linear(384, dim), nn.GELU())
149 |
150 | self.ffn = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(0.1), nn.Linear(dim*4, dim), nn.Dropout(0.1))
151 |
152 | self.norm_0 = nn.LayerNorm(dim)
153 | self.norm_1 = nn.LayerNorm(dim)
154 |
155 | self.planner = EgoPlanner(dim, use_dynamic, timestep)
156 |
157 | def forward(self, inputs):
158 | #encode the poly plan as ref:
159 | b = inputs['encodings'].shape[0]
160 | plan_query = self.region_embed.expand(b,-1,-1)
161 | self_plan_query,_ = self.self_attention(plan_query, plan_query, plan_query)
162 | #cross attention with bev and map-actors:
163 | map_actors = inputs['encodings'][:, 0]
164 | map_actors_mask = inputs['masks'][:, 0]
165 | map_actors_mask[:,0] = False
166 | dense_feature,_ = self.cross_attention(self_plan_query, map_actors, map_actors, key_padding_mask=map_actors_mask)
167 | b, c, h, w = inputs['bev_feature'].shape
168 | bev_feature = inputs['bev_feature'].reshape(b, c, h*w).permute(0, 2, 1)
169 | bev_feature = self.bev_layer(bev_feature)
170 | bev_feature,_ = self.bev_attention(self_plan_query, bev_feature, bev_feature)
171 |
172 | attention_feature = self.norm_0(dense_feature + bev_feature + plan_query)
173 | output_feature = self.ffn(attention_feature) + attention_feature
174 | output_feature = self.norm_1(output_feature)
175 |
176 | # output:
177 | ego_current = inputs['actors'][:, 0, -1, :]
178 | traj, score = self.planner(output_feature, ego_current)
179 | return traj, score.squeeze(-1)
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import argparse
3 | import time
4 | import sys
5 |
6 | import torch
7 | from torch import optim
8 | import torch.distributed as dist
9 | from torch.utils.data.distributed import DistributedSampler
10 | from torch.nn.parallel import DistributedDataParallel as DDP
11 | from torch.utils.data import DataLoader
12 |
13 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2
14 | from google.protobuf import text_format
15 |
16 | from model.InteractPlanner import InteractPlanner
17 | from utils.net_utils import *
18 | from metric import TrainingMetrics, ValidationMetrics
19 |
20 |
21 | # define model training epoch
22 | def training_epoch(train_data, optimizer, epoch, scheduler):
23 |
24 | model.train()
25 | current = 0
26 | start_time = time.time()
27 | size = len(train_data)
28 | epoch_loss = []
29 | train_metric = TrainingMetrics()
30 | i = 0
31 | for batch in train_data:
32 | # prepare data
33 | inputs, target = batch_to_dict(batch, local_rank , use_flow=use_flow)
34 |
35 | optimizer.zero_grad()
36 | # query the model
37 | bev_pred, traj, score = model(inputs)
38 | actor_loss, occ_loss, flow_loss = occupancy_loss(bev_pred, target, use_flow=use_flow)
39 | il_loss, _, gt_modes = imitation_loss(traj, score, target, args.use_planning)
40 |
41 | loss = il_loss + 100*(actor_loss + occ_loss)
42 | if use_flow:
43 | loss += flow_loss
44 |
45 | loss.backward()
46 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
47 | optimizer.step()
48 |
49 | current += args.batch_size
50 | epoch_loss.append(loss.item())
51 | if isinstance(traj, list):
52 | traj, score = traj[-1], score[-1]
53 | ade, fde, l_il, l_ogm = train_metric.update(traj, score, gt_modes, target, il_loss, actor_loss + occ_loss, bev_pred)
54 |
55 | if dist.get_rank() == 0:
56 | sys.stdout.write(f"\rTrain: [{current:>6d}/{size*args.batch_size:>6d}]|Loss: {np.mean(epoch_loss):>.4f}-{l_il:>.4f}-{l_ogm:>.4f}|ADE:{ade:>.4f}-FDE:{fde:>.4f}|{(time.time()-start_time)/current:>.4f}s/sample")
57 | sys.stdout.flush()
58 |
59 | scheduler.step(epoch + i/size)
60 | i += 1
61 |
62 | results = train_metric.result()
63 |
64 | return np.mean(epoch_loss), results
65 |
66 | # define model validation epoch
67 | def validation_epoch(valid_data,epoch):
68 | epoch_metrics = ValidationMetrics()
69 | model.eval()
70 | current = 0
71 | start_time = time.time()
72 | size = len(valid_data)
73 | epoch_loss = []
74 |
75 | print(f'Validation...Epoch{epoch+1}')
76 | for batch in valid_data:
77 | # prepare data
78 | inputs, target = batch_to_dict(batch, local_rank, use_flow=use_flow)
79 |
80 | # query the model
81 | with torch.no_grad():
82 | bev_pred, traj, score = model(inputs)
83 | actor_loss, occ_loss, flow_loss = occupancy_loss(bev_pred, target, use_flow=use_flow)
84 | il_loss, _, gt_modes = imitation_loss(traj, score, target, args.use_planning)
85 | loss = il_loss + 100*(actor_loss + occ_loss)
86 | if use_flow:
87 | loss += flow_loss
88 | # compute metrics
89 | epoch_loss.append(loss.item())
90 | if isinstance(traj, list):
91 | traj, score = traj[-1], score[-1]
92 | ade, fde, l_il, l_ogm, ogm_auc,_, occ_auc,_ = epoch_metrics.update(traj, score, bev_pred,
93 | gt_modes, target, il_loss, actor_loss + occ_loss)
94 |
95 | current += args.batch_size
96 | if dist.get_rank() == 0:
97 | sys.stdout.write(f"\r\Val: [{current:>6d}/{size*args.batch_size:>6d}]|Loss: {np.mean(epoch_loss):>.4f}-{l_il:>.4f}-{l_ogm:>.4f}|ADE:{ade:>.4f}-FDE:{fde:>.4f}{(time.time()-start_time)/current:>.4f}s/sample")
98 | sys.stdout.flush()
99 |
100 | # process metrics
101 | epoch_metrics = epoch_metrics.result()
102 |
103 | return epoch_metrics,np.mean(epoch_loss)
104 |
105 | # Define model training process
106 | def model_training(train_data, valid_data, epochs, save_dir):
107 | # define optimizer and loss function
108 | optimizer = optim.AdamW(model.parameters(), lr=args.lr)
109 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2, eta_min=1e-6)
110 |
111 | for epoch in range(epochs):
112 | if dist.get_rank() == 0:
113 | print(f"Epoch {epoch+1}/{epochs}")
114 |
115 | if epoch<=continue_ep and continue_ep!=0:
116 | scheduler.step()
117 | continue
118 |
119 | train_data.sampler.set_epoch(epoch)
120 | valid_data.sampler.set_epoch(epoch)
121 |
122 | train_loss,train_res = training_epoch(train_data, optimizer, epoch, scheduler)
123 | valid_metrics,val_loss = validation_epoch(valid_data,epoch)
124 |
125 | # save to training log
126 | log = {'epoch': epoch+1, 'loss': train_loss, 'lr': optimizer.param_groups[0]['lr']}
127 | log.update(valid_metrics)
128 |
129 | if dist.get_rank() == 0:
130 | if epoch == 0:
131 | with open(save_dir + f'train_log.csv', 'a') as csv_file:
132 | writer = csv.writer(csv_file)
133 | writer.writerow(log.keys())
134 | writer.writerow(log.values())
135 | else:
136 | with open(save_dir + f'train_log.csv', 'a') as csv_file:
137 | writer = csv.writer(csv_file)
138 | writer.writerow(log.values())
139 |
140 | # adjust learning rate
141 | scheduler.step()
142 |
143 | # save model at the end of epoch
144 | if dist.get_rank() == 0:
145 | torch.save(model.state_dict(), save_dir+f'model_{epoch+1}_{train_loss:4f}_{val_loss:4f}.pth')
146 |
147 |
148 | if __name__ == "__main__":
149 |
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument("--local_rank", type=int)
152 | parser.add_argument("--batch_size", type=int, default=8)
153 | parser.add_argument("--dim", type=int, default=256)
154 | parser.add_argument("--lr", type=float, default=1e-4)
155 | parser.add_argument("--epochs", type=int, default=30)
156 | parser.add_argument("--use_flow", type=bool, action='store_true', default=True,
157 | help='whether to use flow warp')
158 |
159 | parser.add_argument("--save_dir", type=str, default='',help='path to save logs')
160 | parser.add_argument("--data_dir", type=str, default='',
161 | help='path to load preprocessed train & val sets')
162 | parser.add_argument("--model_dir", type=str, default='',
163 | help='path to load models for continue training')
164 |
165 | args = parser.parse_args()
166 | local_rank = args.local_rank
167 |
168 | use_flow = args.use_flow
169 |
170 | config = occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig()
171 | config_text = f"""
172 | num_past_steps: {10}
173 | num_future_steps: {50}
174 | num_waypoints: {5}
175 | cumulative_waypoints: {'false'}
176 | normalize_sdc_yaw: true
177 | grid_height_cells: {128}
178 | grid_width_cells: {128}
179 | sdc_y_in_grid: {int(128*0.75)}
180 | sdc_x_in_grid: {64}
181 | pixels_per_meter: {1.6}
182 | agent_points_per_side_length: 48
183 | agent_points_per_side_width: 16
184 | """
185 |
186 | text_format.Parse(config_text, config)
187 |
188 | model = InteractPlanner(config, dim=args.dim, enc_layer=2, heads=8, dropout=0.1,
189 | timestep=5, decoder_dim=384, fpn_len=2, flow_pred=use_flow)
190 |
191 | save_dir = args.save_dir + f"models/"
192 | os.makedirs(save_dir,exist_ok=True)
193 |
194 | torch.cuda.set_device(local_rank)
195 | dist.init_process_group(backend='nccl')
196 |
197 | model = model.to(local_rank)
198 | if args.model_dir!= '':
199 | kw_dict = {}
200 | for k,v in torch.load(save_dir + args.load_dir,map_location='cpu').items():
201 | kw_dict[k[7:]] = v
202 | model.load_state_dict(kw_dict)
203 | continue_ep = int(args.load_dir.split('_')[-3]) - 1
204 | print(f'model loaded!:epoch {continue_ep + 1}')
205 | else:
206 | continue_ep = 0
207 |
208 | model = DDP(model, device_ids=[local_rank], output_device=local_rank)
209 |
210 | train_dataset = DrivingData(args.data_dir + f'train/*.npz',use_flow=use_flow)
211 | valid_dataset = DrivingData(args.data_dir + f'valid/*.npz',use_flow=use_flow)
212 |
213 | training_size = len(train_dataset)
214 | valid_size = len(valid_dataset)
215 | if dist.get_rank() == 0:
216 | print(f'Length train: {training_size}; Valid: {valid_size}')
217 |
218 | train_sampler = DistributedSampler(train_dataset)
219 | valid_sampler = DistributedSampler(valid_dataset, shuffle=False)
220 | train_data = DataLoader(train_dataset, batch_size=args.batch_size,
221 | sampler=train_sampler, num_workers=16)
222 | valid_data = DataLoader(valid_dataset, batch_size=args.batch_size,
223 | sampler=valid_sampler, num_workers=4)
224 |
225 | model_training(train_data, valid_data, args.epochs, save_dir)
226 |
--------------------------------------------------------------------------------
/utils/waymo_tf_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | #### Example field definition
4 | # Features of road graph.
5 | roadgraph_features = {
6 | 'roadgraph_samples/dir':
7 | tf.io.FixedLenFeature([20000, 3], tf.float32, default_value=None),
8 | 'roadgraph_samples/id':
9 | tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
10 | 'roadgraph_samples/type':
11 | tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
12 | 'roadgraph_samples/valid':
13 | tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None),
14 | 'roadgraph_samples/xyz':
15 | tf.io.FixedLenFeature([20000, 3], tf.float32, default_value=None),
16 | }
17 |
18 | # Features of other agents.
19 | state_features = {
20 | 'state/id':
21 | tf.io.FixedLenFeature([128], tf.float32, default_value=None),
22 | 'state/type':
23 | tf.io.FixedLenFeature([128], tf.float32, default_value=None),
24 | 'state/is_sdc':
25 | tf.io.FixedLenFeature([128], tf.int64, default_value=None),
26 | 'state/tracks_to_predict':
27 | tf.io.FixedLenFeature([128], tf.int64, default_value=None),
28 | 'state/current/bbox_yaw':
29 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
30 | 'state/current/height':
31 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
32 | 'state/current/length':
33 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
34 | 'state/current/timestamp_micros':
35 | tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
36 | 'state/current/valid':
37 | tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None),
38 | 'state/current/vel_yaw':
39 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
40 | 'state/current/velocity_x':
41 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
42 | 'state/current/velocity_y':
43 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
44 | 'state/current/speed':
45 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
46 | 'state/current/width':
47 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
48 | 'state/current/x':
49 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
50 | 'state/current/y':
51 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
52 | 'state/current/z':
53 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None),
54 | 'state/future/bbox_yaw':
55 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
56 | 'state/future/height':
57 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
58 | 'state/future/length':
59 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
60 | 'state/future/timestamp_micros':
61 | tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
62 | 'state/future/valid':
63 | tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None),
64 | 'state/future/vel_yaw':
65 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
66 | 'state/future/velocity_x':
67 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
68 | 'state/future/velocity_y':
69 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
70 | 'state/future/width':
71 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
72 | 'state/future/x':
73 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
74 | 'state/future/y':
75 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
76 | 'state/future/z':
77 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None),
78 | 'state/past/bbox_yaw':
79 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
80 | 'state/past/height':
81 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
82 | 'state/past/length':
83 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
84 | 'state/past/timestamp_micros':
85 | tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
86 | 'state/past/valid':
87 | tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None),
88 | 'state/past/vel_yaw':
89 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
90 | 'state/past/velocity_x':
91 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
92 | 'state/past/velocity_y':
93 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
94 | 'state/past/speed':
95 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
96 | 'state/past/width':
97 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
98 | 'state/past/x':
99 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
100 | 'state/past/y':
101 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
102 | 'state/past/z':
103 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None),
104 | 'scenario/id':
105 | tf.io.FixedLenFeature([1], tf.string, default_value=None),
106 | }
107 |
108 | # Features of traffic lights.
109 | traffic_light_features = {
110 | 'traffic_light_state/current/state':
111 | tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
112 | 'traffic_light_state/current/valid':
113 | tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None),
114 | 'traffic_light_state/current/x':
115 | tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
116 | 'traffic_light_state/current/y':
117 | tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
118 | 'traffic_light_state/current/z':
119 | tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None),
120 | 'traffic_light_state/past/state':
121 | tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
122 | 'traffic_light_state/past/valid':
123 | tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None),
124 | 'traffic_light_state/past/x':
125 | tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
126 | 'traffic_light_state/past/y':
127 | tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
128 | 'traffic_light_state/past/z':
129 | tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None),
130 | }
131 |
132 | features_description = {}
133 | features_description.update(roadgraph_features)
134 | features_description.update(state_features)
135 | features_description.update(traffic_light_features)
136 |
137 | # road label
138 | road_label = {1:'LaneCenter-Freeway', 2:'LaneCenter-SurfaceStreet', 3:'LaneCenter-BikeLane', 6:'RoadLine-BrokenSingleWhite',
139 | 7:'RoadLine-SolidSingleWhite', 8:'RoadLine-SolidDoubleWhite', 9:'RoadLine-BrokenSingleYellow', 10:'RoadLine-BrokenDoubleYellow',
140 | 11:'Roadline-SolidSingleYellow', 12:'Roadline-SolidDoubleYellow', 13:'RoadLine-PassingDoubleYellow', 15:'RoadEdgeBoundary',
141 | 16:'RoadEdgeMedian', 17:'StopSign', 18:'Crosswalk', 19:'SpeedBump'}
142 |
143 | road_line_map = {1:['xkcd:grey', 'solid', 14], 2:['xkcd:grey', 'solid', 14], 3:['xkcd:grey', 'solid', 10], 5:['w', 'solid', 2], 6:['w', 'dashed', 2],
144 | 7:['w', 'solid', 2], 8:['w', 'solid', 2], 9:['xkcd:yellow', 'dashed', 4], 10:['xkcd:yellow', 'dashed', 2],
145 | 11:['xkcd:yellow', 'solid', 2], 12:['xkcd:yellow', 'solid', 3], 13:['xkcd:yellow', 'dotted', 1.5], 15:['y', 'solid', 4.5],
146 | 16:['y', 'solid', 4.5], 17:['r', '.', 40], 18:['b', 'solid', 13], 19:['xkcd:orange', 'solid', 13]}
147 |
148 | # traffic light label
149 | light_label = {0:'Unknown', 1:'Arrow_Stop', 2:'Arrow_Caution', 3:'Arrow_Go', 4:'Stop', 5:'Caution', 6:'Go', 7:'Flashing_Stop', 8:'Flashing_Caution'}
150 | light_state_map = {0:'k', 1:'r', 2:'b', 3:'g', 4:'r', 5:'b', 6:'g', 7:'r', 8:'b'}
151 | light_state_map_num = {0:0, 1:1, 2:2, 3:3, 4:1, 5:2, 6:3, 7:1, 8:2}
152 | light_state_rank_map = {0: 1, 1: 4, 2: 3, 3: 2, 4: 4, 5: 3, 6: 2, 7: 4, 8: 3}
153 | light_near_state_map = {0:'black', 1:'darkred', 2:'darkblue', 3:'darkgreen', 4:'darkred', 5:'darkblue', 6:'darkgreen', 7:'darkred', 8:'darkblue'}
154 |
155 | def linecolormap(value,m_per_pixel):
156 | return {'color':value[0], 'linestyle':value[1], 'linewidth': value[2]*m_per_pixel/3}
157 |
158 | def light_linecolormap(light_state,value,m_per_pixel):
159 | return {'color':light_state_map[light_state], 'linestyle':value[1], 'linewidth': value[2]*m_per_pixel/3,
160 | 'zorder': light_state_rank_map[light_state]}
161 |
162 | def lightnear_linecolormap(light_state,value,m_per_pixel):
163 | return {'color':light_near_state_map[light_state], 'linestyle':value[1], 'linewidth': value[2]*m_per_pixel/3,
164 | 'zorder': light_state_rank_map[light_state]}
165 |
166 | def traffic_light_map(road_type, tl_state, tl_near, m_per_pixel):
167 | road_val = road_line_map[road_type]
168 | if tl_near:
169 | return lightnear_linecolormap(tl_state, road_val, m_per_pixel)
170 | else:
171 | return light_linecolormap(tl_state, road_val, m_per_pixel)
172 |
173 | from waymo_open_dataset.protos import scenario_pb2
174 |
175 | _ObjectType = scenario_pb2.Track.ObjectType
176 | ALL_AGENT_TYPES = [
177 | _ObjectType.TYPE_VEHICLE,
178 | _ObjectType.TYPE_PEDESTRIAN,
179 | _ObjectType.TYPE_CYCLIST,
180 | ]
181 |
182 | agent_color={
183 | _ObjectType.TYPE_VEHICLE:'r',
184 | _ObjectType.TYPE_PEDESTRIAN:'g',
185 | _ObjectType.TYPE_CYCLIST:'b'
186 | }
187 |
--------------------------------------------------------------------------------
/utils/net_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.append(os.path.dirname(SCRIPT_DIR))
6 |
7 | import torch
8 | import logging
9 | import glob
10 |
11 | import numpy as np
12 | from torch.utils.data import Dataset
13 | from torch.nn import functional as F
14 | from google.protobuf import text_format
15 |
16 | # import torchmetrics
17 | from torchvision.ops.focal_loss import sigmoid_focal_loss
18 |
19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
20 | import random
21 |
22 |
23 | def initLogging(log_file: str, level: str = "INFO"):
24 | logging.basicConfig(filename=log_file, filemode='w',
25 | level=getattr(logging, level, None),
26 | format='[%(levelname)s %(asctime)s] %(message)s',
27 | datefmt='%m-%d %H:%M:%S')
28 | logging.getLogger().addHandler(logging.StreamHandler())
29 |
30 | class DrivingData(Dataset):
31 | def __init__(self, data_dir, use_flow=False):
32 | self.data_list = glob.glob(data_dir)
33 | self.use_flow = use_flow
34 |
35 | def __len__(self):
36 | return len(self.data_list)
37 |
38 | def __getitem__(self, idx):
39 | data = np.load(self.data_list[idx],allow_pickle=True)
40 | ego = data['ego']
41 | neighbor = data['neighbors'][:, :11, :]
42 |
43 | neighbor_map_lanes = data['neighbor_map_lanes']
44 | ego_map_lane = data['ego_map_lane']
45 |
46 | neighbor_crosswalk = data['neighbor_map_crosswalks']
47 | ego_crosswalk = data['ego_map_crosswalk']
48 |
49 | ego_future_states = data['gt_future_states']
50 |
51 | ref_line = data['ref_line']
52 | goal = data['goal']
53 |
54 | hist_ogm = data['hist_ogm']
55 | ego_ogm = data['ego_ogm']
56 |
57 | gt_obs = data['gt_obs']
58 | gt_occ = data['gt_occ']
59 | gt_ego = data['ego_ogm_gt']
60 |
61 | if self.use_flow:
62 | hist_flow = data['hist_flow']
63 | gt_flow = data['gt_flow']
64 | road_graph = data['rg'].astype(np.float32)
65 | road_graph = np.array(road_graph)
66 |
67 | return ego, neighbor, ego_map_lane, neighbor_map_lanes, ego_crosswalk, neighbor_crosswalk,\
68 | ego_future_states, ref_line, goal, hist_ogm, ego_ogm, gt_obs, gt_occ, gt_ego, hist_flow, gt_flow, road_graph
69 |
70 | return ego, neighbor, ego_map_lane, neighbor_map_lanes, ego_crosswalk, neighbor_crosswalk,\
71 | ego_future_states, ref_line, goal, hist_ogm, ego_ogm, gt_obs, gt_occ, gt_ego
72 |
73 | def batch_to_dict(batch, local_rank, use_flow=False):
74 | if use_flow:
75 | ego, neighbor, ego_map_lane, neighbor_map_lanes, ego_crosswalk, neighbor_crosswalk,\
76 | ego_future_states, ref_line, goal, hist_ogm, ego_ogm, gt_obs, gt_occ, gt_ego, hist_flow, gt_flow, road_graph = batch
77 | else:
78 | ego, neighbor, ego_map_lane, neighbor_map_lanes, ego_crosswalk, neighbor_crosswalk,\
79 | ego_future_states, ref_line, goal, hist_ogm, ego_ogm, gt_obs, gt_occ, gt_ego = batch
80 |
81 | # if not use_flow:
82 | ego_mask = (1 - ego_ogm.to(local_rank).float().unsqueeze(-1))
83 | hist_ogm = hist_ogm.to(local_rank).float()*ego_mask
84 | b, h, w, t, c = hist_ogm.shape
85 | hist_ogm = hist_ogm.reshape(b, h ,w, t*c)
86 |
87 | input_dict = {
88 | 'ego_state': ego.to(local_rank).float(),
89 | 'neighbor_state': neighbor.to(local_rank).float(),
90 | 'ego_map_lane': ego_map_lane.to(local_rank).float(),
91 | 'neighbor_map_lanes': neighbor_map_lanes.to(local_rank).float(),
92 | 'ego_map_crosswalk': ego_crosswalk.to(local_rank).float(),
93 | 'neighbor_map_crosswalks': neighbor_crosswalk.to(local_rank).float(),
94 | 'hist_ogm': hist_ogm,
95 | 'ego_ogm': ego_ogm.to(local_rank).float(),
96 | }
97 | target_dict = {
98 | 'ref_line': ref_line.to(local_rank).float(),
99 | 'goal': goal.to(local_rank).float(),
100 | 'ego_future_states':ego_future_states[..., [0,1,4]].to(local_rank).float(),
101 | 'gt_obs':gt_obs.to(local_rank).float(),
102 | 'gt_occ':gt_occ.to(local_rank).sum(-1).clamp(0, 1).float(),
103 | 'gt_ego':gt_ego.to(local_rank).float(),
104 | }
105 | if not use_flow:
106 | return input_dict, target_dict
107 | else:
108 | road_graph = road_graph[:, 128:128+256, 128:128+256, :]
109 | input_dict = {
110 | 'ego_state': ego.to(local_rank).float(),
111 | 'neighbor_state': neighbor.to(local_rank).float(),
112 | 'ego_map_lane': ego_map_lane.to(local_rank).float(),
113 | 'neighbor_map_lanes': neighbor_map_lanes.to(local_rank).float(),
114 | 'ego_map_crosswalk': ego_crosswalk.to(local_rank).float(),
115 | 'neighbor_map_crosswalks': neighbor_crosswalk.to(local_rank).float(),
116 | 'hist_ogm': hist_ogm,
117 | 'ego_ogm': ego_ogm.to(local_rank).float(),
118 | 'hist_flow': hist_flow.to(local_rank).float(),
119 | 'road_graph': road_graph.to(local_rank).float(),
120 | }
121 | target_dict = {
122 | 'ref_line': ref_line.to(local_rank).float(),
123 | 'goal': goal.to(local_rank).float(),
124 | 'ego_future_states':ego_future_states[..., [0,1,4]].to(local_rank).float(),
125 | 'gt_obs':gt_obs.to(local_rank).float(),
126 | 'gt_occ':gt_occ.to(local_rank).sum(-1).clamp(0, 1).float(),
127 | 'gt_ego':gt_ego.to(local_rank).float(),
128 | 'gt_flow':gt_flow.to(local_rank).float()
129 | }
130 | return input_dict, target_dict
131 |
132 | def occupancy_loss(outputs, targets, use_flow=False):
133 | ego_mask = 1-targets['gt_ego']
134 | gt_obs = targets['gt_obs'][..., 0, :] * ego_mask.unsqueeze(-1)
135 | target_ogm = gt_obs.permute(0, 4, 1, 2, 3) #[b, c, t, h, w]
136 | # actors:
137 | actor_loss: torch.Tensor = 0
138 | alpha_list = [0.1, 0.01, 0.01]
139 | for i in range(3):
140 | ref = outputs[:, i]
141 | tar = target_ogm[:, i]
142 | loss = sigmoid_focal_loss(ref, tar, alpha=alpha_list[i], gamma=1, reduction='mean')
143 | actor_loss += loss
144 | actor_loss = actor_loss / 3
145 |
146 | #occulsions:
147 | occ_ref, occ_tar = outputs[:, 3], targets['gt_occ']*ego_mask#[:, 0]
148 | occ_loss = sigmoid_focal_loss(occ_ref, occ_tar, alpha=0.05, gamma=1, reduction='mean')
149 |
150 | if use_flow:
151 | flow_loss: torch.Tensor = 0
152 | target_flow = targets['gt_flow']
153 | target_flow = target_flow.sum(-1)
154 | flow_exists = torch.logical_or(torch.ne(target_flow[..., 0], 0), torch.ne(target_flow[..., 1], 0)).float()
155 | flow_outputs = outputs[:, -2:]
156 | b, c, t, h, w = flow_outputs.shape
157 | flow_outputs = flow_outputs.permute(0, 2, 3, 4, 1)
158 | pred_flow = torch.mul(flow_outputs, flow_exists.unsqueeze(-1))
159 | exist_num = torch.nan_to_num(torch.sum(flow_exists)/2, 1.0)
160 | if exist_num > 0:
161 | flow_loss += F.smooth_l1_loss(pred_flow, target_flow, reduction='sum') / exist_num
162 | else:
163 | flow_loss = None
164 |
165 | return actor_loss, occ_loss, flow_loss
166 |
167 | def modal_selections(x, y, mode):
168 | b, n, t, d = x.shape
169 | if mode=='fde':
170 | fde_dist = torch.norm(x[:, :, -1, :2] - y[:, -1, :2].unsqueeze(1).expand(-1, n, -1), dim=-1)
171 | dist = torch.argmin(fde_dist, dim=-1)
172 | else:
173 | # joint ade and fde
174 | fde_dist = torch.norm(x[:, :, -1, :2] - y[:, -1, :2].unsqueeze(1).expand(-1, n, -1), dim=-1)
175 | ade_dist = torch.norm(x[:, :, :, :2] - y[:, :, :2].unsqueeze(1).expand(-1, n, -1, -1), dim=-1).mean(-1)
176 | dist = torch.argmin(0.5*fde_dist + ade_dist, dim=-1)
177 |
178 | return dist
179 |
180 |
181 | def infer_modal_selection(traj, score, targets, use_planning):
182 | B = score.shape[0]
183 | gt_modes = torch.argmax(score, dim=1)
184 | selected_trajs = traj[torch.arange(B)[:, None], gt_modes.unsqueeze(-1)].squeeze(1)
185 | return selected_trajs, gt_modes
186 |
187 |
188 | def imitation_loss(traj, score, targets, use_planning=False):
189 | gt_future = targets['ego_future_states']
190 | if isinstance(traj, list):
191 | il_loss: torch.Tensor = 0
192 | for tr,sc in zip(traj, score):
193 | loss, selected_trajs, gt_modes = single_layer_planning_loss(tr, sc, targets)
194 | il_loss += loss
195 | else:
196 | il_loss, selected_trajs, gt_modes = single_layer_planning_loss(traj, score, targets)
197 |
198 | return il_loss, selected_trajs, gt_modes
199 |
200 |
201 | def single_layer_planning_loss(traj, score, targets):
202 | gt_future = targets['ego_future_states']
203 | p_d = traj.shape[-1]
204 | gt_future = gt_future[...,:p_d]
205 | gt_modes = modal_selections(traj, gt_future,mode='joint')
206 |
207 | classification_loss = F.cross_entropy(score, gt_modes, label_smoothing=0.2)
208 | B = traj.shape[0]
209 | selected_trajs = traj[torch.arange(B)[:, None], gt_modes.unsqueeze(-1)].squeeze(1)
210 | goal_time = [9, 29, 49]
211 |
212 | ade_loss = F.smooth_l1_loss(selected_trajs, gt_future[...,:p_d])
213 | fde_loss = F.smooth_l1_loss(selected_trajs[...,goal_time,:], gt_future[...,goal_time,:p_d])
214 |
215 | il_loss = ade_loss + 0.5*fde_loss + 2*classification_loss
216 |
217 | return il_loss, selected_trajs, gt_modes
218 |
219 |
220 | def get_contributing_params(y, top_level=True):
221 | nf = y.grad_fn.next_functions if top_level else y.next_functions
222 | for f, _ in nf:
223 | try:
224 | yield f.variable
225 | except AttributeError:
226 | pass # node has no tensor
227 | if f is not None:
228 | yield from get_contributing_params(f, top_level=False)
229 |
230 | def check_non_contributing_params(model, outputs):
231 | contrib_params = set()
232 | all_parameters = set(model.parameters())
233 |
234 | for output in outputs:
235 | contrib_params.update(get_contributing_params(output))
236 | print(all_parameters - contrib_params)
--------------------------------------------------------------------------------
/utils/plan_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.append(os.path.dirname(SCRIPT_DIR))
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | import numpy as np
12 | import math
13 | import bisect
14 |
15 | import matplotlib.pyplot as plt
16 |
17 |
18 | def ref_line_grids(ref_line, widths=10, pixels_per_meter=3.2, left=True):
19 | '''
20 | generate the mapping of Cartisan coordinates (x, y)
21 | according to the Frenet grids (s, d)
22 | inputs: ref_lines (b, length, 2), width
23 | outputs: refline_grids (b, length, width*2, 2)
24 | '''
25 | width_d = (torch.arange(-widths, widths) + 0.5) / pixels_per_meter
26 | # print(ref_line.shape)
27 | b, l, c = ref_line.shape
28 | width_d = width_d.unsqueeze(0).unsqueeze(1).expand(b, l, -1).to(ref_line.device)
29 | angle = ref_line[:, :, 2]
30 | angle = (angle + np.pi) % (2*np.pi) - np.pi #-
31 |
32 | ref_x = ref_line[:, :, 0:1]
33 | ref_y = ref_line[:, :, 1:2]
34 |
35 | # output coords always conincide with ogm's coords settings
36 | x = -torch.sin(angle).unsqueeze(-1) * width_d + ref_x
37 | y = torch.cos(angle).unsqueeze(-1) * width_d + ref_y
38 |
39 | cart_grids = torch.stack([x, y], dim=-1)
40 | return cart_grids
41 |
42 | def ref_line_ogm_sample(ogm, rl_grids, config):
43 | """
44 | scatter the ogm fields to ref_line fields
45 | according to the ref_line_grids
46 | inputs: ogm [b, h, w]
47 | grids: [b, l_s, l_d, 2]
48 | outputs: ref_line fields: [b, l_s, l_d]
49 | """
50 | points_x, points_y = rl_grids[..., 0], rl_grids[..., 1]
51 | pixels_per_meter = config.pixels_per_meter
52 | points_x = torch.round(-points_y * pixels_per_meter) + config.sdc_x_in_grid
53 | points_y = torch.round(-points_x * pixels_per_meter) + config.sdc_y_in_grid
54 |
55 | # Filter out points that are located outside the FOV of topdown map.
56 | point_is_in_fov = torch.logical_and(
57 | torch.logical_and(
58 | torch.greater_equal(points_x, 0), torch.greater_equal(points_y, 0)),
59 | torch.logical_and(
60 | torch.less(points_x, config.grid_width_cells),
61 | torch.less(points_y, config.grid_height_cells))).float()
62 |
63 | w_axis_in = points_x * point_is_in_fov
64 | h_axis_in = points_y * point_is_in_fov
65 |
66 | w_axis_in = w_axis_in.long()
67 | h_axis_in = h_axis_in.long()
68 |
69 | b, h, w = w_axis_in.shape
70 | B = torch.arange(b).long()
71 | refline_fields = ogm[B[:, None], h_axis_in.view(b, -1), w_axis_in.view(b, -1)]
72 | refline_fields = refline_fields.view(b, h, w)
73 |
74 | # mask refline_fields not in fovs:
75 | refline_fields = refline_fields * point_is_in_fov
76 | return refline_fields
77 |
78 | def generate_ego_pos_at_field(ego_pos, ref_lines, angle):
79 | '''
80 | transfrom the ego occupancy into Frenet for refline_field:
81 | inputs: ego_pos: [B, 2] (x, y) angle: ego angles
82 | ref_lines : [B, L, 3] (a, y, angle)
83 | outputs 3 quantile points (s, d, theta) and safe-distance ||1/4 h, w||_2
84 | '''
85 | # 1. Transform ego pos (x, y, angle) into Frenet Coords (s, l, the):
86 | dist = torch.norm(ego_pos.unsqueeze(1) - ref_lines[..., :2], dim=-1, p=2.0)
87 | s = torch.argmin(dist, dim=1)
88 | b = ref_lines.shape[0]
89 | B = torch.arange(b).long()
90 | sel_ref = ref_lines[B, s, :]
91 | s = (s - 200) *0.1
92 |
93 | x_r, y_r, theta_r = sel_ref[:, 0], sel_ref[:, 1], sel_ref[:, 2]
94 | x, y = ego_pos[:, 0], ego_pos[:, 1]
95 |
96 | sgn = (y - y_r) * torch.cos(theta_r) - (x - x_r) * torch.sin(theta_r)
97 | dis = torch.sqrt(torch.square(x - x_r) + torch.square(y - y_r))
98 | l = torch.sign(sgn) * dis
99 |
100 | the = angle - theta_r
101 | # the += np.pi/2
102 | the = (the + np.pi) % (2* np.pi) - np.pi
103 |
104 | ego = torch.stack([s, l, the], dim=-1)
105 |
106 | return ego
107 |
108 |
109 | def refline_meshgrids(ref_line_field, pixels_per_meter=3.2):
110 | '''
111 | build the (s,l) meshgrids for ref_line field
112 | '''
113 | device = ref_line_field.device
114 | b, s, l, _ = ref_line_field.shape
115 | widths = int(l/2)
116 | mesh_l = (torch.arange(-widths, widths) + 0.5) / pixels_per_meter
117 | mesh_s = (torch.arange(s).float() + 0.5) * 0.1 #/ pixels_per_meter
118 | mesh_s, mesh_l = torch.meshgrid(mesh_s, mesh_l)
119 | mesh_sl = torch.stack([mesh_s, mesh_l], dim=-1)
120 | mesh_sl = mesh_sl.unsqueeze(0).expand(b, -1, -1, -1)
121 | mesh_sl = mesh_sl.to(device)
122 | return mesh_sl
123 |
124 |
125 | def gather_nd_slow(ogm, h_axis_in, w_axis_in):
126 | b, h, w = w_axis_in.shape
127 | output = torch.zeros((b, h, w)).to(w_axis_in.device)
128 | for i in range(b):
129 | for j in range(h):
130 | for k in range(w):
131 | output[i, j, k] = ogm[i, h_axis_in[i, j, k], w_axis_in[i, j, k]]
132 | return output
133 |
134 |
135 |
136 | class Spline(object):
137 | """
138 | Cubic Spline class
139 | """
140 |
141 | def __init__(self, x, y):
142 | self.b, self.c, self.d, self.w = [], [], [], []
143 |
144 | self.x = x
145 | self.y = y
146 |
147 | self.nx = len(x) # dimension of x
148 | h = np.diff(x) + 1e-3
149 |
150 | # calc coefficient c
151 | self.a = [iy for iy in y]
152 |
153 | # calc coefficient c
154 | A = self.__calc_A(h)
155 | B = self.__calc_B(h)
156 | self.c = np.linalg.solve(A, B)
157 | # print(self.c1)
158 |
159 | # calc spline coefficient b and d
160 | for i in range(self.nx - 1):
161 | self.d.append((self.c[i + 1] - self.c[i]) / (3.0 * h[i]))
162 | tb = (self.a[i + 1] - self.a[i]) / h[i] - h[i] * (self.c[i + 1] + 2.0 * self.c[i]) / 3.0
163 | self.b.append(tb)
164 |
165 | def calc(self, t):
166 | """
167 | Calc position
168 | if t is outside of the input x, return None
169 | """
170 |
171 | if t < self.x[0]:
172 | return None
173 | elif t > self.x[-1]:
174 | return None
175 |
176 | i = self.__search_index(t)
177 | dx = t - self.x[i]
178 | result = self.a[i] + self.b[i] * dx + self.c[i] * dx ** 2.0 + self.d[i] * dx ** 3.0
179 |
180 | return result
181 |
182 | def calcd(self, t):
183 | """
184 | Calc first derivative
185 | if t is outside of the input x, return None
186 | """
187 |
188 | if t < self.x[0]:
189 | return None
190 | elif t > self.x[-1]:
191 | return None
192 |
193 | i = self.__search_index(t)
194 | dx = t - self.x[i]
195 | result = self.b[i] + 2.0 * self.c[i] * dx + 3.0 * self.d[i] * dx ** 2.0
196 |
197 | return result
198 |
199 | def calcdd(self, t):
200 | """
201 | Calc second derivative
202 | """
203 |
204 | if t < self.x[0]:
205 | return None
206 | elif t > self.x[-1]:
207 | return None
208 |
209 | i = self.__search_index(t)
210 | dx = t - self.x[i]
211 | result = 2.0 * self.c[i] + 6.0 * self.d[i] * dx
212 |
213 | return result
214 |
215 | def __search_index(self, x):
216 | """
217 | search data segment index
218 | """
219 | return bisect.bisect(self.x, x) - 1
220 |
221 | def search_index(self, x):
222 | """
223 | search data segment index
224 | """
225 | return bisect.bisect(self.x, x) - 1
226 |
227 | def __calc_A(self, h):
228 | """
229 | calc matrix A for spline coefficient c
230 | """
231 | A = np.zeros((self.nx, self.nx))
232 | A[0, 0] = 1.0
233 |
234 | for i in range(self.nx - 1):
235 | if i != (self.nx - 2):
236 | A[i + 1, i + 1] = 2.0 * (h[i] + h[i + 1])
237 | A[i + 1, i] = h[i]
238 | A[i, i + 1] = h[i]
239 |
240 | A[0, 1] = 0.0
241 | A[self.nx - 1, self.nx - 2] = 0.0
242 | A[self.nx - 1, self.nx - 1] = 1.0
243 |
244 | return A
245 |
246 | def __calc_B(self, h):
247 | """
248 | calc matrix B for spline coefficient c
249 | """
250 | B = np.zeros(self.nx)
251 |
252 | for i in range(self.nx - 2):
253 | B[i + 1] = 3.0 * (self.a[i + 2] - self.a[i + 1]) / h[i + 1] - 3.0 * (self.a[i + 1] - self.a[i]) / h[i]
254 |
255 | return B
256 |
257 |
258 | class Spline2D:
259 | """
260 | 2D Cubic Spline class
261 | """
262 |
263 | def __init__(self, x, y):
264 | self.s = self.__calc_s(x, y)
265 | self.sx = Spline(self.s, x)
266 | self.sy = Spline(self.s, y)
267 |
268 | def __calc_s(self, x, y):
269 | dx = np.diff(x)
270 | dy = np.diff(y)
271 | self.ds = np.hypot(dx, dy)
272 | s = [0]
273 | s.extend(np.cumsum(self.ds))
274 |
275 | return s
276 |
277 | def calc_position(self, s):
278 | """
279 | calc position
280 | """
281 | x = self.sx.calc(s)
282 | y = self.sy.calc(s)
283 |
284 | return x, y
285 |
286 | def calc_curvature(self, s):
287 | """
288 | calc curvature
289 | """
290 | dx = self.sx.calcd(s)
291 | ddx = self.sx.calcdd(s)
292 | dy = self.sy.calcd(s)
293 | ddy = self.sy.calcdd(s)
294 | k = (ddy * dx - ddx * dy) / ((dx ** 2 + dy ** 2)**(3 / 2))
295 |
296 | return k
297 |
298 | def calc_yaw(self, s):
299 | """
300 | calc yaw
301 | """
302 | dx = self.sx.calcd(s)
303 | dy = self.sy.calcd(s)
304 | yaw = math.atan2(dy, dx)
305 |
306 | return yaw
307 |
308 | def search_index(self,s):
309 | i = self.sx.search_index(s)
310 | j = self.sy.search_index(s)
311 | return i,j
312 |
313 | def generate_target_course(x, y):
314 | csp = Spline2D(x, y)
315 | s = np.arange(0, csp.s[-1], 0.1)
316 | rx, ry, ryaw, rk = [], [], [], []
317 | for i_s in s:
318 | ix, iy = csp.calc_position(i_s)
319 | rx.append(ix)
320 | ry.append(iy)
321 | ryaw.append(csp.calc_yaw(i_s))
322 | rk.append(csp.calc_curvature(i_s))
323 |
324 | return rx, ry, ryaw, rk, csp
325 |
326 |
327 |
--------------------------------------------------------------------------------
/planner.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.append(os.path.dirname(SCRIPT_DIR))
6 | sys.path.append('/theseus')
7 |
8 | import theseus as th
9 | import matplotlib.pyplot as plt
10 |
11 | import torch
12 | import torch.nn as nn
13 | import numpy as np
14 |
15 | from utils.plan_utils import *
16 | import casadi as cs
17 | import math
18 | import scipy
19 |
20 | class Planner(object):
21 | def __init__(self,
22 | device,
23 | horizon=4,
24 | g_length=50,
25 | g_width=40,
26 | test_iters=50,
27 | test_step=0.3,
28 | ):
29 | super(Planner, self).__init__()
30 |
31 | self.g_width = g_width
32 | self.horizon = horizon
33 | control_variables = th.Vector(dof=horizon *10 * 2, name="control_variables")
34 | ref_line_fields = th.Variable(torch.empty(1, horizon, g_length, g_width, 2), name="ref_line_field")
35 | ref_line_costs = th.Variable(torch.empty(1, horizon, g_length, g_width), name="ref_line_costs")
36 | lwt = th.Variable(torch.empty(1, horizon*10, 3), name="lwt")
37 | current_state = th.Variable(torch.empty(1, 2, 4), name="current_state")
38 | spl = th.Variable(torch.empty(1, 1), name="speed_limit")
39 | stp = th.Variable(torch.empty(1, 1), name="stop_point")
40 | il_lane = th.Variable(torch.empty(1, horizon*10, 2), name="il_lane")
41 |
42 | objective = th.Objective()
43 | objective = self.cost_function(objective, control_variables, ref_line_fields,
44 | ref_line_costs, lwt, current_state, spl, stp, il_lane)
45 | self.optimizer = th.GaussNewton(objective, th.CholmodSparseSolver, vectorize=False,
46 | max_iterations=test_iters, step_size=test_step, abs_err_tolerance=1e-2)
47 |
48 | self.layer = th.TheseusLayer(self.optimizer, vectorize=False)
49 | self.layer.to(device=device)
50 |
51 | self.max_acc = 5
52 | self.max_delta = math.pi / 6
53 | self.ts = 0.1
54 | self.sigma = 1.0
55 | self.num = 0
56 |
57 | def preprocess(self, ego_state, ego_plan, ref_lines, ogm_prediction, type_mask, config, left=True):
58 | #computing the angle:
59 | diff_traj = torch.diff(ego_plan, axis=1)
60 | diff_traj = torch.cat([diff_traj, diff_traj[:, -1, :].unsqueeze(1)], dim=1)
61 | angle = torch.nan_to_num(torch.atan2(diff_traj[:,:,1], diff_traj[:,:,0]), 0).clamp(-0.67,0.67)
62 |
63 | #reshape time axis: l to the batch axis
64 | frenet_plan, angle = ego_plan[:, :50, :2], angle[:, :50]
65 | length, width = ego_state[:, -1, 6], ego_state[:, -1, 7]
66 | b, l, d = frenet_plan.shape
67 | length, width = length.unsqueeze(1).expand(-1, l), width.unsqueeze(1).expand(-1, l)
68 | frenet_plan = frenet_plan.reshape(b*l, d)
69 | speed = ego_state[:, -2:, 3]
70 | # print(ego_state[:, -2:])
71 | angle = angle.reshape(b*l)
72 |
73 | # ref_lines = ref_lines[:, ::2, :]
74 | speed_limit = ref_lines[:, :, 4]
75 | b, t, d = ref_lines.shape
76 | orf = ref_lines
77 | ref_lines = ref_lines.unsqueeze(1).expand(-1, l, -1, -1).reshape(b*l, t, d)
78 |
79 | #generate the ego mask
80 | ego = generate_ego_pos_at_field(frenet_plan, ref_lines, angle)
81 | spl = torch.max(speed_limit, dim=1, keepdim=True)[0]
82 | stp = torch.max(speed_limit==0, dim=1, keepdim=True)[1] * 0.1
83 | theta = ego[:, -1].reshape(b, l)
84 |
85 | ref_lines = ref_lines.reshape(b, l, t, d)[:, ::10].reshape(b*l//10, t, d)
86 |
87 | refline_fields = ref_line_grids(ref_lines, widths=int(self.g_width/2), left=left ,pixels_per_meter=3.2)
88 |
89 | ogm_prediction = ogm_prediction[:, 0] * (ogm_prediction[:, 0] > 0.3) + \
90 | 10*ogm_prediction[:, 1] * type_mask[:, 1,None,None,None] +
91 | 10*ogm_prediction[:, 2] * type_mask[:, 2,None,None,None]
92 | ogm_prediction = ogm_prediction.clamp(0, 1)
93 |
94 | ogm_prediction = ogm_prediction * (ogm_prediction > 0.1)
95 |
96 | b, lp, h, w = ogm_prediction.shape
97 | og = ogm_prediction
98 | ogm_prediction = ogm_prediction.reshape(b*l//10, h, w)
99 |
100 | ogm_refline_fields = ref_line_ogm_sample(ogm_prediction, refline_fields, config)
101 | _, h, w = ogm_refline_fields.shape
102 |
103 | current_state = generate_ego_pos_at_field(ego_state[:, -1, [0, 1]],orf,ego_state[:, -1, 2])
104 | lcurrent_state = generate_ego_pos_at_field(ego_state[:, -2, [0, 1]],orf,ego_state[:, -2, 2])
105 |
106 | all_ego = torch.cat([current_state[:, None, :2], ego[:, :2].reshape(b, l, 2)], dim=1)
107 | d_ego = torch.diff(all_ego, dim=1) / 0.1
108 |
109 | c_state = torch.stack([lcurrent_state, current_state], dim=-2)
110 | c_state = torch.concat([c_state, speed.unsqueeze(-1)], dim=-1)
111 | il_lane = ego[:, :2].reshape(b, l, 2)
112 | # print(il_lane[0, 9::10])
113 |
114 | return {
115 | 'control_variables': d_ego.reshape(b, l*2).detach(),
116 | 'il_lane': il_lane.detach(),
117 | 'lwt': torch.stack([length, width, theta], dim=-1).detach(),
118 | 'ref_line_field': refline_fields.reshape(b, lp, h, w, 2).detach(),
119 | 'ref_line_costs': ogm_refline_fields.reshape(b, lp, h, w).detach(),
120 | 'current_state': c_state.detach(),
121 | 'speed_limit': spl,
122 | 'stop_point':stp,
123 | }
124 |
125 | def il_cost(self, optim_vars, aux_vars):
126 | ego = optim_vars[0].tensor.view(-1, self.horizon*10, 2)
127 | current_state = aux_vars[0].tensor
128 | ds = ego[:, :, 0].clamp(min=0)
129 | dl = ego[:, :, 1]
130 | s = current_state[:, -1, 0][:,None] + torch.cumsum(ds * 0.1, dim=-1)
131 | L = current_state[:, -1, 1][:,None] + torch.cumsum(dl * 0.1, dim=-1)
132 | ego = torch.stack([s, L], dim=-1)
133 | il_lane = aux_vars[1].tensor
134 | cost = torch.abs(il_lane - ego).mean(-1)
135 | return 1 * cost
136 |
137 | def collision_cost(self, optim_vars, aux_vars):
138 | ego = optim_vars
139 | lwt, refline_fields, ref_line_costs, current_state = aux_vars
140 | lwt, refline_fields, ref_line_costs = lwt.tensor, refline_fields.tensor, ref_line_costs.tensor
141 | b, l, h, w = ref_line_costs.shape
142 | lwt = lwt[:, 9::10,:].reshape(b*l, 3)
143 | ego = ego[0].tensor.view(b, l*10, 2)
144 | ds = ego[:, :, 0].clamp(min=0)
145 | dl = ego[:, :, 1]
146 | s = current_state[:, -1, 0][:,None] + torch.cumsum(ds * 0.1, dim=-1)
147 | L = current_state[:, -1, 1][:,None] + torch.cumsum(dl * 0.1, dim=-1)
148 | ego = torch.stack([s, L], dim=-1)[:, 4::5, :]
149 | ego = ego.view(b, l, 2, 2)
150 | ego = ego.view(b*l, 2, 2)
151 |
152 | refline_fields = refline_fields.view(b*l, h, w, 2)
153 | ref_line_costs = ref_line_costs.view(b*l, h, w)
154 |
155 | mesh_sl = refline_meshgrids(refline_fields, pixels_per_meter=1.6)
156 |
157 | safety_cost_mask = ref_line_costs
158 | safety_cost_mask = safety_cost_mask > 0.3
159 |
160 | diff = mesh_sl.unsqueeze(1) - ego.unsqueeze(-2).unsqueeze(-2)
161 |
162 | interactive_mask = (diff[..., 0] > 0) * (torch.abs(diff[..., 1]) < 7.5)
163 | ego_dist = torch.sqrt(torch.square(diff[..., 0]) + torch.square(diff[..., 1]))
164 | ego_dist = (5 - ego_dist) * (ego_dist < 5) * interactive_mask
165 |
166 | ego_s = ego_dist * safety_cost_mask.unsqueeze(1)
167 | safety_cost_s = ego_s.sum(-1).sum(-1)
168 |
169 | safety_cost = safety_cost_s
170 |
171 | safety_cost = safety_cost.view(b, l*2)
172 |
173 | return 10 * safety_cost
174 |
175 | def red_light(self, optim_vars, aux_vars):
176 | ego = optim_vars
177 | ego = ego[0].tensor.view(-1, self.horizon*10, 2)
178 | stop_point = aux_vars[0].tensor
179 | current_state = aux_vars[1].tensor
180 | s = current_state[:, -1, 0][:,None] + torch.cumsum(ego[..., 0] * 0.1, dim=-1)
181 | stop_distance = stop_point - 3
182 | red_light_error = (s - stop_distance) * (s > stop_distance) * (stop_point != 0)
183 | return 10 * red_light_error
184 |
185 |
186 | def cost_function(self, objective, control_variables, ref_line_fields,
187 | ref_line_costs, lwt, current_state, spl, stp, il_lane, vectorize=True):
188 |
189 | safe_cost = th.AutoDiffCostFunction([control_variables], self.collision_cost, self.horizon*2,
190 | aux_vars=[lwt, ref_line_fields, ref_line_costs, current_state], autograd_vectorize=vectorize, name="safe_cost")
191 | objective.add(safe_cost)
192 |
193 | il_cost = th.AutoDiffCostFunction([control_variables], self.il_cost, self.horizon*10,
194 | aux_vars=[current_state, il_lane],autograd_vectorize=vectorize, name="il_cost")
195 | objective.add(il_cost)
196 |
197 |
198 | rl_cost = th.AutoDiffCostFunction([control_variables], self.red_light, self.horizon*10,
199 | aux_vars=[stp, current_state],autograd_vectorize=vectorize, name="red_light")
200 | objective.add(rl_cost)
201 |
202 | return objective
203 |
204 |
205 | def plan(self, planning_inputs, selected_ref, current_state):
206 |
207 | final_values, info = self.layer.forward(planning_inputs, optimizer_kwargs={'track_best_solution': True})
208 | plan = info.best_solution["control_variables"].view(-1, self.horizon*10, 2)
209 | plan = plan.to(selected_ref.device)
210 | s = planning_inputs['current_state'][:, -1, 0][:,None] + torch.cumsum(plan[..., 0] * 0.1, dim=-1)
211 | l = planning_inputs['current_state'][:, -1, 1][:,None] + torch.cumsum(plan[..., 1] * 0.1, dim=-1)
212 | plan = torch.stack([s, l], dim=-1)
213 | xy_plan = self.frenet_to_cartiesan(plan, selected_ref)
214 | speed = torch.hypot(current_state[:, -1, 2], current_state[:, -1, 3])
215 | last_speed = torch.hypot(current_state[:, -2, 2], current_state[:,-2, 3])
216 | acc = (speed - last_speed)/ 0.1
217 | current_state = torch.stack([current_state[:,-1, 0], current_state[:, -1, 1], current_state[:, -1, 4], speed, acc], dim=-1)
218 | b = xy_plan.shape[0]
219 | res = []
220 | for i in range(b):
221 | pl = self.refine(current_state[i].cpu().numpy(), xy_plan[i].cpu().numpy())
222 | res.append(pl)
223 | return torch.tensor(np.stack(res, 0)).to(xy_plan.device).float()
224 |
225 | def refine(self, current_state, reference):
226 | opti = cs.Opti()
227 |
228 | # Define the optimization variables
229 | X = opti.variable(4, self.horizon*10 + 1)
230 | U = opti.variable(2, self.horizon*10)
231 |
232 | # Define the initial state and the reference trajectory
233 | x0 = current_state[:4] # (x, y, theta, v)
234 | xr = reference.T
235 |
236 | # Define the cost function for the MPC problem
237 | obj = 0
238 |
239 | for i in range(self.horizon*10):
240 | obj += (i+1) / (self.horizon*10) * cs.sumsqr(X[:2, i+1] - xr[:2, i])
241 | obj += 100 * (X[2, i+1] - xr[2, i]) ** 2
242 | obj += 0.1 * (U[0, i]) ** 2
243 | obj += U[1, i] ** 2
244 |
245 | if i >= 1:
246 | obj += 0.1 * (U[0, i] - U[0, i-1]) ** 2
247 | obj += (U[1, i] - U[1, i-1]) ** 2
248 |
249 | opti.minimize(obj)
250 |
251 | # Define the constraints for the MPC problem
252 | opti.subject_to(X[:, 0] == x0)
253 | opti.subject_to(U[0, 0] == current_state[4])
254 |
255 | for i in range(self.horizon*10):
256 | opti.subject_to([X[0, i+1] == X[0, i] + X[3, i] * cs.cos(X[2, i]) * self.ts,
257 | X[1, i+1] == X[1, i] + X[3, i] * cs.sin(X[2, i]) * self.ts,
258 | X[2, i+1] == X[2, i] + X[3, i] / 4 * cs.tan(U[1, i]) * self.ts,
259 | X[3, i+1] == X[3, i] + U[0, i] * self.ts])
260 |
261 | for i in range(self.horizon):
262 | int_step = (i+1)*10
263 | opti.subject_to(X[0, int_step] - xr[0, int_step-1] <= 2)
264 | opti.subject_to(X[0, int_step] - xr[0, int_step-1] >= -2)
265 | opti.subject_to(X[1, int_step] - xr[1, int_step-1] <= 0.5)
266 | opti.subject_to(X[1, int_step] - xr[1, int_step-1] >= -0.5)
267 |
268 | opti.subject_to(X[3, :] >= 0)
269 |
270 | # Create the MPC solver
271 | opts = {'ipopt.print_level': 0, 'print_time': False}
272 | opti.solver('ipopt', opts)
273 | try:
274 | sol = opti.solve()
275 | states = sol.value(X)
276 | # act = sol.value(U)
277 | except:
278 | print("Solver failed. Returning best solution found.")
279 | states = opti.debug.value(X)
280 |
281 | traj = states.T[1:, :3]
282 |
283 | return traj
284 |
285 | def frenet_to_cartiesan(self, sl, ref_line):
286 | s, l = (sl[:, :, 0]*10 + 200), sl[:, :, 1]
287 | s = s.clamp(0, 1199)
288 | b = ref_line.shape[0]
289 | ref_points = ref_line[torch.arange(b).long()[:,None], s.long(), :]
290 | cartesian_x = ref_points[:, :, 0] - l * torch.sin(ref_points[:, :, 2])
291 | cartesian_y = ref_points[:, :, 1] + l * torch.cos(ref_points[:, :, 2])
292 | angle = ref_points[:, :, 2]
293 |
294 | return torch.stack([cartesian_x, cartesian_y, angle], dim=-1)
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from torchmetrics.functional import auroc
5 |
6 | import numpy as np
7 | import math
8 |
9 | import skimage as ski
10 |
11 | def draw_ego_mask(ego, length=5.2860+1, width=2.332+0.5, size=(128, 128), pixels_per_meter=1.6):
12 |
13 | b = ego.shape[0]
14 | dego = ego.detach().cpu().numpy()
15 | masks = []
16 | for i in range(b):
17 | x, y, angle = dego[i, 0], dego[i, 1], dego[i, 2]
18 | sin, cos = np.sin(angle), np.cos(angle)
19 | front_left = [x + length/2*cos - width/2*sin, y + length/2*sin + width/2*cos]
20 | front_right = [x + length/2*cos + width/2*sin, y + length/2*sin - width/2*cos]
21 | rear_left = [x - length/2*cos - width/2*sin, y - length/2*sin + width/2*cos]
22 | rear_right = [x - length/2*cos + width/2*sin, y - length/2*sin - width/2*cos]
23 | poly_xy = np.array([front_left, front_right, rear_right, rear_left]) #(4, 2)
24 | # x,y -> h, w
25 | poly_h = int(size[0]*0.75) - np.round(poly_xy[:, 0] * pixels_per_meter)
26 | poly_w = int(size[1]*0.5) - np.round(poly_xy[:, 1] * pixels_per_meter)
27 | poly_hw = np.stack([poly_h, poly_w] ,axis=-1)
28 | mask = ski.draw.polygon2mask(size, poly_hw)
29 | masks.append(mask)
30 | masks = np.stack(masks, axis=0)
31 | masks = torch.tensor(masks).to(ego.device)
32 | return masks
33 |
34 | def plan_metrics(trajectories, ego_future):
35 | l = ego_future.shape[-2]
36 | trajectories = trajectories[..., :l, :]
37 | ego_future_valid = torch.ne(ego_future[..., :2], 0)
38 | ego_trajectory = trajectories[..., :2] * ego_future_valid[:, None, :, :]
39 | distance = torch.norm(ego_trajectory[:,:, :, :2] - ego_future[:,None, :, :2], dim=-1)
40 |
41 | ade = distance.mean(-1)
42 | ade, _ = torch.min(ade,dim=-1)
43 | egoADE = torch.mean(ade)
44 | fde = distance[:,:,-1]
45 | fde, _ = torch.min(fde,dim=-1)
46 | egoFDE = torch.mean(fde)
47 |
48 | fde3 = distance[:,:,29]
49 | fde3, _ = torch.min(fde3,dim=-1)
50 | egoFDE3 = torch.mean(fde3)
51 |
52 | fde1 = distance[:,:,9]
53 | fde1, _ = torch.min(fde1,dim=-1)
54 | egoFDE1 = torch.mean(fde1)
55 |
56 | return egoADE.item(), egoFDE.item() ,egoFDE3.item(), egoFDE1.item()
57 |
58 | from time import time
59 | def occupancy_metrics(preds, target):
60 | #preds: [batch, t, h, w] (sigmoid) target[b, t, h, w]
61 | T = target.shape[1]
62 | check_time = [1, 3, 5]
63 | auc_list, iou_list = [], []
64 |
65 | for i in range(T):
66 | auc_list.append(auc_metrics(preds[:, i], target[:, i]))
67 | iou_list.append(soft_iou(preds[:, i], target[:, i]))
68 | res_auc, res_iou = [], []
69 | for t in check_time:
70 | res_auc.append(torch.mean(torch.stack(auc_list[:t])).item())
71 | res_iou.append(torch.mean(torch.stack(iou_list[:t])).item())
72 | return res_auc, res_iou
73 |
74 | def all_type_occupancy_metrics(preds, target, n_types=2):
75 | res_list = []
76 | for i in range(n_types - 1):
77 | res_auc, res_iou = occupancy_metrics(preds[:, i], target['gt_obs'][..., 0, i])
78 | res_list.append([res_auc, res_iou])
79 | if preds.shape[1]<=3:
80 | res_auc, res_iou = occupancy_metrics(preds[:, 0], (target['gt_occ'] + target['gt_obs'][..., 0, 0]).clamp(0, 1))
81 | else:
82 | res_auc, res_iou = occupancy_metrics(preds[:, 3], target['gt_occ'])
83 | res_list.append([res_auc, res_iou])
84 | return res_list
85 |
86 | def auc_metrics(inputs, target):
87 | return auroc(inputs, target.int(), task='binary', thresholds=100)
88 |
89 | def soft_iou(inputs, target):
90 | inputs, target = inputs.reshape(-1), target.reshape(-1)
91 | intersection = torch.mean(torch.mul(inputs, target))
92 | T_inputs, T_target = torch.mean(inputs), torch.mean(target)
93 | soft_iou_score = torch.nan_to_num(intersection / (T_inputs + T_target - intersection) ,0.0)
94 | return soft_iou_score
95 |
96 | def check_dynamics(traj, current_state):
97 | d_t = 0.1
98 | diff_xy = torch.diff(traj, dim=1)
99 | diff_x, diff_y = diff_xy[:,:, 0], diff_xy[:,:, 1]
100 |
101 | v_x, v_y, theta = diff_x / d_t, diff_y/d_t, np.arctan2(diff_y.cpu().numpy(), diff_x.cpu().numpy() + 1e-6)
102 | theta = torch.tensor(theta).to(v_x.device)
103 | lon_speed = v_x * torch.cos(theta) + v_y * torch.sin(theta)
104 | lat_speed = v_y * torch.cos(theta) - v_x * torch.sin(theta)
105 |
106 | acc = torch.diff(lon_speed,dim=-1) / d_t
107 | jerk = torch.diff(lon_speed,dim=-1,n=2) / d_t**2
108 | lat_acc = torch.diff(lat_speed,dim=-1) / d_t
109 |
110 | return torch.mean(torch.abs(acc)).item(), torch.mean(torch.abs(jerk)).item(), torch.mean(torch.abs(lat_acc)).item()
111 |
112 | def check_traffic(traj, ref_line, gt_modes):
113 | b, t, c = ref_line.shape
114 | red_light = False
115 | off_route = False
116 |
117 | # project to frenet
118 | distance_to_ref = torch.cdist(traj[:,:, :2], ref_line[:,:, :2])
119 | #b, L_ego , s_ref
120 | s_ego = torch.argmin(distance_to_ref, axis=-1)
121 | distance_to_route = torch.min(distance_to_ref, axis=-1).values
122 | off_route = torch.any(distance_to_route > 5, dim=1)
123 |
124 | # get stop point
125 | stop_point = torch.argmax(ref_line[:,:,-2].int(),dim=1)
126 | sig = ref_line[torch.arange(b)[:,None],stop_point.unsqueeze(-1),-1].squeeze(1)
127 | rl_sig = torch.logical_or(sig==1, torch.logical_or(sig==4, sig==7))
128 | s_stp = s_ego-stop_point.unsqueeze(-1)
129 | red_light = torch.logical_and(torch.logical_and(stop_point > 0, torch.any(s_stp > 0,dim=1)), rl_sig)
130 |
131 | return red_light, off_route#.float().mean().item()
132 |
133 | def compare_to_gt(ego_metric, gt_metric):
134 | not_gt_metric = torch.logical_not(gt_metric)
135 | real_metric = torch.logical_and(ego_metric, not_gt_metric)
136 | if not_gt_metric.float().sum()==0:
137 | return not_gt_metric.float().sum().item()
138 | real_val = real_metric.float().sum() / not_gt_metric.float().sum()
139 | return real_val.item()
140 |
141 | def flow_epe(outputs, targets):
142 | if 'gt_flow' not in targets:
143 | return 0
144 | target_flow = targets['gt_flow']
145 | target_flow = target_flow.sum(-1)
146 | flow_exists = torch.logical_or(torch.ne(target_flow[..., 0], 0), torch.ne(target_flow[..., 1], 0)).float()
147 | flow_outputs = outputs[:, -2:]
148 | b, c, t, h, w = flow_outputs.shape
149 | flow_outputs = flow_outputs.permute(0, 2, 3, 4, 1)
150 | pred_flow = torch.mul(flow_outputs, flow_exists.unsqueeze(-1))#.permute(0, 2, 3, 4, 1) #[b, t, h, w, 2]
151 | epe_list = []
152 | # for i in range(3):
153 | if torch.sum(flow_exists) > 0:
154 | flow_epe = torch.sum(torch.norm(pred_flow - target_flow, p=2, dim=-1))/ torch.nan_to_num(torch.sum(flow_exists)/2, 1.0)
155 | return flow_epe.item()
156 | else:
157 | return 0
158 | # return epe_list
159 |
160 | class TrainingMetrics:
161 | def __init__(self):
162 | self.ade = []
163 | self.fde = []
164 | self.il_loss = []
165 | self.ogm_loss = []
166 | self.epe = []
167 |
168 | def update(self, traj, score, gt_modes, target, il_loss, ogm_loss, outputs):
169 | ade, fde, fde3, fde1 = plan_metrics(traj, target['ego_future_states'])
170 | self.epe.append(flow_epe(outputs, target))
171 | self.ade.append(ade)
172 | self.fde.append(fde)
173 |
174 | self.il_loss.append(il_loss.item())
175 | self.ogm_loss.append(ogm_loss.item())
176 | return np.mean(self.ade), np.mean(self.fde), np.mean(self.il_loss), np.mean(self.ogm_loss)
177 |
178 | def result(self):
179 | return {
180 | 'T_il_loss':np.mean(self.il_loss),
181 | 'T_ogm_loss':np.mean(self.ogm_loss),
182 | 'T_ade':np.mean(self.ade),
183 | 'T_fde':np.mean(self.fde),
184 | 'epe_v':np.mean(self.epe)
185 | }
186 |
187 | class ValidationMetrics:
188 | def __init__(self):
189 | self.ade = []
190 | self.fde = []
191 | self.fde3 = []
192 | self.fde1 = []
193 | self.il_loss = []
194 | self.ogm_loss = []
195 |
196 | self.ogm_auc = []
197 | self.ogm_iou = []
198 | self.epe = []
199 |
200 | self.ogm_auc_p = []
201 | self.ogm_iou_p = []
202 | self.epe_p = []
203 |
204 | self.ogm_auc_c = []
205 | self.ogm_iou_c = []
206 | self.epe_c = []
207 |
208 | self.occ_auc = []
209 | self.occ_iou = []
210 |
211 | def update(self, traj, score, ogm_pred, gt_modes, target, il_loss, ogm_loss):
212 | ade, fde, fde3, fde1 = plan_metrics(traj, target['ego_future_states'])
213 | self.ade.append(ade)
214 | self.fde.append(fde)
215 | self.fde3.append(fde3)
216 | self.fde1.append(fde1)
217 |
218 | epe_list = flow_epe(ogm_pred, target)
219 | self.epe.append(epe_list)
220 |
221 | self.il_loss.append(il_loss.item())
222 | self.ogm_loss.append(ogm_loss.item())
223 |
224 | ogm_list = all_type_occupancy_metrics(ogm_pred.sigmoid(), target, 4)
225 | ogm_auc, ogm_iou, occ_auc, occ_iou = ogm_list[0][0][1], ogm_list[0][1][1], ogm_list[-1][0][1], ogm_list[-1][1][1]
226 |
227 | self.ogm_auc.append(ogm_auc)
228 | self.ogm_iou.append(ogm_iou)
229 | self.occ_auc.append(occ_auc)
230 | self.occ_iou.append(occ_iou)
231 |
232 | ogm_auc_p, ogm_iou_p, ogm_auc_c, ogm_iou_c = ogm_list[1][0][1], ogm_list[1][1][1], ogm_list[2][0][1], ogm_list[2][1][1]
233 | self.ogm_auc_p.append(ogm_auc_p)
234 | self.ogm_iou_p.append(ogm_iou_p)
235 | self.ogm_auc_c.append(ogm_auc_c)
236 | self.ogm_iou_c.append(ogm_iou_c)
237 |
238 | return np.mean(self.ade), np.mean(self.fde), np.mean(self.il_loss), np.mean(self.ogm_loss),\
239 | np.mean(self.ogm_auc), np.mean(self.ogm_iou), np.mean(self.occ_auc), np.mean(self.occ_iou)
240 |
241 | def result(self):
242 | return {
243 | 'E_il_loss':np.mean(self.il_loss),
244 | 'E_ogm_loss':np.mean(self.ogm_loss),
245 | 'E_ade':np.mean(self.ade),
246 | 'E_fde_5': np.mean(self.fde),
247 | 'E_fde_3': np.mean(self.fde3),
248 | 'E_fde_1': np.mean(self.fde1),
249 | 'ogm_auc_v':np.mean(self.ogm_auc),
250 | 'ogm_iou_v':np.mean(self.ogm_iou),
251 | 'ogm_auc_p':np.mean(self.ogm_auc_p),
252 | 'ogm_iou_p':np.mean(self.ogm_iou_p),
253 | 'ogm_auc_c':np.mean(self.ogm_auc_c),
254 | 'ogm_iou_c':np.mean(self.ogm_iou_c),
255 | 'occ_auc':np.mean(self.occ_auc),
256 | 'occ_iou':np.mean(self.occ_iou),
257 | 'v_epe':np.mean(self.epe)
258 | }
259 |
260 |
261 | class TestingMetrics:
262 | def __init__(self, config, lite_mode=False):
263 |
264 | self.reset()
265 | self.config = config
266 | self.ogm_to_position()
267 | self.lite_mode = lite_mode
268 |
269 | def reset(self):
270 | self.valid_dict = {
271 | 'fde_1s':[], 'fde_3s':[], 'fde_5s':[], 'ade':[],'collisions_rate':[], 'off_road_rate':[], 'red_light':[],
272 | 'acc':[], 'jerk':[], 'lat_acc':[]
273 | }
274 | for m in ['auc','iou']:
275 | for t in [3 ,5, 8]:
276 | for ty in ['v','p','c','occ']:
277 | self.valid_dict[f'{m}_{t}_{ty}'] = []
278 |
279 | def ogm_to_position(self):
280 | indexes = torch.arange(1, self.config.grid_height_cells + 1)
281 | widths_indexes = - (indexes - self.config.sdc_y_in_grid - 0.5) / self.config.pixels_per_meter
282 | heights_indexes = - (indexes - self.config.sdc_x_in_grid - 0.5) / self.config.pixels_per_meter
283 | #correspnding (x,y) in dense coordinates (h, w, 2)
284 | coordinates = torch.stack(torch.meshgrid([widths_indexes, heights_indexes]), dim=-1)#.permute(1, 0, 2)
285 | self.ogm_coordinates = coordinates
286 |
287 |
288 | def gt_ogm_to_position(self, traj, target, t, current_state, col_thres=3):
289 | ego_mask = 1- target['gt_ego'][:, t]
290 | all_occupancies = target['gt_obs'][:, t, :, :, 0, :].sum(-1) + target['gt_occ'][:, t , :, :]
291 | all_occupancies = all_occupancies * ego_mask
292 | all_occupancies = all_occupancies.clamp(0, 1).unsqueeze(-1)
293 | b = all_occupancies.shape[0]
294 | current_ogm_coordinates = self.ogm_coordinates.unsqueeze(0).expand(b, -1, -1, -1).to(all_occupancies.device)
295 |
296 | b, h, w, c = current_ogm_coordinates.shape
297 | ego_plan = draw_ego_mask(traj[:, t*10 + 9, :3])
298 |
299 | #filtering the occupied occupancies
300 | collision_grids = ego_plan * all_occupancies[..., 0]
301 | collision_grids = collision_grids.reshape(b, h*w)
302 | collision_grids = collision_grids.sum(1) > col_thres
303 |
304 | return collision_grids.float()
305 |
306 | def collsions_check(self, traj, target, current_state):
307 | col_list = []
308 | for t in range(5):
309 | col_list.append(self.gt_ogm_to_position(traj, target, t, current_state))
310 | col_list = torch.stack(col_list, dim=1).sum(1) >= 1
311 | col_rate = col_list
312 | return col_rate
313 |
314 | def update(self, traj, score, ogm_pred, gt_modes, target, current_state):
315 |
316 | ade, fde, fde3, fde1 = plan_metrics(traj, target['ego_future_states'][:, :50, :])
317 |
318 | self.valid_dict['ade'].append(ade)
319 | self.valid_dict['fde_1s'].append(fde1)
320 | self.valid_dict['fde_3s'].append(fde3)
321 | self.valid_dict['fde_5s'].append(fde)
322 |
323 | ogm_list = all_type_occupancy_metrics(ogm_pred, target, 4)
324 |
325 | for i, ty in enumerate(['v','p','c','occ']):
326 | for j, t in enumerate([3, 5, 8]):
327 | for k, m in enumerate(['auc', 'iou']):
328 | self.valid_dict[f'{m}_{t}_{ty}'].append(ogm_list[i][k][j])
329 | # self.valid_dict[f'{m}_{t}_{ty}'].append(0)
330 |
331 | gt_traj = target['ego_future_states'][..., :,:]
332 | if not self.lite_mode:
333 | col_rate = self.collsions_check(traj, target, current_state)
334 | self.valid_dict['collisions_rate'].append(col_rate.float().mean().item())
335 |
336 | acc, jerk, lat_acc = check_dynamics(traj, current_state)
337 |
338 | self.valid_dict['acc'].append(acc)
339 | self.valid_dict['lat_acc'].append(lat_acc)
340 | self.valid_dict['jerk'].append(jerk)
341 |
342 | red_light, off_route = check_traffic(traj, target['ref_line'], gt_modes)
343 | gt_red_light, gt_off_route = check_traffic(gt_traj, target['ref_line'], gt_modes)
344 | self.valid_dict['red_light'].append(compare_to_gt(red_light, gt_red_light))
345 | self.valid_dict['off_road_rate'].append(compare_to_gt(off_route, gt_off_route))
346 |
347 | def result(self):
348 | new_dict = {}
349 | for k, v in self.valid_dict.items():
350 | new_dict[k] = np.nanmean(v)
351 | return new_dict
--------------------------------------------------------------------------------
/utils/occupancy_grid_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from functools import partial
4 |
5 | from .occupancy_render_utils import render_occupancy, render_flow_from_inputs, sample_filter, generate_units, render_ego_occupancy
6 | from waymo_open_dataset.utils.occupancy_flow_grids import TimestepGrids, WaypointGrids,_WaypointGridsOneType
7 |
8 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2
9 | from waymo_open_dataset.utils import occupancy_flow_data
10 |
11 | def create_ground_truth_timestep_grids(
12 | traj_tensor,
13 | valid_tensor,
14 | ego_traj,
15 | config,
16 | flow=False,
17 | flow_origin=False,
18 | sdc_ids=None,
19 | test=False
20 | ):
21 | """Renders topdown views of agents over past/current/future time frames.
22 |
23 | Args:
24 | inputs: Dict of input tensors from the motion dataset.
25 | config: OccupancyFlowTaskConfig proto message.
26 |
27 | Returns:
28 | TimestepGrids object holding topdown renders of agents.
29 | """
30 |
31 | timestep_grids = TimestepGrids()
32 |
33 | unit_x, unit_y = generate_units(config.agent_points_per_side_length, config.agent_points_per_side_width)
34 |
35 | # Occupancy grids.
36 | sample_func = partial(
37 | sample_filter,
38 | traj_tensor=traj_tensor,
39 | valid_tensor=valid_tensor,
40 | ego_traj=ego_traj,
41 | config=config,
42 | unit_x=unit_x,unit_y=unit_y
43 |
44 | )
45 |
46 | current_sample = sample_func(
47 | times=['current'],
48 | include_observed=True,
49 | include_occluded=True,
50 | )
51 | current_occupancy, current_valid = render_occupancy(current_sample, config, sdc_ids=None)
52 | #[num_agents]
53 | # print(current_valid.shape)
54 | current_valid = tf.reduce_max(tf.cast(current_valid,tf.int32),axis=-1)[:,0]
55 |
56 | timestep_grids.vehicles.current_occupancy = current_occupancy.vehicles
57 | timestep_grids.pedestrians.current_occupancy = current_occupancy.pedestrians
58 | timestep_grids.cyclists.current_occupancy = current_occupancy.cyclists
59 |
60 | past_sample = sample_func(
61 | times=['past'],
62 | include_observed=True,
63 | include_occluded=True,
64 | )
65 | past_occupancy,_ = render_occupancy(past_sample, config, sdc_ids=None)
66 | timestep_grids.vehicles.past_occupancy = past_occupancy.vehicles
67 | timestep_grids.pedestrians.past_occupancy = past_occupancy.pedestrians
68 | timestep_grids.cyclists.past_occupancy = past_occupancy.cyclists
69 |
70 | #[num_agents] presence in fov {AT Present}
71 | observed_valid = current_valid
72 |
73 | if not test:
74 | future_sample = sample_func(
75 | times=['future'],
76 | include_observed=True,
77 | include_occluded=False,
78 | )
79 | future_obs,_ = render_occupancy(future_sample, config, sdc_ids=None)
80 | timestep_grids.vehicles.future_observed_occupancy = future_obs.vehicles
81 | timestep_grids.pedestrians.future_observed_occupancy = future_obs.pedestrians
82 | timestep_grids.cyclists.future_observed_occupancy = future_obs.cyclists
83 |
84 |
85 | future_sample_occ = sample_func(
86 | times=['future'],
87 | include_observed=False,
88 | include_occluded=True,
89 | )
90 | future_occ, _ = render_occupancy(future_sample_occ, config, sdc_ids=None)
91 | # occluded_valid = tf.reduce_max(tf.cast(occ_valid,tf.int32),axis=-1)[:,0]
92 | timestep_grids.vehicles.future_occluded_occupancy = future_occ.vehicles
93 | timestep_grids.pedestrians.future_occluded_occupancy = future_occ.pedestrians
94 | timestep_grids.cyclists.future_occluded_occupancy = future_occ.cyclists
95 |
96 | # All occupancy for flow_origin_occupancy.
97 | if flow_origin or flow:
98 | all_sample = sample_func(
99 | times=['past', 'current', 'future'],
100 | include_observed=True,
101 | include_occluded=True,
102 | )
103 | if flow_origin:
104 | all_occupancy,_ = render_occupancy(all_sample, config, sdc_ids=None)
105 | timestep_grids.vehicles.all_occupancy = all_occupancy.vehicles
106 | timestep_grids.pedestrians.all_occupancy = all_occupancy.pedestrians
107 | timestep_grids.cyclists.all_occupancy = all_occupancy.cyclists
108 |
109 | # Flow.
110 | # NOTE: Since the future flow depends on the current and past timesteps, we
111 | # need to compute it from [past + current + future] sparse points.
112 | if flow:
113 | all_flow = render_flow_from_inputs(all_sample, config, sdc_ids=None)
114 | timestep_grids.vehicles.all_flow = all_flow.vehicles
115 | timestep_grids.pedestrians.all_flow = all_flow.pedestrians
116 | timestep_grids.cyclists.all_flow = all_flow.cyclists
117 |
118 | if sdc_ids is not None:
119 | all_sample = sample_func(
120 | times=['past', 'current', 'future'],
121 | include_observed=True,
122 | include_occluded=False,
123 | )
124 | ego_occupancy = render_ego_occupancy(all_sample, sdc_ids, config)
125 | return timestep_grids, observed_valid, ego_occupancy
126 |
127 |
128 | return timestep_grids, observed_valid
129 |
130 |
131 |
132 | def create_ground_truth_waypoint_grids(
133 | timestep_grids: TimestepGrids,
134 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig,
135 | flow_origin: bool=False,
136 | flow: bool=False,
137 | ) -> WaypointGrids:
138 | """Subsamples or aggregates future topdowns as ground-truth labels.
139 |
140 | Args:
141 | timestep_grids: Holds topdown renders of agents over time.
142 | config: OccupancyFlowTaskConfig proto message.
143 |
144 | Returns:
145 | WaypointGrids object.
146 | """
147 | if config.num_future_steps % config.num_waypoints != 0:
148 | raise ValueError(f'num_future_steps({config.num_future_steps}) must be '
149 | f'a multiple of num_waypoints({config.num_waypoints}).')
150 |
151 | true_waypoints = WaypointGrids(
152 | vehicles=_WaypointGridsOneType(
153 | observed_occupancy=[], occluded_occupancy=[], flow=[]),
154 | pedestrians=_WaypointGridsOneType(
155 | observed_occupancy=[], occluded_occupancy=[], flow=[]),
156 | cyclists=_WaypointGridsOneType(
157 | observed_occupancy=[], occluded_occupancy=[], flow=[]),
158 | )
159 |
160 | # Observed occupancy.
161 | _add_ground_truth_observed_occupancy_to_waypoint_grids(
162 | timestep_grids=timestep_grids,
163 | waypoint_grids=true_waypoints,
164 | config=config)
165 | # Occluded occupancy.
166 | _add_ground_truth_occluded_occupancy_to_waypoint_grids(
167 | timestep_grids=timestep_grids,
168 | waypoint_grids=true_waypoints,
169 | config=config)
170 | # Flow origin occupancy.
171 | if flow_origin:
172 | _add_ground_truth_flow_origin_occupancy_to_waypoint_grids(
173 | timestep_grids=timestep_grids,
174 | waypoint_grids=true_waypoints,
175 | config=config)
176 | # Flow.
177 | if flow:
178 | _add_ground_truth_flow_to_waypoint_grids(
179 | timestep_grids=timestep_grids,
180 | waypoint_grids=true_waypoints,
181 | config=config)
182 |
183 | return true_waypoints
184 |
185 | def _ego_ground_truth_occupancy(ego_occupancy, config):
186 | waypoint_size = config.num_future_steps // config.num_waypoints
187 | future_obs = ego_occupancy[..., config.num_past_steps + 1:]
188 | gt_ogm = []
189 | for k in range(config.num_waypoints):
190 | waypoint_end = (k + 1) * waypoint_size
191 | if config.cumulative_waypoints:
192 | waypoint_start = waypoint_end - waypoint_size
193 | # [batch_size, height, width, waypoint_size]
194 | segment = future_obs[..., waypoint_start:waypoint_end]
195 | # [batch_size, height, width, 1]
196 | waypoint_occupancy = tf.reduce_max(segment, axis=-1, keepdims=True)
197 | else:
198 | # [batch_size, height, width, 1]
199 | waypoint_occupancy = future_obs[..., waypoint_end - 1:waypoint_end]
200 | gt_ogm.append(waypoint_occupancy)
201 |
202 | return gt_ogm
203 |
204 |
205 | def _add_ground_truth_observed_occupancy_to_waypoint_grids(
206 | timestep_grids: TimestepGrids,
207 | waypoint_grids: WaypointGrids,
208 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig,
209 | ) -> None:
210 | """Subsamples or aggregates future topdowns as ground-truth labels.
211 |
212 | Args:
213 | timestep_grids: Holds topdown renders of agents over time.
214 | waypoint_grids: Holds topdown waypoints selected as ground-truth labels.
215 | config: OccupancyFlowTaskConfig proto message.
216 | """
217 | waypoint_size = config.num_future_steps // config.num_waypoints
218 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES:
219 | # [batch_size, height, width, num_future_steps]
220 | future_obs = timestep_grids.view(object_type).future_observed_occupancy
221 | for k in range(config.num_waypoints):
222 | waypoint_end = (k + 1) * waypoint_size
223 | if config.cumulative_waypoints:
224 | waypoint_start = waypoint_end - waypoint_size
225 | # [batch_size, height, width, waypoint_size]
226 | segment = future_obs[..., waypoint_start:waypoint_end]
227 | # [batch_size, height, width, 1]
228 | waypoint_occupancy = tf.reduce_max(segment, axis=-1, keepdims=True)
229 | else:
230 | # [batch_size, height, width, 1]
231 | waypoint_occupancy = future_obs[..., waypoint_end - 1:waypoint_end]
232 | waypoint_grids.view(object_type).observed_occupancy.append(
233 | waypoint_occupancy)
234 |
235 |
236 | def _add_ground_truth_occluded_occupancy_to_waypoint_grids(
237 | timestep_grids: TimestepGrids,
238 | waypoint_grids: WaypointGrids,
239 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig,
240 | ) -> None:
241 | """Subsamples or aggregates future topdowns as ground-truth labels.
242 |
243 | Args:
244 | timestep_grids: Holds topdown renders of agents over time.
245 | waypoint_grids: Holds topdown waypoints selected as ground-truth labels.
246 | config: OccupancyFlowTaskConfig proto message.
247 | """
248 | waypoint_size = config.num_future_steps // config.num_waypoints
249 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES:
250 | # [batch_size, height, width, num_future_steps]
251 | future_occ = timestep_grids.view(object_type).future_occluded_occupancy
252 | for k in range(config.num_waypoints):
253 | waypoint_end = (k + 1) * waypoint_size
254 | if config.cumulative_waypoints:
255 | waypoint_start = waypoint_end - waypoint_size
256 | # [batch_size, height, width, waypoint_size]
257 | segment = future_occ[..., waypoint_start:waypoint_end]
258 | # [batch_size, height, width, 1]
259 | waypoint_occupancy = tf.reduce_max(segment, axis=-1, keepdims=True)
260 | else:
261 | # [batch_size, height, width, 1]
262 | waypoint_occupancy = future_occ[..., waypoint_end - 1:waypoint_end]
263 | waypoint_grids.view(object_type).occluded_occupancy.append(
264 | waypoint_occupancy)
265 |
266 |
267 | def _add_ground_truth_flow_origin_occupancy_to_waypoint_grids(
268 | timestep_grids: TimestepGrids,
269 | waypoint_grids: WaypointGrids,
270 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig,
271 | ) -> None:
272 | """Subsamples or aggregates topdowns as origin occupancies for flow fields.
273 |
274 | Args:
275 | timestep_grids: Holds topdown renders of agents over time.
276 | waypoint_grids: Holds topdown waypoints selected as ground-truth labels.
277 | config: OccupancyFlowTaskConfig proto message.
278 | """
279 | waypoint_size = config.num_future_steps // config.num_waypoints
280 | num_history_steps = config.num_past_steps + 1 # Includes past + current.
281 | num_future_steps = config.num_future_steps
282 | if waypoint_size > num_history_steps:
283 | raise ValueError('If waypoint_size > num_history_steps, we cannot find the '
284 | 'flow origin occupancy for the first waypoint.')
285 |
286 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES:
287 | # [batch_size, height, width, num_past_steps + 1 + num_future_steps]
288 | all_occupancy = timestep_grids.view(object_type).all_occupancy
289 | # Keep only the section containing flow_origin_occupancy timesteps.
290 | # First remove `waypoint_size` from the end. Then keep the tail containing
291 | # num_future_steps timesteps.
292 | flow_origin_occupancy = all_occupancy[..., :-waypoint_size]
293 | # [batch_size, height, width, num_future_steps]
294 | flow_origin_occupancy = flow_origin_occupancy[..., -num_future_steps:]
295 | for k in range(config.num_waypoints):
296 | waypoint_end = (k + 1) * waypoint_size
297 | if config.cumulative_waypoints:
298 | waypoint_start = waypoint_end - waypoint_size
299 | # [batch_size, height, width, waypoint_size]
300 | segment = flow_origin_occupancy[..., waypoint_start:waypoint_end]
301 | # [batch_size, height, width, 1]
302 | waypoint_flow_origin = tf.reduce_max(segment, axis=-1, keepdims=True)
303 | else:
304 | # [batch_size, height, width, 1]
305 | waypoint_flow_origin = flow_origin_occupancy[..., waypoint_end -
306 | 1:waypoint_end]
307 | waypoint_grids.view(object_type).flow_origin_occupancy.append(
308 | waypoint_flow_origin)
309 |
310 |
311 | def _add_ground_truth_flow_to_waypoint_grids(
312 | timestep_grids: TimestepGrids,
313 | waypoint_grids: WaypointGrids,
314 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig,
315 | ) -> None:
316 | """Subsamples or aggregates future flow fields as ground-truth labels.
317 |
318 | Args:
319 | timestep_grids: Holds topdown renders of agents over time.
320 | waypoint_grids: Holds topdown waypoints selected as ground-truth labels.
321 | config: OccupancyFlowTaskConfig proto message.
322 | """
323 | num_future_steps = config.num_future_steps
324 | waypoint_size = config.num_future_steps // config.num_waypoints
325 |
326 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES:
327 | # num_flow_steps = (num_past_steps + num_futures_steps) - waypoint_size
328 | # [batch_size, height, width, num_flow_steps, 2]
329 | flow = timestep_grids.view(object_type).all_flow
330 | # Keep only the flow tail, containing num_future_steps timesteps.
331 | # [batch_size, height, width, num_future_steps, 2]
332 | flow = flow[..., -num_future_steps:, :]
333 | for k in range(config.num_waypoints):
334 | waypoint_end = (k + 1) * waypoint_size
335 | if config.cumulative_waypoints:
336 | waypoint_start = waypoint_end - waypoint_size
337 | # [batch_size, height, width, waypoint_size, 2]
338 | segment = flow[..., waypoint_start:waypoint_end, :]
339 | # Compute mean flow over the timesteps in this segment by counting
340 | # the number of pixels with non-zero flow and dividing the flow sum
341 | # by that number.
342 | # [batch_size, height, width, waypoint_size, 2]
343 | occupied_pixels = tf.cast(tf.not_equal(segment, 0.0), tf.float32)
344 | # [batch_size, height, width, 2]
345 | num_flow_values = tf.reduce_sum(occupied_pixels, axis=3)
346 | # [batch_size, height, width, 2]
347 | segment_sum = tf.reduce_sum(segment, axis=3)
348 | # [batch_size, height, width, 2]
349 | mean_flow = tf.math.divide_no_nan(segment_sum, num_flow_values)
350 | waypoint_flow = mean_flow
351 | else:
352 | waypoint_flow = flow[..., waypoint_end - 1, :]
353 | waypoint_grids.view(object_type).flow.append(waypoint_flow)
--------------------------------------------------------------------------------
/model/encoder.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 |
4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.append(os.path.dirname(SCRIPT_DIR))
6 |
7 | import math
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from time import time
12 |
13 | class PositionalEncoding(nn.Module):
14 | def __init__(self, d_model=256, dropout=0.1, max_len=100):
15 | super(PositionalEncoding, self).__init__()
16 | position = torch.arange(max_len).unsqueeze(1)
17 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
18 | pe = torch.zeros(max_len, 1, d_model)
19 | pe[:, 0, 0::2] = torch.sin(position * div_term)
20 | pe[:, 0, 1::2] = torch.cos(position * div_term)
21 | pe = pe.permute(1, 0, 2)
22 | self.register_buffer('pe', pe)
23 | self.dropout = nn.Dropout(p=dropout)
24 |
25 | def forward(self, x):
26 | x = x + self.pe
27 | return self.dropout(x)
28 |
29 | class AgentEncoder(nn.Module):
30 | def __init__(self):
31 | super(AgentEncoder, self).__init__()
32 | self.motion = nn.LSTM(9, 256, 2, batch_first=True)
33 | self.type_embed = nn.Embedding(3, 256, padding_idx=0)
34 |
35 | def forward(self, inputs):
36 | traj, _ = self.motion(inputs[...,:-1])
37 | types = inputs[...,0,-1].int().clamp(0, 2)
38 | types = self.type_embed(types)
39 | output = traj[:, -1] + types
40 | return output
41 |
42 | class LaneEncoder(nn.Module):
43 | def __init__(self):
44 | super(LaneEncoder, self).__init__()
45 | # encdoer layer
46 | self.self_line = nn.Linear(3, 128)
47 | self.left_line = nn.Linear(3, 128)
48 | self.right_line = nn.Linear(3, 128)
49 | self.speed_limit = nn.Linear(1, 64)
50 | self.self_type = nn.Embedding(4, 64, padding_idx=0)
51 | self.left_type = nn.Embedding(11, 64, padding_idx=0)
52 | self.right_type = nn.Embedding(11, 64, padding_idx=0)
53 | self.traffic_light_type = nn.Embedding(9, 64, padding_idx=0)
54 | self.interpolating = nn.Embedding(2, 64)
55 | self.stop_sign = nn.Embedding(2, 64)
56 | self.stop_point = nn.Embedding(2, 64)
57 |
58 | # hidden layers
59 | self.pointnet = nn.Sequential(nn.Linear(512, 384), nn.ReLU(), nn.Linear(384, 256))
60 | self.position_encode = PositionalEncoding(max_len=100)
61 |
62 | def forward(self, inputs):
63 | # embedding
64 | self_line = self.self_line(inputs[..., :3])
65 | left_line = self.left_line(inputs[..., 3:6])
66 | right_line = self.right_line(inputs[..., 6:9])
67 | speed_limit = self.speed_limit(inputs[..., 9].unsqueeze(-1))
68 | self_type = self.self_type(inputs[..., 10].int().clamp(0, 3))
69 | left_type = self.left_type(inputs[..., 11].int().clamp(0, 10))
70 | right_type = self.right_type(inputs[..., 12].int().clamp(0, 10))
71 | traffic_light = self.traffic_light_type(inputs[..., 13].int().clamp(0, 8))
72 | stop_point = self.stop_point(inputs[..., 14].int().clamp(0, 1))
73 | interpolating = self.interpolating(inputs[..., 15].int().clamp(0, 1))
74 | stop_sign = self.stop_sign(inputs[..., 16].int().clamp(0, 1))
75 |
76 | lane_attr = self_type + left_type + right_type + traffic_light + stop_point + interpolating + stop_sign
77 | lane_embedding = torch.cat([self_line, left_line, right_line, speed_limit, lane_attr], dim=-1)
78 |
79 | # process
80 | output = self.pointnet(lane_embedding)
81 | output = self.position_encode(output)
82 |
83 | return output
84 |
85 | class NeighborLaneEncoder(nn.Module):
86 | def __init__(self):
87 | super(NeighborLaneEncoder, self).__init__()
88 | # encdoer layer
89 | self.self_line = nn.Linear(3, 128)
90 | self.speed_limit = nn.Linear(1, 64)
91 | self.traffic_light_type = nn.Embedding(9, 64, padding_idx=0)
92 |
93 | # hidden layers
94 | self.pointnet = nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256))
95 | self.position_encode = PositionalEncoding(max_len=50)
96 |
97 | def forward(self, inputs):
98 | # embedding
99 | self_line = self.self_line(inputs[..., :3])
100 | speed_limit = self.speed_limit(inputs[..., 3].unsqueeze(-1))
101 | traffic_light = self.traffic_light_type(inputs[..., 4].int().clamp(0, 8))
102 | lane_embedding = torch.cat([self_line, speed_limit, traffic_light], dim=-1)
103 |
104 | # process
105 | output = self.pointnet(lane_embedding)
106 | output = self.position_encode(output)
107 |
108 | return output
109 |
110 | class CrosswalkEncoder(nn.Module):
111 | def __init__(self):
112 | super(CrosswalkEncoder, self).__init__()
113 | self.pointnet = nn.Sequential(
114 | nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 256)
115 | )
116 | self.position_encode = PositionalEncoding(max_len=50)
117 |
118 | def forward(self, inputs):
119 | output = self.pointnet(inputs)
120 | output = self.position_encode(output)
121 | return output
122 |
123 | class VectorEncoder(nn.Module):
124 | def __init__(self, dim=256, layers=4, heads=8, dropout=0.1):
125 | super(VectorEncoder, self).__init__()
126 |
127 | self.ego_encoder = AgentEncoder()
128 | self.neighbor_encoder = AgentEncoder()
129 |
130 | # self.map_encoder = NeighborLaneEncoder()
131 | self.ego_map_encoder = LaneEncoder()
132 | self.crosswalk_encoder = CrosswalkEncoder()
133 |
134 | attention_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4,
135 | activation='gelu', dropout=dropout, batch_first=True)
136 | self.fusion_encoder = nn.TransformerEncoder(attention_layer, layers)
137 |
138 | def segment_map(self, map, map_encoding):
139 | B, N_r, N_p, D = map_encoding.shape
140 | map_encoding = F.max_pool2d(map_encoding.permute(0, 3, 1, 2), kernel_size=(1, 10))
141 | map_encoding = map_encoding.permute(0, 2, 3, 1).reshape(B, -1, D)
142 |
143 | map_mask = torch.eq(map, 0)[:, :, :, 0].reshape(B, N_r, N_p//10, -1)
144 | map_mask = torch.max(map_mask, dim=-1)[0].reshape(B, -1)
145 | map_mask[:, 0] = False # prevent nan
146 |
147 | return map_encoding, map_mask
148 |
149 | def forward(self, inputs):
150 |
151 | ego = inputs['ego_state']
152 | neighbor = inputs['neighbor_state']
153 |
154 | actors = torch.cat([ego.unsqueeze(1), neighbor], dim=1)
155 | actors_mask = torch.eq(actors, 0)[:, :, -1, 0]
156 | actors_mask[:, 0] = False
157 |
158 | ego = self.ego_encoder(ego)
159 | B, N, T, D = neighbor.shape
160 | neighbor = self.neighbor_encoder(neighbor.reshape(B*N, T, D))
161 | neighbor = neighbor.reshape(B, N, -1)
162 |
163 | encode_actors = torch.cat([ego.unsqueeze(1), neighbor], dim=1)
164 | B,N,C = encode_actors.shape
165 |
166 | ego_maps = inputs['ego_map_lane']
167 | ego_encode_maps = self.ego_map_encoder(ego_maps)
168 | B, M, L, D = ego_maps.shape
169 | ego_encode_maps, ego_map_mask = self.segment_map(ego_maps, ego_encode_maps) #(B*N,N_map,D)
170 | encode_maps = ego_encode_maps
171 | map_mask = ego_map_mask
172 |
173 | crosswalks = inputs['ego_map_crosswalk']
174 | encode_cws = self.crosswalk_encoder(crosswalks)
175 | B, M, L, D = crosswalks.shape
176 | encode_cws, cw_mask = self.segment_map(crosswalks, encode_cws)
177 |
178 | encode_maps = torch.cat([encode_maps, encode_cws], dim=1)
179 | map_mask = torch.cat([map_mask, cw_mask], dim=1)
180 |
181 | encode_inputs = torch.cat([encode_actors, encode_maps], dim=1) #(B*N,N + N_map + N_cw, D)
182 | encode_masks = torch.cat([actors_mask, map_mask], dim=1) #(B*N,N + N_map + N_cw)
183 | encode_masks[:,0] = False
184 |
185 | encodings = self.fusion_encoder(encode_inputs ,src_key_padding_mask=encode_masks)
186 |
187 | _, L, D = encodings.shape
188 | N = 1
189 | encodings = encodings.reshape(B, N, L, D)
190 | encode_masks = encode_masks.reshape(B, N, L)
191 |
192 | encoder_outputs = {
193 | 'actors': actors,
194 | 'encodings': encodings,
195 | 'masks': encode_masks
196 | }
197 |
198 | encoder_outputs.update(inputs)
199 | return encoder_outputs
200 |
201 |
202 | class PredLaneEncoder(nn.Module):
203 | def __init__(self):
204 | super(PredLaneEncoder, self).__init__()
205 | # encdoer layer
206 | self.self_line = nn.Linear(3, 128)
207 | self.map_flow = nn.Linear(3, 128)
208 |
209 | self.speed_limit = nn.Linear(1, 64)
210 | self.traffic_light_type = nn.Embedding(4, 64, padding_idx=0)
211 | self.self_type = nn.Embedding(20, 64, padding_idx=0)
212 | self.stop_sign = nn.Linear(1, 64)
213 |
214 | # hidden layers
215 | self.pointnet = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 256))
216 | self.position_encode = PositionalEncoding(max_len=20)
217 |
218 | def forward(self, inputs):
219 | # embedding
220 | self_line = self.self_line(inputs[..., :3])
221 |
222 | road_type = self.self_type(inputs[..., 3].int().clamp(0, 19))
223 | traffic_light = self.traffic_light_type(inputs[..., 4].int().clamp(0, 3))
224 | stop_sign = self.stop_sign(inputs[..., 6].unsqueeze(-1))
225 | sp_limit = self.speed_limit(inputs[..., 7].unsqueeze(-1))
226 |
227 | map_flow_feat = self.map_flow(inputs[..., -3:])
228 | lane_feat = road_type + traffic_light + stop_sign
229 |
230 | lane_embedding = torch.cat([self_line, map_flow_feat, lane_feat, sp_limit], dim=-1)
231 |
232 | # process
233 | output = self.pointnet(lane_embedding)
234 | output = self.position_encode(output)
235 |
236 | #max pooling:
237 | output = torch.max(output, dim=-2).values
238 |
239 | return output
240 |
241 | class PredEncoder(nn.Module):
242 | def __init__(self, dim=256, layers=4, heads=8, dropout=0.1,use_map=True):
243 | super(PredEncoder, self).__init__()
244 | self.ego_encoder = AgentEncoder()
245 | self.neighbor_encoder = AgentEncoder()
246 | self.use_map = use_map
247 | if use_map:
248 | self.map_encoder = PredLaneEncoder()
249 |
250 | attention_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4,
251 | activation='gelu', dropout=dropout, batch_first=True)
252 | self.fusion_encoder = nn.TransformerEncoder(attention_layer, layers, enable_nested_tensor=False)
253 | print('vec_encoder',sum([p.numel() for p in self.parameters()]))
254 |
255 | def forward(self, inputs):
256 | ego = inputs['ego_state']
257 | neighbor = inputs['neighbor_state']
258 |
259 | actors = torch.cat([ego.unsqueeze(1), neighbor], dim=1)
260 | actors_mask = torch.eq(actors, 0)[:, :, -1, 0]
261 | actors_mask[:, 0] = False
262 |
263 | ego = self.ego_encoder(ego)
264 | B, N, T, D = neighbor.shape
265 | neighbor = self.neighbor_encoder(neighbor.reshape(B*N, T, D))
266 | neighbor = neighbor.reshape(B, N, -1)
267 |
268 | encode_actors = torch.cat([ego.unsqueeze(1), neighbor], dim=1)
269 | B,N,C = encode_actors.shape
270 |
271 | if self.use_map:
272 | maps = inputs['map_segs']
273 | map_mask = torch.eq(maps, 0)[:, :, 0, 0]
274 | maps = self.map_encoder(maps)
275 | encode_actors = torch.cat([encode_actors, maps], dim=1)
276 | encode_masks = torch.cat([actors_mask, map_mask], dim=1)
277 | else:
278 | encode_masks = actors_mask
279 |
280 | encodings = self.fusion_encoder(encode_actors ,src_key_padding_mask=encode_masks)
281 |
282 | encoder_outputs = {
283 | 'actors': actors,
284 | 'encodings': encodings,
285 | 'masks': encode_masks
286 | }
287 |
288 | encoder_outputs.update(inputs)
289 | return encoder_outputs
290 |
291 |
292 | from .swin_T import PredSwinTransformerV2
293 |
294 |
295 | class OGMFlowEncoder(nn.Module):
296 | def __init__(self,sep_flow=False,large_scale=False):
297 | super(OGMFlowEncoder, self).__init__()
298 |
299 | self.map_type_encoder = nn.Embedding(num_embeddings=20, embedding_dim=64 ,padding_idx=0)
300 | self.tl_encoder = nn.Embedding(num_embeddings=4, embedding_dim=64, padding_idx=0)
301 | self.map_encoder = nn.Linear(1, 64)
302 |
303 | self.ogm_rg_encoder = nn.Linear(33 if large_scale else 11, 128)
304 | self.sep_flow = sep_flow
305 | if not self.sep_flow:
306 | self.flow_encoder = nn.Linear(2, 128)
307 | self.offset_encoder = nn.Linear(2, 64)
308 |
309 | self.point_net = nn.Sequential(nn.Linear((384 if not sep_flow else 256), 192), nn.ReLU(), nn.Linear(192, 96))
310 | print('flow_encoder',sum([p.numel() for p in self.parameters()]))
311 |
312 | def forward(self, inputs, offsets):
313 | hist_ogm = inputs['hist_ogm']
314 | hist_flow = inputs['hist_flow']
315 | road_graph = inputs['road_graph']
316 |
317 | rg, road_type, traffic = road_graph[...,0:1], road_graph[..., 1].int(), road_graph[..., 2].int()
318 |
319 | hist_ogm = self.ogm_rg_encoder(hist_ogm)
320 | if not self.sep_flow:
321 | hist_flow = self.flow_encoder(hist_flow)
322 |
323 | offsets = self.offset_encoder(offsets)
324 | maps = self.map_encoder(rg) + self.tl_encoder(traffic.clamp(0, 3)) + self.map_type_encoder(road_type.clamp(0, 19))
325 | maps = maps
326 |
327 | if self.sep_flow:
328 | mm_inputs = torch.cat([hist_ogm, maps, offsets], dim=-1)
329 | else:
330 | mm_inputs = torch.cat([hist_ogm, hist_flow, maps, offsets], dim=-1)
331 | mm_inputs = self.point_net(mm_inputs)
332 | mm_inputs = mm_inputs.permute(0, 3, 1, 2)
333 | return mm_inputs
334 |
335 | class VisualEncoder(PredSwinTransformerV2):
336 | def __init__(self,
337 | config,
338 | input_resolution=(512, 512),
339 | embedding_channels=96,
340 | window_size=8,
341 | in_channels=96,
342 | patch_size=4,
343 | use_checkpoint=False,
344 | sequential_self_attention=False,
345 | use_deformable_block=True,
346 | large_scale=False,
347 | **kwargs):
348 | super(VisualEncoder, self).__init__(input_resolution=input_resolution,
349 | window_size=window_size,
350 | in_channels=in_channels,
351 | use_checkpoint=use_checkpoint,
352 | sequential_self_attention=sequential_self_attention,
353 | embedding_channels=embedding_channels,
354 | patch_size=patch_size,
355 | depths=(2, 2, 2),
356 | number_of_heads=(6, 12, 24),
357 | use_deformable_block=use_deformable_block,
358 | large_scale=large_scale,
359 | **kwargs)
360 |
361 | self.mm_encoder = OGMFlowEncoder(sep_flow=True,large_scale=large_scale)
362 |
363 | self.config = config
364 | self.input_resolution = input_resolution
365 | self._make_position_bias_input()
366 |
367 | self.init_crop = input_resolution[0] / patch_size
368 | print('swin_encoder',sum([p.numel() for p in self.parameters()]))
369 |
370 | def half_cropping(self, tensor, stage=0):
371 | cropped_len = int(self.init_crop / (2 ** stage))
372 | begin, end = int(cropped_len / 4), int(3 * cropped_len / 4)
373 | return tensor[:, :, begin:end, begin:end]
374 |
375 | def crop_output_list(self, output_list):
376 | return [self.half_cropping(outputs, i) for i, outputs in enumerate(output_list)]
377 |
378 | def _make_position_bias_input(self):
379 | device = self.stages[0].blocks[0].window_attention.tau.device
380 | indexes: torch.Tensor = torch.arange(1, self.config.grid_height_cells*2 + 1, device=device)
381 | widths_indexes = - (indexes - (self.config.sdc_x_in_grid + self.config.grid_height_cells) - 0.5) / self.config.pixels_per_meter
382 | heights_indexes = - (indexes - (self.config.sdc_y_in_grid + self.config.grid_width_cells) - 0.5) / self.config.pixels_per_meter
383 | #correspnding (x,y) in dense coordinates
384 | coordinates: torch.Tensor = torch.stack(torch.meshgrid([heights_indexes, widths_indexes]), dim=-1)
385 | self.register_buffer('input_bias', coordinates)
386 |
387 | def forward(self, inputs):
388 | coordinate_bias = self.input_bias #[b, h, w, 2]
389 | b = inputs['hist_ogm'].shape[0]
390 | offsets = coordinate_bias.unsqueeze(0).expand(b, -1, -1, -1)
391 | visual_inputs = self.mm_encoder(inputs, offsets)
392 | outputs_list, flow = super(VisualEncoder, self).forward(visual_inputs, inputs['hist_flow'].permute(0, 3, 1, 2))
393 | outputs_list = self.crop_output_list(outputs_list)
394 | return outputs_list, None
--------------------------------------------------------------------------------
/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | from shapely.geometry import LineString, Point, Polygon
4 | from shapely.affinity import affine_transform, rotate
5 |
6 | def wrap_to_pi(theta):
7 | #[0, 2pi] ->[0, pi,-pi, 0]
8 | return (theta+np.pi) % (2*np.pi) - np.pi
9 |
10 | def compute_direction_diff(ego_theta, target_theta):
11 | delta = np.abs(ego_theta - target_theta)
12 | delta = np.where(delta > np.pi, 2*np.pi - delta, delta)
13 |
14 | return delta
15 |
16 | def depth_first_search(cur_lane, lanes, dist=0, threshold=100):
17 | """
18 | Perform depth first search over lane graph up to the threshold.
19 | Args:
20 | cur_lane: Starting lane_id
21 | lanes: raw lane data
22 | dist: Distance of the current path
23 | threshold: Threshold after which to stop the search
24 | Returns:
25 | lanes_to_return (list of list of integers): List of sequence of lane ids
26 | """
27 | if dist > threshold:
28 | return [[cur_lane]]
29 | else:
30 | traversed_lanes = []
31 | child_lanes = lanes[cur_lane].exit_lanes
32 |
33 | if child_lanes:
34 | for child in child_lanes:
35 | centerline = np.array([(map_point.x, map_point.y, map_point.z) for map_point in lanes[child].polyline])
36 | cl_length = centerline.shape[0]
37 | curr_lane_ids = depth_first_search(child, lanes, dist + cl_length, threshold)
38 | traversed_lanes.extend(curr_lane_ids)
39 |
40 | if len(traversed_lanes) == 0:
41 | return [[cur_lane]]
42 |
43 | lanes_to_return = []
44 |
45 | for lane_seq in traversed_lanes:
46 | lanes_to_return.append([cur_lane] + lane_seq)
47 |
48 | return lanes_to_return
49 |
50 | def is_overlapping_lane_seq(lane_seq1, lane_seq2):
51 | """
52 | Check if the 2 lane sequences are overlapping.
53 | Args:
54 | lane_seq1: list of lane ids
55 | lane_seq2: list of lane ids
56 | Returns:
57 | bool, True if the lane sequences overlap
58 | """
59 |
60 | if lane_seq2[1:] == lane_seq1[1:]:
61 | return True
62 | elif set(lane_seq2) <= set(lane_seq1):
63 | return True
64 |
65 | return False
66 |
67 | def remove_overlapping_lane_seq(lane_seqs):
68 | """
69 | Remove lane sequences which are overlapping to some extent
70 | Args:
71 | lane_seqs (list of list of integers): List of list of lane ids (Eg. [[12345, 12346, 12347], [12345, 12348]])
72 | Returns:
73 | List of sequence of lane ids (e.g. ``[[12345, 12346, 12347], [12345, 12348]]``)
74 | """
75 | redundant_lane_idx = set()
76 |
77 | for i in range(len(lane_seqs)):
78 | for j in range(len(lane_seqs)):
79 | if i in redundant_lane_idx or i == j:
80 | continue
81 | if is_overlapping_lane_seq(lane_seqs[i], lane_seqs[j]):
82 | redundant_lane_idx.add(j)
83 |
84 | unique_lane_seqs = [lane_seqs[i] for i in range(len(lane_seqs)) if i not in redundant_lane_idx]
85 |
86 | return unique_lane_seqs
87 |
88 | def polygon_completion(polygon):
89 | polyline_x = []
90 | polyline_y = []
91 |
92 | for i in range(len(polygon)):
93 | if i+1 < len(polygon):
94 | next = i+1
95 | else:
96 | next = 0
97 |
98 | dist_x = polygon[next].x - polygon[i].x
99 | dist_y = polygon[next].y - polygon[i].y
100 | dist = np.linalg.norm([dist_x, dist_y])
101 | interp_num = np.ceil(dist)*2
102 | interp_index = np.arange(2+interp_num)
103 | point_x = np.interp(interp_index, [0, interp_index[-1]], [polygon[i].x, polygon[next].x]).tolist()
104 | point_y = np.interp(interp_index, [0, interp_index[-1]], [polygon[i].y, polygon[next].y]).tolist()
105 | polyline_x.extend(point_x[:-1])
106 | polyline_y.extend(point_y[:-1])
107 |
108 | polyline_x, polyline_y = np.array(polyline_x), np.array(polyline_y)
109 | polyline_heading = wrap_to_pi(np.arctan2(polyline_y[1:]-polyline_y[:-1], polyline_x[1:]-polyline_x[:-1]))
110 | polyline_heading = np.insert(polyline_heading, -1, polyline_heading[-1])
111 |
112 | return np.stack([polyline_x, polyline_y, polyline_heading], axis=1)
113 |
114 | def get_polylines(lines):
115 | polylines = {}
116 |
117 | for line in lines.keys():
118 | polyline = np.array([(map_point.x, map_point.y) for map_point in lines[line].polyline])
119 | if len(polyline) > 1:
120 | direction = wrap_to_pi(np.arctan2(polyline[1:, 1]-polyline[:-1, 1], polyline[1:, 0]-polyline[:-1, 0]))
121 | direction = np.insert(direction, -1, direction[-1])[:, np.newaxis]
122 | else:
123 | direction = np.array([0])[:, np.newaxis]
124 | polylines[line] = np.concatenate([polyline, direction], axis=-1)
125 |
126 | return polylines
127 |
128 | def find_reference_lanes(agent_type, agent_traj, lanes):
129 | curr_lane_ids = {}
130 |
131 | if agent_type == 2:
132 | distance_threshold = 5
133 |
134 | while len(curr_lane_ids) < 1:
135 | for lane in lanes.keys():
136 | if lanes[lane].shape[0] > 1:
137 | distance_to_agent = LineString(lanes[lane][:, :2]).distance(Point(agent_traj[-1, :2]))
138 | if distance_to_agent < distance_threshold:
139 | curr_lane_ids[lane] = 0
140 |
141 | distance_threshold += 5
142 | if distance_threshold > 50:
143 | break
144 | else:
145 | distance_threshold = 3.5
146 | direction_threshold = 10
147 | while len(curr_lane_ids) < 1:
148 | for lane in lanes.keys():
149 | distance_to_ego = np.linalg.norm(agent_traj[-1, :2] - lanes[lane][:, :2], axis=-1)
150 | direction_to_ego = wrap_to_pi(agent_traj[-1, 2] - lanes[lane][:, -1])
151 | for i, j, k in zip(distance_to_ego, direction_to_ego, range(distance_to_ego.shape[0])):
152 | if i <= distance_threshold :#and np.abs(j) <= np.radians(direction_threshold):
153 | curr_lane_ids[lane] = k
154 | break
155 |
156 | distance_threshold += 3.5
157 | direction_threshold += 10
158 | if distance_threshold > 50:
159 | break
160 |
161 | return curr_lane_ids
162 |
163 | def find_neighbor_lanes(curr_lane_ids, traj, lanes, lane_polylines):
164 | neighbor_lane_ids = {}
165 |
166 | for curr_lane, start in curr_lane_ids.items():
167 | left_lanes = lanes[curr_lane].left_neighbors
168 | right_lanes = lanes[curr_lane].right_neighbors
169 | left_lane = None
170 | right_lane = None
171 | curr_index = start
172 |
173 | for l_lane in left_lanes:
174 | if l_lane.self_start_index <= curr_index <= l_lane.self_end_index and not l_lane.feature_id in curr_lane_ids:
175 | left_lane = l_lane
176 |
177 | for r_lane in right_lanes:
178 | if r_lane.self_start_index <= curr_index <= r_lane.self_end_index and not r_lane.feature_id in curr_lane_ids:
179 | right_lane = r_lane
180 |
181 | if left_lane is not None:
182 | left_polyline = lane_polylines[left_lane.feature_id]
183 | start = np.argmin(np.linalg.norm(traj[-1, :2] - left_polyline[:, :2], axis=-1))
184 | neighbor_lane_ids[left_lane.feature_id] = start
185 |
186 | if right_lane is not None:
187 | right_polyline = lane_polylines[right_lane.feature_id]
188 | start = np.argmin(np.linalg.norm(traj[-1, :2] - right_polyline[:, :2], axis=-1))
189 | neighbor_lane_ids[right_lane.feature_id] = start
190 |
191 | return neighbor_lane_ids
192 |
193 | def find_neareast_point(curr_point, line):
194 | distance_to_curr_point = np.linalg.norm(curr_point[np.newaxis, :2] - line[:, :2], axis=-1)
195 | neareast_point = line[np.argmin(distance_to_curr_point)]
196 |
197 | return neareast_point
198 |
199 | def find_map_waypoint(pos, polylines):
200 | waypoint = [-1, -1, 1e9, 1e9]
201 | direction_threshold = 10
202 |
203 | for id, polyline in polylines.items():
204 | distance_to_gt = np.linalg.norm(pos[np.newaxis, :2] - polyline[:, :2], axis=-1)
205 | direction_to_gt = wrap_to_pi(pos[np.newaxis, 2]-polyline[:, 2])
206 |
207 | for i, j, k in zip(range(polyline.shape[0]), distance_to_gt, direction_to_gt):
208 | if j < waypoint[2] and np.abs(k) <= np.radians(direction_threshold):
209 | waypoint = [id, i, j, k]
210 |
211 | lane_id = waypoint[0]
212 | waypoint_id = waypoint[1]
213 |
214 | if lane_id > 0:
215 | return lane_id, waypoint_id
216 | else:
217 | return None, None
218 |
219 | from .plan_utils import generate_target_course
220 |
221 | def ref_line_norm(ref_line, center, angle):
222 | xy = LineString(ref_line[:, 0:2])
223 | xy = affine_transform(xy, [1, 0, 0, 1, -center[0], -center[1]])
224 | xy = rotate(xy, -angle, origin=(0, 0), use_radians=True)
225 | yaw = wrap_to_pi(ref_line[:, 2] - angle)
226 | c = ref_line[:, 3]
227 | info = ref_line[:, 4]
228 | return np.column_stack((xy.coords, yaw, c, info))
229 |
230 | def find_route(traj, timestep, cur_pos, map_lanes, map_crosswalks, map_signals):
231 | cur_pos = np.array(cur_pos)
232 | lane_polylines = get_polylines(map_lanes)
233 | # print(len(lane_polylines.keys()), cur_pos)
234 | end_lane, end_point = find_map_waypoint(np.array((traj[-1].center_x, traj[-1].center_y, traj[-1].heading)), lane_polylines)
235 | # print(end_lane, end_point)
236 | start_lane, start_point = find_map_waypoint(np.array((traj[0].center_x, traj[0].center_y, traj[0].heading)), lane_polylines)
237 | # print(start_lane, start_point, cur_pos)
238 | cur_lane, _ = find_map_waypoint(cur_pos, lane_polylines)
239 | # print(end_lane, start_lane, cur_lane)
240 |
241 | path_waypoints = []
242 | for t in range(0, len(traj), 10):
243 | lane, point = find_map_waypoint(np.array((traj[t].center_x, traj[t].center_y, traj[t].heading)), lane_polylines)
244 | path_waypoints.append(lane_polylines[lane][point])
245 | # print(path_waypoints)
246 |
247 | before_waypoints = []
248 | if start_point < 40:
249 | if map_lanes[start_lane].entry_lanes:
250 | lane = map_lanes[start_lane].entry_lanes[0]
251 | for waypoint in lane_polylines[lane]:
252 | before_waypoints.append(waypoint)
253 | for waypoint in lane_polylines[start_lane][:start_point]:
254 | before_waypoints.append(waypoint)
255 |
256 | after_waypoints = []
257 | for waypoint in lane_polylines[end_lane][end_point:]:
258 | after_waypoints.append(waypoint)
259 | if len(after_waypoints) < 40:
260 | if map_lanes[end_lane].exit_lanes:
261 | lane = map_lanes[end_lane].exit_lanes[0]
262 | for waypoint in lane_polylines[lane]:
263 | after_waypoints.append(waypoint)
264 |
265 | waypoints = np.concatenate([before_waypoints[::5], path_waypoints, after_waypoints[::5]], axis=0)
266 |
267 | # generate smooth route
268 | tx, ty, tyaw, tc, _ = generate_target_course(waypoints[:, 0], waypoints[:, 1])
269 | ref_line = np.column_stack([tx, ty, tyaw, tc])
270 |
271 | # print(ref_line.shape)
272 |
273 | # get reference path at current timestep
274 | current_location = np.argmin(np.linalg.norm(ref_line[:, :2] - cur_pos[np.newaxis, :2], axis=-1))
275 | start_index = np.max([current_location-200, 0])
276 |
277 | ref_line = ref_line[start_index:start_index+1200]
278 |
279 | # add speed limit, crosswalk, and traffic signal info to ref route
280 | line_info = np.zeros(shape=(ref_line.shape[0], 1))
281 | speed_limit = map_lanes[cur_lane].speed_limit_mph / 2.237
282 | ref_line = np.concatenate([ref_line, line_info], axis=-1)
283 | crosswalks = [Polygon([(point.x, point.y) for point in crosswalk.polygon]) for _, crosswalk in map_crosswalks.items()]
284 | signals = [Point([signal.stop_point.x, signal.stop_point.y]) for signal in map_signals[timestep].lane_states if signal.state in [1, 4, 7]]
285 |
286 | for i in range(ref_line.shape[0]):
287 | if any([Point(ref_line[i, :2]).distance(signal) < 0.2 for signal in signals]):
288 | ref_line[i, 4] = 0 # red light
289 | elif any([crosswalk.contains(Point(ref_line[i, :2])) for crosswalk in crosswalks]):
290 | ref_line[i, 4] = 1 # crosswalk
291 | else:
292 | ref_line[i, 4] = speed_limit
293 |
294 | return ref_line
295 |
296 | def imputer(traj):
297 | x, y, v_x, v_y, theta = traj[:, 0], traj[:, 1], traj[:, 3], traj[:, 4], traj[:, 2]
298 |
299 | if np.any(x==0):
300 | for i in reversed(range(traj.shape[0])):
301 | if x[i] == 0 and i!= traj.shape[0]-1:
302 | v_x[i] = v_x[i+1]
303 | v_y[i] = v_y[i+1]
304 | x[i] = x[i+1] - v_x[i]*0.1
305 | y[i] = y[i+1] - v_y[i]*0.1
306 | theta[i] = theta[i+1]
307 | return np.column_stack((x, y, theta, v_x, v_y))
308 | else:
309 | return np.column_stack((x, y, theta, v_x, v_y))
310 |
311 | def goal_norm(goal, center, angle):
312 | x = goal[0] - center[0]
313 | y = goal[1] - center[1]
314 | new_x = x * np.cos(angle) + y * np.sin(angle)
315 | new_y = -x * np.sin(angle) + y * np.cos(angle)
316 | return np.array([new_x, new_y])
317 |
318 | def agent_norm(traj, center, angle, impute=False):
319 | if impute:
320 | traj = imputer(traj[:, :5])
321 |
322 | x, y = traj[:, 0] - center[0], traj[:, 1] - center[1]
323 | # angle = np.pi/ 2 - angle
324 | tx = np.cos(angle) * x + np.sin(angle) * y
325 | ty = -np.sin(angle) * x + np.cos(angle) * y
326 | line_rotate = np.stack([tx, ty], axis=-1)
327 | line_rotate[traj[:, :2]==0] = 0
328 | # line_rotate[traj[:, 1]==0, 1] = 0
329 |
330 | heading = wrap_to_pi(traj[:, 4] - angle)
331 | heading[traj[:, 4]==0] = 0
332 |
333 | if traj.shape[-1] > 3:
334 | velocity_x = traj[:, 2] * np.cos(angle) + traj[:, 3] * np.sin(angle)
335 | velocity_x[traj[:, 2]==0] = 0
336 | velocity_y = traj[:, 3] * np.cos(angle) - traj[:, 2] * np.sin(angle)
337 | velocity_y[traj[:, 3]==0] = 0
338 | return np.column_stack((line_rotate, heading, velocity_x, velocity_y))
339 | else:
340 | return np.column_stack((line_rotate, heading))
341 |
342 | def agent_norm_left(traj, center, angle, impute=False):
343 | if impute:
344 | traj = imputer(traj[:, :5])
345 |
346 | x, y = traj[:, 0] - center[0], traj[:, 1] - center[1]
347 | angle = np.pi/ 2 - angle
348 | tx = np.cos(angle) * x - np.sin(angle) * y
349 | ty = np.sin(angle) * x + np.cos(angle) * y
350 | line_rotate = np.stack([tx, ty], axis=-1)
351 | line_rotate[traj[:, :2]==0] = 0
352 | # line_rotate[traj[:, 1]==0, 1] = 0
353 |
354 | heading = wrap_to_pi(traj[:, 4] - angle)
355 | heading[traj[:, 4]==0] = 0
356 |
357 | if traj.shape[-1] > 3:
358 | velocity_x = traj[:, 2] * np.cos(angle) - traj[:, 3] * np.sin(angle)
359 | velocity_x[traj[:, 2]==0] = 0
360 | velocity_y = traj[:, 3] * np.cos(angle) + traj[:, 2] * np.sin(angle)
361 | velocity_y[traj[:, 3]==0] = 0
362 | return np.column_stack((line_rotate, heading, velocity_x, velocity_y))
363 | else:
364 | return np.column_stack((line_rotate, heading))
365 |
366 | def map_norm(map_line, center, angle):
367 | self_line = LineString(map_line[:, 0:2])
368 | self_line = affine_transform(self_line, [1, 0, 0, 1, -center[0], -center[1]])
369 | self_line = rotate(self_line, -angle, origin=(0, 0), use_radians=True)
370 | self_line = np.array(self_line.coords)
371 | self_line[map_line[:, 0:2]==0] = 0
372 | self_heading = wrap_to_pi(map_line[:, 2] - angle)
373 |
374 | if map_line.shape[1] > 3:
375 | left_line = LineString(map_line[:, 3:5])
376 | left_line = affine_transform(left_line, [1, 0, 0, 1, -center[0], -center[1]])
377 | left_line = rotate(left_line, -angle, origin=(0, 0), use_radians=True)
378 | left_line = np.array(left_line.coords)
379 | left_line[map_line[:, 3:5]==0] = 0
380 | left_heading = wrap_to_pi(map_line[:, 5] - angle)
381 | left_heading[map_line[:, 5]==0] = 0
382 |
383 | right_line = LineString(map_line[:, 6:8])
384 | right_line = affine_transform(right_line, [1, 0, 0, 1, -center[0], -center[1]])
385 | right_line = rotate(right_line, -angle, origin=(0, 0), use_radians=True)
386 | right_line = np.array(right_line.coords)
387 | right_line[map_line[:, 6:8]==0] = 0
388 | right_heading = wrap_to_pi(map_line[:, 8] - angle)
389 | right_heading[map_line[:, 8]==0] = 0
390 |
391 | return np.column_stack((self_line, self_heading, left_line, left_heading, right_line, right_heading))
392 | else:
393 | return np.column_stack((self_line, self_heading))
394 |
395 |
396 | def map_norm_left(map_line, center, angle):
397 | angle = np.pi/ 2 - angle
398 | x, y = map_line[:, 0] - center[0], map_line[:, 1] - center[1]
399 | tx = np.cos(angle) * x - np.sin(angle) * y
400 | ty = np.sin(angle) * x + np.cos(angle) * y
401 | self_line = np.stack([tx, ty], axis=-1)
402 | self_line[map_line[:, 0:2]==0] = 0
403 | self_heading = wrap_to_pi(map_line[:, 2] - angle)
404 | self_heading[map_line[:, 2]==0] = 0
405 |
406 | if map_line.shape[1] > 3:
407 | x, y = map_line[:, 3] - center[0], map_line[:, 4] - center[1]
408 | tx = np.cos(angle) * x - np.sin(angle) * y
409 | ty = np.sin(angle) * x + np.cos(angle) * y
410 | left_line = np.stack([tx, ty], axis=-1)
411 | left_line[map_line[:, 3:5]==0] = 0
412 |
413 | left_heading = wrap_to_pi(map_line[:, 5] - angle)
414 | left_heading[map_line[:, 5]==0] = 0
415 |
416 | x, y = map_line[:, 6] - center[0], map_line[:, 7] - center[1]
417 | tx = np.cos(angle) * x - np.sin(angle) * y
418 | ty = np.sin(angle) * x + np.cos(angle) * y
419 | right_line = np.stack([tx, ty], axis=-1)
420 | right_line[map_line[:, 6:8]==0] = 0
421 |
422 | right_heading = wrap_to_pi(map_line[:, 8] - angle)
423 | right_heading[map_line[:, 8]==0] = 0
424 |
425 | return np.column_stack((self_line, self_heading, left_line, left_heading, right_line, right_heading))
426 | else:
427 | return np.column_stack((self_line, self_heading))
428 |
429 |
430 |
431 |
--------------------------------------------------------------------------------
/utils/occupancy_render_utils.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from glob import glob
3 | import dataclasses
4 |
5 | import numpy as np
6 | import math
7 |
8 | from waymo_open_dataset.protos import scenario_pb2
9 | from waymo_open_dataset.utils.occupancy_flow_renderer import _transform_to_image_coordinates, rotate_points_around_origin
10 | from waymo_open_dataset.utils import occupancy_flow_data
11 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2
12 | from waymo_open_dataset.protos import scenario_pb2
13 |
14 | import matplotlib.pyplot as plt
15 | import matplotlib.style as mplstyle
16 |
17 | from PIL import Image
18 |
19 | from .waymo_tf_utils import linecolormap, road_label, road_line_map, traffic_light_map, light_state_map_num
20 |
21 | _ObjectType = scenario_pb2.Track.ObjectType
22 |
23 | @dataclasses.dataclass
24 | class _SampledPoints:
25 | """Set of points sampled from agent boxes.
26 |
27 | All fields have shape -
28 | [batch_size, num_agents, num_steps, num_points] where num_points is
29 | (points_per_side_length * points_per_side_width).
30 | """
31 | # [batch, num_agents, num_steps, points_per_agent].
32 | x: tf.Tensor
33 | # [batch, num_agents, num_steps, points_per_agent].
34 | y: tf.Tensor
35 | # [batch, num_agents, num_steps, points_per_agent].
36 | z: tf.Tensor
37 | # [batch, num_agents, num_steps, points_per_agent].
38 | agent_type: tf.Tensor
39 | # [batch, num_agents, num_steps, points_per_agent].
40 | valid: tf.Tensor
41 |
42 | def pack_trajs(parsed_data, time=199):
43 | #x, y, z, heading, length, width, valid required for occupancy
44 | tracks = parsed_data.tracks
45 | traj_tensor = np.zeros((len(tracks), time+1, 10))
46 | valid_tensor = np.zeros((len(tracks), time+1))
47 | sdc_id = parsed_data.sdc_track_index
48 | goal = None
49 |
50 | for i, track in enumerate(tracks):
51 | object_type = track.object_type
52 | for j, state in enumerate(track.states):
53 | if state.valid:
54 | traj_tensor[i, j] = np.array([state.center_x, state.center_y, state.velocity_x, state.velocity_y,
55 | state.heading, state.center_z, state.length, state.width, state.height, object_type])
56 | valid_tensor[i, j] = 1
57 | if i== sdc_id:
58 | goal = [state.center_x, state.center_y]
59 |
60 | traj_tensor = tf.convert_to_tensor(traj_tensor, tf.float32)
61 | valid_tensor = tf.convert_to_tensor(valid_tensor, tf.int32)[...,tf.newaxis]
62 |
63 | return traj_tensor, valid_tensor, goal
64 |
65 | def _np_to_img_coordinate(points_x, points_y, config):
66 | pixels_per_meter = config.pixels_per_meter
67 | points_x = np.round(points_x * pixels_per_meter) + config.sdc_x_in_grid
68 | points_y = np.round(-points_y * pixels_per_meter) + config.sdc_y_in_grid
69 |
70 | # Filter out points that are located outside the FOV of topdown map.
71 | point_is_in_fov = np.logical_and(
72 | np.logical_and(
73 | np.greater_equal(points_x, 0), np.greater_equal(points_y, 0)),
74 | np.logical_and(
75 | np.less(points_x, config.grid_width_cells),
76 | np.less(points_y, config.grid_height_cells)))
77 |
78 | return points_x, points_y, point_is_in_fov.astype(np.float32)
79 |
80 |
81 | def wrap_to_pi(theta):
82 | return (theta+np.pi) % (2*np.pi) - np.pi
83 |
84 |
85 | def get_polylines_type(lines, traffic_light_lanes, stop_sign_lanes, config, ego_xyh, polyline_len):
86 | tl_keys = set(traffic_light_lanes.keys())
87 | polylines = {}
88 | org_polylnes = []
89 | for line in lines.keys():
90 | types = lines[line].type
91 | polyline = np.array([(map_point.x, map_point.y) for map_point in lines[line].polyline])
92 | if polyline.shape[0] <= 2:
93 | continue
94 | # rotations:
95 | x, y = polyline[:, 0] - ego_xyh[0], polyline[:, 1] - ego_xyh[1]
96 | angle = np.pi/ 2 - ego_xyh[2]
97 | tx = np.cos(angle) * x - np.sin(angle) * y
98 | ty = np.sin(angle) * x + np.cos(angle) * y
99 | new_polyline = np.stack([tx, ty], axis=-1)
100 |
101 | if len(polyline) > 1:
102 | direction = wrap_to_pi(np.arctan2(polyline[1:, 1]-polyline[:-1, 1], polyline[1:, 0]-polyline[:-1, 0]) - angle)
103 | direction = np.insert(direction, -1, direction[-1])[:, np.newaxis]
104 | else:
105 | direction = np.array([0])[:, np.newaxis]
106 |
107 | trajs = np.concatenate([new_polyline, direction], axis=-1)
108 |
109 | #(x_img, y_img, in_fov)
110 | ogm_states = np.zeros((trajs.shape[0], 3))
111 | points_x, points_y, point_is_in_fov = _np_to_img_coordinate(new_polyline[:,0], new_polyline[:,1], config)
112 |
113 | ogm_states = np.stack([points_x, points_y, point_is_in_fov], axis=-1)
114 |
115 | # attrib_states: (type, tl_state, near_tl, stop_sign, sp_limit)
116 | attrib_states = np.zeros((trajs.shape[0], 5))
117 | attrib_states[:, 0] = types
118 | if line in tl_keys:
119 | attrib_states[:, 1] = light_state_map_num[traffic_light_lanes[line][0]]
120 | near_tl = np.less_equal(np.linalg.norm(polyline[:, :2] - np.array(traffic_light_lanes[line][1:])[np.newaxis,...], axis=-1), 3).astype(np.float32)
121 | attrib_states[:, 2] = near_tl
122 | # add stop sign
123 | if line in stop_sign_lanes:
124 | attrib_states[:, 3] = True
125 | try:
126 | attrib_states[:, 4] = lines[line].speed_limit_mph / 2.237
127 | except:
128 | attrib_states[:, 4] = 0
129 |
130 | ogm_center = np.array([0, (config.sdc_y_in_grid - 0.5*config.grid_height_cells)/config.pixels_per_meter])[np.newaxis,...]
131 | polyline_traj = np.concatenate((trajs, attrib_states, ogm_states), axis=-1)
132 | org_polylnes.append(polyline_traj)
133 | traj_splits = np.array_split(polyline_traj, np.ceil(polyline_traj.shape[0] / polyline_len), axis=0)
134 | i = 0
135 | for sub_traj in traj_splits:
136 |
137 | ade = np.mean(np.linalg.norm(sub_traj[:, :2] - ogm_center, axis=-1))
138 | polylines[f'{line}_{i}'] = (sub_traj, ade)
139 |
140 | i += 1
141 |
142 | return polylines, org_polylnes
143 |
144 | def render_roadgraph_tf(rg_tensor):
145 | if len(rg_tensor)==0:
146 | print('Warning: RG tensor is 0!')
147 | return tf.zeros((512, 512, 3))
148 | rg_tensor = tf.convert_to_tensor(np.concatenate(rg_tensor,axis=0))
149 | topdown_shape = [512, 512, 1]
150 | rg_x, rg_y, point_is_in_fov = rg_tensor[:, -3], rg_tensor[:, -2], rg_tensor[:, -1]
151 |
152 | types = rg_tensor[:, 3]
153 | tl_state = rg_tensor[:, 4]
154 |
155 | should_render_point = tf.cast(point_is_in_fov, tf.bool)
156 | point_indices = tf.cast(tf.where(should_render_point), tf.int32)
157 | x_img_coord = tf.gather_nd(rg_x, point_indices)[..., tf.newaxis]
158 | y_img_coord = tf.gather_nd(rg_y, point_indices)[..., tf.newaxis]
159 |
160 | types = tf.gather_nd(types, point_indices)[..., tf.newaxis]
161 | tl_state = tf.gather_nd(tl_state, point_indices)[..., tf.newaxis]
162 |
163 | num_points_to_render = point_indices.shape.as_list()[0]
164 |
165 | # [num_points_to_render, 3]
166 | xy_img_coord = tf.concat(
167 | [
168 | # point_indices[:, :1],
169 | tf.cast(y_img_coord, tf.int32),
170 | tf.cast(x_img_coord, tf.int32),
171 | ],
172 | axis=1,
173 | )
174 | gt_values = tf.ones_like(x_img_coord, dtype=tf.float32)
175 |
176 | # [batch_size, grid_height_cells, grid_width_cells, 1]
177 | rg_viz = tf.scatter_nd(xy_img_coord, gt_values, topdown_shape)
178 | # assert_shapes([(rg_viz, topdown_shape)])
179 | rg_type = tf.math.divide_no_nan(tf.cast(tf.scatter_nd(xy_img_coord, types, topdown_shape),tf.float32) , rg_viz)
180 | rg_tl_type = tf.math.divide_no_nan(tf.cast(tf.scatter_nd(xy_img_coord, tl_state, topdown_shape),tf.float32) , rg_viz)
181 | rg_viz = tf.clip_by_value(rg_viz, 0.0, 1.0)
182 |
183 | return tf.concat([rg_viz, rg_type, rg_tl_type],axis=-1)
184 |
185 |
186 | def get_crosswalk_type(lines, traffic_light_lanes, stop_sign_lanes, config, ego_xyh, polyline_len):
187 | tl_keys = set(traffic_light_lanes.keys())
188 | polylines = {}
189 | org_polylnes = []
190 | # id_list = []
191 | for line in lines.keys():
192 | types = 18
193 | polyline = np.array([(map_point.x, map_point.y) for map_point in lines[line].polygon])
194 | if polyline.shape[0] <= 2:
195 | continue
196 | # rotations:
197 | x, y = polyline[:, 0] - ego_xyh[0], polyline[:, 1] - ego_xyh[1]
198 | angle = np.pi/ 2 - ego_xyh[2]
199 | tx = np.cos(angle) * x - np.sin(angle) * y
200 | ty = np.sin(angle) * x + np.cos(angle) * y
201 | new_polyline = np.stack([tx, ty], axis=-1)
202 |
203 | if len(polyline) > 1:
204 | direction = wrap_to_pi(np.arctan2(polyline[1:, 1]-polyline[:-1, 1], polyline[1:, 0]-polyline[:-1, 0]) - angle)
205 | direction = np.insert(direction, -1, direction[-1])[:, np.newaxis]
206 | else:
207 | direction = np.array([0])[:, np.newaxis]
208 |
209 | trajs = np.concatenate([new_polyline, direction], axis=-1)
210 |
211 | #(x_img, y_img, in_fov)
212 | ogm_states = np.zeros((trajs.shape[0], 3))
213 | points_x, points_y, point_is_in_fov = _np_to_img_coordinate(new_polyline[:,0], new_polyline[:,1], config)
214 |
215 | ogm_states = np.stack([points_x, points_y, point_is_in_fov], axis=-1)
216 |
217 | # attrib_states: (type, tl_state, near_tl, stop_sign, sp_limit)
218 | attrib_states = np.zeros((trajs.shape[0], 5))
219 | attrib_states[:, 0] = types
220 | if line in tl_keys:
221 | attrib_states[:, 1] = light_state_map_num[traffic_light_lanes[line][0]]
222 | near_tl = np.less_equal(np.linalg.norm(polyline[:, :2] - np.array(traffic_light_lanes[line][1:])[np.newaxis,...], axis=-1), 3).astype(np.float32)
223 | attrib_states[:, 2] = near_tl
224 | # add stop sign
225 | if line in stop_sign_lanes:
226 | attrib_states[:, 3] = True
227 | try:
228 | attrib_states[:, 4] = lines[line].speed_limit_mph / 2.237
229 | except:
230 | attrib_states[:, 4] = 0
231 |
232 | ogm_center = np.array([0, (config.sdc_y_in_grid - 0.5*config.grid_height_cells)/config.pixels_per_meter])[np.newaxis,...]
233 | polyline_traj = np.concatenate((trajs, attrib_states, ogm_states), axis=-1)
234 | org_polylnes.append(polyline_traj)
235 | # org_polylnes[line] = polyline_traj
236 | traj_splits = np.array_split(polyline_traj, np.ceil(polyline_traj.shape[0] / polyline_len), axis=0)
237 | i = 0
238 |
239 | for sub_traj in traj_splits:
240 |
241 | ade = np.mean(np.linalg.norm(sub_traj[:, :2] - ogm_center, axis=-1))
242 | polylines[f'{line}_{i}'] = (sub_traj, ade)
243 |
244 | i += 1
245 |
246 | return polylines, org_polylnes
247 |
248 |
249 | def pack_maps(lanes, roads, crosswalks, traffic_light_lanes, stop_sign_lanes, config, ego_xy, polyline_len=20):
250 | '''
251 | inputs: (Dict) lanes, roads, crosswalks
252 | outputs: dict of polyline-trajs with their ids which are inside fov
253 | '''
254 | all_poly_trajs = {}
255 | org_ploys = []
256 |
257 | lane_polylines, lane_ids = get_polylines_type(lanes, traffic_light_lanes, stop_sign_lanes, config, ego_xy, polyline_len)
258 | all_poly_trajs.update(lane_polylines)
259 | # print([id.shape for id in lane_ids.values()])
260 | org_ploys.extend(lane_ids)
261 |
262 | roads_polylines, road_ids = get_polylines_type(roads, traffic_light_lanes, stop_sign_lanes, config, ego_xy, polyline_len)
263 | all_poly_trajs.update(roads_polylines)
264 | org_ploys.extend(road_ids)
265 |
266 | cw_polylines, cw_ids = get_crosswalk_type(crosswalks, traffic_light_lanes, stop_sign_lanes, config, ego_xy, polyline_len)
267 | all_poly_trajs.update(cw_polylines)
268 | org_ploys.extend(cw_ids)
269 |
270 | return all_poly_trajs, org_ploys
271 |
272 |
273 | def points_sample(
274 | traj_tensor,
275 | valid_tensor,
276 | ego_trajs,
277 | unit_x,
278 | unit_y,
279 | config
280 | ):
281 |
282 | # traj_tensor, valid_tensor = pack_trajs(parsed_data)
283 |
284 | x, y, z = traj_tensor[...,0:1], traj_tensor[...,1:2], traj_tensor[...,5:6]
285 | length, width = traj_tensor[...,6:7], traj_tensor[...,7:8]
286 | bbox_yaw = traj_tensor[...,4:5]
287 |
288 | sdc_x = ego_trajs[0:1][tf.newaxis, tf.newaxis, :]
289 | sdc_y = ego_trajs[1:2][tf.newaxis, tf.newaxis, :]
290 | sdc_z = ego_trajs[5:6][tf.newaxis, tf.newaxis, :]
291 |
292 | x = x - sdc_x
293 | y = y - sdc_y
294 | z = z - sdc_z
295 |
296 | angle = math.pi / 2 - ego_trajs[4:5][tf.newaxis, tf.newaxis, :]
297 | x, y = rotate_points_around_origin(x, y, angle)
298 | bbox_yaw = bbox_yaw + angle
299 |
300 | agent_type = traj_tensor[...,-1][...,tf.newaxis]
301 |
302 | return _sample_points_from_agent_boxes(
303 | x=x,
304 | y=y,
305 | z=z,
306 | bbox_yaw=bbox_yaw,
307 | width=width,
308 | length=length,
309 | agent_type=agent_type,
310 | valid=valid_tensor,
311 | unit_x=unit_x,
312 | unit_y=unit_y
313 | # points_per_side_length=config.points_per_side_length,
314 | # points_per_side_width=config.points_per_side_width,
315 | )
316 |
317 | def sample_filter(
318 | traj_tensor,
319 | valid_tensor,
320 | ego_traj,
321 | config,
322 | times,
323 | unit_x,
324 | unit_y,
325 | include_observed=True,
326 | include_occluded=True
327 | ):
328 |
329 | b,e = _get_num_steps_from_times(times, config)
330 |
331 | # Sample points from agent boxes over specified time frames.
332 | # All fields have shape [num_agents, num_steps, points_per_agent].
333 | sampled_points = points_sample(
334 | traj_tensor[:,b:e],
335 | valid_tensor[:,b:e],
336 | ego_traj,
337 | unit_x,unit_y,
338 | config
339 | )
340 |
341 | agent_valid = tf.cast(sampled_points.valid, tf.bool)
342 |
343 | include_all = include_observed and include_occluded
344 | if not include_all and 'future' in times:
345 | history_times = ['past', 'current']
346 | b,e = _get_num_steps_from_times(history_times, config)
347 | agent_is_observed = valid_tensor[:,b:e]
348 | # [num_agents, 1, 1]
349 | agent_is_observed = tf.reduce_max(agent_is_observed, axis=1, keepdims=True)
350 | agent_is_observed = tf.cast(agent_is_observed, tf.bool)
351 |
352 | if include_observed:
353 | agent_filter = agent_is_observed
354 | elif include_occluded:
355 | agent_filter = tf.logical_not(agent_is_observed)
356 | else: # Both observed and occluded are off.
357 | raise ValueError('Either observed or occluded agents must be requested.')
358 | agent_valid = tf.logical_and(agent_valid, agent_filter)
359 |
360 | return _SampledPoints(
361 | x=sampled_points.x,
362 | y=sampled_points.y,
363 | z=sampled_points.z,
364 | agent_type=sampled_points.agent_type,
365 | valid=agent_valid,
366 | )
367 |
368 | def render_ego_occupancy(
369 | sampled_points,
370 | sdc_ids,
371 | config
372 | ):
373 | agent_x = sampled_points.x
374 | agent_y = sampled_points.y
375 | agent_type = sampled_points.agent_type
376 | agent_valid = sampled_points.valid
377 |
378 | # Set up assert_shapes.
379 | assert_shapes = tf.debugging.assert_shapes
380 | num_agents, num_steps, points_per_agent = agent_x.shape.as_list()
381 | topdown_shape = [
382 | config.grid_height_cells, config.grid_width_cells, num_steps
383 | ]
384 |
385 | # print(topdown_shape)
386 |
387 | # Transform from world coordinates to topdown image coordinates.
388 | # All 3 have shape: [batch, num_agents, num_steps, points_per_agent]
389 | agent_x, agent_y, point_is_in_fov = _transform_to_image_coordinates(
390 | points_x=agent_x,
391 | points_y=agent_y,
392 | config=config,
393 | )
394 |
395 | assert_shapes([(point_is_in_fov,
396 | [num_agents, num_steps, points_per_agent])])
397 |
398 | # Filter out points from invalid objects.
399 | agent_valid = tf.cast(agent_valid, tf.bool)
400 | point_is_in_fov_and_valid = tf.logical_and(point_is_in_fov, agent_valid)
401 | agent_x, agent_y, should_render_point = agent_x[sdc_ids][tf.newaxis,...], agent_y[sdc_ids][tf.newaxis,...], point_is_in_fov_and_valid[sdc_ids][tf.newaxis,...]
402 |
403 | # Collect points for ego vehicle
404 | assert_shapes([
405 | (should_render_point,
406 | [1, num_steps, points_per_agent]),
407 | ])
408 |
409 | # [num_points_to_render, 4]
410 | point_indices = tf.cast(tf.where(should_render_point), tf.int32)
411 |
412 | # [num_points_to_render, 1]
413 | x_img_coord = tf.gather_nd(agent_x, point_indices)[..., tf.newaxis]
414 | y_img_coord = tf.gather_nd(agent_y, point_indices)[..., tf.newaxis]
415 |
416 | num_points_to_render = point_indices.shape.as_list()[0]
417 | assert_shapes([(x_img_coord, [num_points_to_render, 1]),
418 | (y_img_coord, [num_points_to_render, 1])])
419 |
420 | # [num_points_to_render, 4]
421 | xy_img_coord = tf.concat(
422 | [
423 | # point_indices[:, :1],
424 | tf.cast(y_img_coord, tf.int32),
425 | tf.cast(x_img_coord, tf.int32),
426 | point_indices[:, 1:2],
427 | ],
428 | axis=1,
429 | )
430 | # [num_points_to_render]
431 | gt_values = tf.squeeze(tf.ones_like(x_img_coord, dtype=tf.float32), axis=-1)
432 |
433 | # [batch_size, grid_height_cells, grid_width_cells, num_steps]
434 | topdown = tf.scatter_nd(xy_img_coord, gt_values, topdown_shape)
435 |
436 |
437 | assert_shapes([(topdown, topdown_shape)])
438 |
439 | # scatter_nd() accumulates values if there are repeated indices. Since
440 | # we sample densely, this happens all the time. Clip the final values.
441 | topdown = tf.clip_by_value(topdown, 0.0, 1.0)
442 | return topdown
443 |
444 |
445 | def render_occupancy(
446 | sampled_points,
447 | config,
448 | sdc_ids=None,
449 | ):
450 |
451 | agent_x = sampled_points.x
452 | agent_y = sampled_points.y
453 | agent_type = sampled_points.agent_type
454 | agent_valid = sampled_points.valid
455 |
456 | # Set up assert_shapes.
457 | assert_shapes = tf.debugging.assert_shapes
458 | num_agents, num_steps, points_per_agent = agent_x.shape.as_list()
459 | topdown_shape = [
460 | config.grid_height_cells, config.grid_width_cells, num_steps
461 | ]
462 |
463 | # print(topdown_shape)
464 |
465 | # Transform from world coordinates to topdown image coordinates.
466 | # All 3 have shape: [batch, num_agents, num_steps, points_per_agent]
467 | agent_x, agent_y, point_is_in_fov = _transform_to_image_coordinates(
468 | points_x=agent_x,
469 | points_y=agent_y,
470 | config=config,
471 | )
472 | assert_shapes([(point_is_in_fov,
473 | [num_agents, num_steps, points_per_agent])])
474 |
475 | # Filter out points from invalid objects.
476 | agent_valid = tf.cast(agent_valid, tf.bool)
477 |
478 | #cases masking the ego car:
479 | if sdc_ids is not None:
480 | mask = np.ones((num_agents, 1, 1))
481 | mask[sdc_ids] = 0
482 | mask = tf.convert_to_tensor(mask, tf.bool)
483 | agent_valid = tf.logical_and(agent_valid, mask)
484 |
485 | point_is_in_fov_and_valid = tf.logical_and(point_is_in_fov, agent_valid)
486 |
487 | occupancies = {}
488 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES:
489 | # Collect points for each agent type, i.e., pedestrians and vehicles.
490 | agent_type_matches = tf.equal(agent_type, object_type)
491 | should_render_point = tf.logical_and(point_is_in_fov_and_valid,
492 | agent_type_matches)
493 |
494 | assert_shapes([
495 | (should_render_point,
496 | [num_agents, num_steps, points_per_agent]),
497 | ])
498 |
499 | # [num_points_to_render, 4]
500 | point_indices = tf.cast(tf.where(should_render_point), tf.int32)
501 |
502 | # [num_points_to_render, 1]
503 | x_img_coord = tf.gather_nd(agent_x, point_indices)[..., tf.newaxis]
504 | y_img_coord = tf.gather_nd(agent_y, point_indices)[..., tf.newaxis]
505 |
506 | num_points_to_render = point_indices.shape.as_list()[0]
507 | assert_shapes([(x_img_coord, [num_points_to_render, 1]),
508 | (y_img_coord, [num_points_to_render, 1])])
509 |
510 | # [num_points_to_render, 4]
511 | xy_img_coord = tf.concat(
512 | [
513 | # point_indices[:, :1],
514 | tf.cast(y_img_coord, tf.int32),
515 | tf.cast(x_img_coord, tf.int32),
516 | point_indices[:, 1:2],
517 | ],
518 | axis=1,
519 | )
520 | # [num_points_to_render]
521 | gt_values = tf.squeeze(tf.ones_like(x_img_coord, dtype=tf.float32), axis=-1)
522 |
523 | # [batch_size, grid_height_cells, grid_width_cells, num_steps]
524 | topdown = tf.scatter_nd(xy_img_coord, gt_values, topdown_shape)
525 |
526 |
527 | assert_shapes([(topdown, topdown_shape)])
528 |
529 | # scatter_nd() accumulates values if there are repeated indices. Since
530 | # we sample densely, this happens all the time. Clip the final values.
531 | topdown = tf.clip_by_value(topdown, 0.0, 1.0)
532 | occupancies[object_type] = topdown
533 |
534 | return occupancy_flow_data.AgentGrids(
535 | vehicles=occupancies[_ObjectType.TYPE_VEHICLE],
536 | pedestrians=occupancies[_ObjectType.TYPE_PEDESTRIAN],
537 | cyclists=occupancies[_ObjectType.TYPE_CYCLIST],
538 | ), point_is_in_fov_and_valid
539 |
540 |
541 | def render_flow_from_inputs(
542 | sampled_points,
543 | config,
544 | sdc_ids=None,
545 | ):
546 |
547 | agent_x = sampled_points.x
548 | agent_y = sampled_points.y
549 | agent_type = sampled_points.agent_type
550 | agent_valid = sampled_points.valid
551 |
552 | # Set up assert_shapes.
553 | assert_shapes = tf.debugging.assert_shapes
554 | num_agents, num_steps, points_per_agent = agent_x.shape.as_list()
555 | # The timestep distance between flow steps.
556 | waypoint_size = config.num_future_steps // config.num_waypoints
557 | num_flow_steps = num_steps - waypoint_size
558 | topdown_shape = [
559 | config.grid_height_cells, config.grid_width_cells,num_flow_steps
560 | ]
561 |
562 | # Transform from world coordinates to topdown image coordinates.
563 | # All 3 have shape: [batch, num_agents, num_steps, points_per_agent]
564 | agent_x, agent_y, point_is_in_fov = _transform_to_image_coordinates(
565 | points_x=agent_x,
566 | points_y=agent_y,
567 | config=config,
568 | )
569 |
570 | # Filter out points from invalid objects.
571 | agent_valid = tf.cast(agent_valid, tf.bool)
572 | #cases masking the ego car:
573 | if sdc_ids is not None:
574 | mask = np.ones((num_agents, 1, 1))
575 | mask[sdc_ids] = 0
576 | mask = tf.convert_to_tensor(mask, tf.bool)
577 | agent_valid = tf.logical_and(agent_valid, mask)
578 |
579 | # Backward Flow.
580 | # [num_agents, num_flow_steps, points_per_agent]
581 | dx = agent_x[:, :-waypoint_size, :] - agent_x[:, waypoint_size:, :]
582 | dy = agent_y[:, :-waypoint_size, :] - agent_y[:, waypoint_size:, :]
583 |
584 | # Adjust other fields as well to reduce from num_steps to num_flow_steps.
585 | # agent_x, agent_y: Use later timesteps since flow vectors go back in time.
586 | # [batch_size, num_agents, num_flow_steps, points_per_agent]
587 | agent_x = agent_x[:, waypoint_size:, :]
588 | agent_y = agent_y[:, waypoint_size:, :]
589 | # agent_type: Use later timesteps since flow vectors go back in time.
590 | # [batch_size, num_agents, num_flow_steps, points_per_agent]
591 | agent_type = agent_type[:, waypoint_size:, :]
592 | # point_is_in_fov: Use later timesteps since flow vectors go back in time.
593 | # [batch_size, num_agents, num_flow_steps, points_per_agent]
594 | point_is_in_fov = point_is_in_fov[:, waypoint_size:, :]
595 | # agent_valid: And the two timesteps. They both need to be valid.
596 | # [batch_size, num_agents, num_flow_steps, points_per_agent]
597 | agent_valid = tf.logical_and(agent_valid[:, waypoint_size:, :],
598 | agent_valid[:, :-waypoint_size, :])
599 |
600 | # [batch_size, num_agents, num_flow_steps, points_per_agent]
601 | point_is_in_fov_and_valid = tf.logical_and(point_is_in_fov, agent_valid)
602 |
603 | flows = {}
604 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES:
605 | # Collect points for each agent type, i.e., pedestrians and vehicles.
606 | agent_type_matches = tf.equal(agent_type, object_type)
607 | should_render_point = tf.logical_and(point_is_in_fov_and_valid,
608 | agent_type_matches)
609 |
610 | # [batch_size, height, width, num_flow_steps, 2]
611 | flow = _render_flow_points_for_one_agent_type(
612 | agent_x=agent_x,
613 | agent_y=agent_y,
614 | dx=dx,
615 | dy=dy,
616 | should_render_point=should_render_point,
617 | topdown_shape=topdown_shape,
618 | )
619 | flows[object_type] = flow
620 |
621 | return occupancy_flow_data.AgentGrids(
622 | vehicles=flows[_ObjectType.TYPE_VEHICLE],
623 | pedestrians=flows[_ObjectType.TYPE_PEDESTRIAN],
624 | cyclists=flows[_ObjectType.TYPE_CYCLIST],
625 | )
626 |
627 |
628 | def _render_flow_points_for_one_agent_type(
629 | agent_x,
630 | agent_y,
631 | dx,
632 | dy,
633 | should_render_point,
634 | topdown_shape,
635 | ):
636 | assert_shapes = tf.debugging.assert_shapes
637 |
638 | # Scatter points across topdown maps for each timestep. The tensor
639 | # `point_indices` holds the indices where `should_render_point` is True.
640 | # It is a 2-D tensor with shape [n, 3], where n is the number of valid
641 | # agent points inside FOV. Each row in this tensor contains indices over
642 | # the following 3 dimensions: (agent, timestep, point).
643 |
644 | # [num_points_to_render, 3]
645 | point_indices = tf.cast(tf.where(should_render_point), tf.int32)
646 | # [num_points_to_render, 1]
647 | x_img_coord = tf.gather_nd(agent_x, point_indices)[..., tf.newaxis]
648 | y_img_coord = tf.gather_nd(agent_y, point_indices)[..., tf.newaxis]
649 |
650 | num_points_to_render = point_indices.shape.as_list()[0]
651 | assert_shapes([(x_img_coord, [num_points_to_render, 1]),
652 | (y_img_coord, [num_points_to_render, 1])])
653 |
654 | # [num_points_to_render, 4]
655 | xy_img_coord = tf.concat(
656 | [
657 | tf.cast(y_img_coord, tf.int32),
658 | tf.cast(x_img_coord, tf.int32),
659 | point_indices[:, 1:2],
660 | ],
661 | axis=1,
662 | )
663 | # [num_points_to_render]
664 | gt_values_dx = tf.gather_nd(dx, point_indices)
665 | gt_values_dy = tf.gather_nd(dy, point_indices)
666 |
667 | gt_values = tf.squeeze(tf.ones_like(x_img_coord, dtype=tf.float32), axis=-1)
668 |
669 | # [batch_size, grid_height_cells, grid_width_cells, num_flow_steps]
670 | flow_x = tf.scatter_nd(xy_img_coord, gt_values_dx, topdown_shape)
671 | flow_y = tf.scatter_nd(xy_img_coord, gt_values_dy, topdown_shape)
672 | num_values_per_pixel = tf.scatter_nd(xy_img_coord, gt_values, topdown_shape)
673 |
674 | # Undo the accumulation effect of tf.scatter_nd() for repeated indices.
675 | flow_x = tf.math.divide_no_nan(flow_x, num_values_per_pixel)
676 | flow_y = tf.math.divide_no_nan(flow_y, num_values_per_pixel)
677 |
678 | # [batch_size, grid_height_cells, grid_width_cells, num_flow_steps, 2]
679 | flow = tf.stack([flow_x, flow_y], axis=-1)
680 | return flow
681 |
682 |
683 | def generate_units(points_per_side_length, points_per_side_width):
684 |
685 | if points_per_side_length < 1:
686 | raise ValueError('points_per_side_length must be >= 1')
687 | if points_per_side_width < 1:
688 | raise ValueError('points_per_side_width must be >= 1')
689 |
690 | # Create sample points on a unit square or boundary depending on flag.
691 | if points_per_side_length == 1:
692 | step_x = 0.0
693 | else:
694 | step_x = 1.0 / (points_per_side_length - 1)
695 | if points_per_side_width == 1:
696 | step_y = 0.0
697 | else:
698 | step_y = 1.0 / (points_per_side_width - 1)
699 | unit_x = []
700 | unit_y = []
701 | for xi in range(points_per_side_length):
702 | for yi in range(points_per_side_width):
703 | unit_x.append(xi * step_x - 0.5)
704 | unit_y.append(yi * step_y - 0.5)
705 |
706 | # Center unit_x and unit_y if there was only 1 point on those dimensions.
707 | if points_per_side_length == 1:
708 | unit_x = np.array(unit_x) + 0.5
709 | if points_per_side_width == 1:
710 | unit_y = np.array(unit_y) + 0.5
711 |
712 | unit_x = tf.convert_to_tensor(unit_x, tf.float32)
713 | unit_y = tf.convert_to_tensor(unit_y, tf.float32)
714 |
715 | return unit_x, unit_y
716 |
717 |
718 | def _sample_ego_from_boxes(x, y, bbox_yaw, width, length, unit_x, unit_y):
719 | sin_yaw = tf.sin(bbox_yaw)
720 | cos_yaw = tf.cos(bbox_yaw)
721 |
722 | tx = cos_yaw * length * unit_x - sin_yaw * width * unit_y + x
723 | ty = sin_yaw * length * unit_x + cos_yaw * width * unit_y + y
724 | return tx, ty
725 |
726 |
727 | def _sample_points_from_agent_boxes(
728 | x, y, z, bbox_yaw, width, length, agent_type, valid, unit_x, unit_y
729 | ):
730 | assert_shapes = tf.debugging.assert_shapes
731 | assert_shapes([(x, [..., 1])])
732 | x_shape = x.get_shape().as_list()
733 |
734 | # Transform the unit square points to agent dimensions and coordinate frames.
735 | sin_yaw = tf.sin(bbox_yaw)
736 | cos_yaw = tf.cos(bbox_yaw)
737 |
738 | # [..., num_points]
739 | tx = cos_yaw * length * unit_x - sin_yaw * width * unit_y + x
740 | ty = sin_yaw * length * unit_x + cos_yaw * width * unit_y + y
741 | tz = tf.broadcast_to(z, tx.shape)
742 |
743 | # points_shape = x_shape[:-1] + [num_points]
744 | agent_type = tf.broadcast_to(agent_type, tx.shape)
745 | valid = tf.broadcast_to(valid, tx.shape)
746 |
747 | return _SampledPoints(x=tx, y=ty, z=tz, agent_type=agent_type, valid=valid)
748 |
749 |
750 | def _get_num_steps_from_times(
751 | times,
752 | config):
753 | """Returns number of timesteps that exist in requested times."""
754 | p, c, f = config.num_past_steps, 1, config.num_future_steps
755 | dict_1 = {'past':(0, p), 'current':(p, p+c), 'future':(p+c, p+c+f)}
756 | if len(times)==0:
757 | raise NotImplementedError()
758 | elif len(times)==1:
759 | return dict_1[times[0]]
760 | elif len(times)==2:
761 | assert times[0]=='past'
762 | return (0, p + c)
763 | else:
764 | return (0, p + c + f)
765 |
766 |
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
4 |
5 | import tensorflow as tf
6 | import numpy as np
7 | import multiprocessing
8 | from multiprocessing import Pool, Process
9 | import argparse
10 | import random
11 |
12 | import math
13 | import time
14 | import pandas as pd
15 |
16 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2, scenario_pb2
17 | from google.protobuf import text_format
18 |
19 | from shapely.geometry import LineString, Point, Polygon
20 | from shapely.affinity import affine_transform, rotate
21 |
22 | import matplotlib.pyplot as plt
23 | import matplotlib as mpl
24 |
25 | from functools import partial
26 | from glob import glob
27 | from tqdm import tqdm
28 |
29 | from utils.occupancy_grid_utils import create_ground_truth_timestep_grids, \
30 | create_ground_truth_waypoint_grids, _ego_ground_truth_occupancy
31 | from utils.occupancy_render_utils import pack_trajs, pack_maps, \
32 | render_roadgraph_tf
33 |
34 | from utils.train_utils import *
35 |
36 | class Processor:
37 | def __init__(
38 | self,
39 | height=128,
40 | width=128,
41 | pixels_per_meter=1.6,
42 | hist_len=11,
43 | future_len=50,
44 | gap=5,
45 | num_observed=32,
46 | num_occluded=6,
47 | num_map=3,
48 | map_len=100,
49 | map_buffer=150,
50 | ego_map=6,
51 | ego_map_len=200,
52 | ego_buffer=300,
53 | ref_max_len=1000,
54 | planning_horizon=5,
55 | dt=0.1,
56 | data_files='',
57 | save_dir='',
58 | cumulative_waypoints='false',
59 | ol_test=False,
60 | timestep=199
61 | ):
62 |
63 | self.height = height
64 | self.width = width
65 | self.pixels_per_meter = pixels_per_meter
66 | self.cumulative_waypoints = cumulative_waypoints
67 |
68 | self.hist_len = hist_len
69 | self.future_len = future_len
70 | self.num_observed = num_observed
71 | self.num_occluded = num_occluded
72 | self.num_map = num_map
73 | self.map_len = map_len
74 | self.map_buffer = map_buffer
75 | self.ego_map = ego_map
76 | self.ego_map_len = ego_map_len
77 | self.ego_buffer = ego_buffer
78 | self.ref_max_len = ref_max_len
79 |
80 | self.horizon = planning_horizon #s
81 | self.dt = dt #s
82 |
83 | self.data_files = [data_files]
84 | self.save_dir = save_dir
85 | self.gap = gap
86 |
87 | self.ol_test = ol_test
88 |
89 | self.timestep = timestep
90 | print(f'timestep:{self.timestep}')
91 |
92 | self.test_scenario_ids = None
93 |
94 | self.get_config()
95 |
96 | def get_config(self):
97 | center_x,center_y = int(self.width / 2), int(self.height * 0.75)
98 | config = occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig()
99 | config_text = f"""
100 | num_past_steps: {self.hist_len - 1}
101 | num_future_steps: {self.future_len}
102 | num_waypoints: {self.future_len//5}
103 | cumulative_waypoints: {self.cumulative_waypoints}
104 | normalize_sdc_yaw: true
105 | grid_height_cells: {self.height}
106 | grid_width_cells: {self.width}
107 | sdc_y_in_grid: {center_y}
108 | sdc_x_in_grid: {center_x}
109 | pixels_per_meter: {self.pixels_per_meter}
110 | agent_points_per_side_length: 48
111 | agent_points_per_side_width: 16
112 | """
113 |
114 | text_format.Parse(config_text, config)
115 | self.config = config
116 |
117 | input_config = occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig()
118 | iconfig_text = f"""
119 | num_past_steps: {self.hist_len - 1}
120 | num_future_steps: {self.future_len}
121 | num_waypoints: {self.future_len//5}
122 | cumulative_waypoints: false
123 | normalize_sdc_yaw: true
124 | grid_height_cells: {self.height*2}
125 | grid_width_cells: {self.width*2}
126 | sdc_y_in_grid: {center_y + int(self.height*0.5)}
127 | sdc_x_in_grid: {center_x + int(self.width*0.5)}
128 | pixels_per_meter: {self.pixels_per_meter}
129 | agent_points_per_side_length: 48
130 | agent_points_per_side_width: 16
131 | """
132 |
133 | text_format.Parse(iconfig_text, input_config)
134 | self.input_config = input_config
135 |
136 | def build_map(self, map_features, dynamic_map_states):
137 | self.lanes = {}
138 | self.roads = {}
139 | self.stop_signs = {}
140 | self.crosswalks = {}
141 | self.speed_bumps = {}
142 |
143 | # static map features
144 | for map in map_features:
145 | map_type = map.WhichOneof("feature_data")
146 | map_id = map.id
147 | map = getattr(map, map_type)
148 |
149 | if map_type == 'lane':
150 | self.lanes[map_id] = map
151 | elif map_type == 'road_line' or map_type == 'road_edge':
152 | self.roads[map_id] = map
153 | elif map_type == 'stop_sign':
154 | self.stop_signs[map_id] = map
155 | elif map_type == 'crosswalk':
156 | self.crosswalks[map_id] = map
157 | elif map_type == 'speed_bump':
158 | self.speed_bumps[map_id] = map
159 | else:
160 | continue
161 |
162 | # dynamic map features
163 | self.traffic_signals = dynamic_map_states
164 |
165 | def map_process(self, traj, num_map, map_len, map_buffer, ind, goal=None):
166 | '''
167 | Map point attributes
168 | self_point (x, y, h), left_boundary_point (x, y, h), right_boundary_pont (x, y, h), speed limit (float),
169 | self_type (int), left_boundary_type (int), right_boundary_type (int),
170 | traffic light (int), stop_point (bool), interpolating (bool), stop_sign (bool)
171 | '''
172 | vectorized_map = np.zeros(shape=(num_map, map_len, 17))
173 | vectorized_crosswalks = np.zeros(shape=(3, 100, 3))
174 | agent_type = int(traj[-1][-1])
175 |
176 | # get all lane polylines
177 | lane_polylines = get_polylines(self.lanes)
178 |
179 | # get all road lines and edges polylines
180 | road_polylines = get_polylines(self.roads)
181 |
182 | # find current lanes for the agent
183 | ref_lane_ids = find_reference_lanes(agent_type, traj, lane_polylines)
184 |
185 | # find candidate lanes
186 | ref_lanes = []
187 |
188 | # get current lane's forward lanes
189 | for curr_lane, start in ref_lane_ids.items():
190 | candidate = depth_first_search(curr_lane, self.lanes,
191 | dist=lane_polylines[curr_lane][start:].shape[0], threshold=300)
192 | ref_lanes.extend(candidate)
193 |
194 | if agent_type != 2:
195 | # find current lanes' left and right lanes
196 | neighbor_lane_ids = find_neighbor_lanes(ref_lane_ids, traj, self.lanes, lane_polylines)
197 |
198 | # get neighbor lane's forward lanes
199 | for neighbor_lane, start in neighbor_lane_ids.items():
200 | candidate = depth_first_search(neighbor_lane, self.lanes,
201 | dist=lane_polylines[neighbor_lane][start:].shape[0], threshold=300)
202 | ref_lanes.extend(candidate)
203 |
204 | # update reference lane ids
205 | ref_lane_ids.update(neighbor_lane_ids)
206 |
207 | # remove overlapping lanes
208 | ref_lanes = remove_overlapping_lane_seq(ref_lanes)
209 |
210 | # get traffic light controlled lanes and stop sign controlled lanes
211 | traffic_light_lanes = {}
212 | stop_sign_lanes = []
213 |
214 | for signal in self.traffic_signals[ind-1].lane_states:
215 | traffic_light_lanes[signal.lane] = (signal.state, signal.stop_point.x, signal.stop_point.y)
216 | for lane in self.lanes[signal.lane].entry_lanes:
217 | traffic_light_lanes[lane] = (signal.state, signal.stop_point.x, signal.stop_point.y)
218 |
219 | for i, sign in self.stop_signs.items():
220 | stop_sign_lanes.extend(sign.lane)
221 |
222 | # add lanes to the array
223 | added_lanes = 0
224 | for i, s_lane in enumerate(ref_lanes):
225 | added_points = 0
226 | if i > num_map - 1:
227 | break
228 |
229 | # create a data cache
230 | cache_lane = np.zeros(shape=(map_buffer, 17))
231 |
232 | for lane in s_lane:
233 | curr_index = ref_lane_ids[lane] if lane in ref_lane_ids else 0
234 | self_line = lane_polylines[lane][curr_index:]
235 |
236 | if added_points >= map_buffer:
237 | break
238 |
239 | # add info to the array
240 | for point in self_line:
241 | # self_point and type
242 | cache_lane[added_points, 0:3] = point
243 | cache_lane[added_points, 10] = self.lanes[lane].type
244 |
245 | # left_boundary_point and type
246 | for left_boundary in self.lanes[lane].left_boundaries:
247 | left_boundary_id = left_boundary.boundary_feature_id
248 | left_start = left_boundary.lane_start_index
249 | left_end = left_boundary.lane_end_index
250 | left_boundary_type = left_boundary.boundary_type # road line type
251 | if left_boundary_type == 0:
252 | left_boundary_type = self.roads[left_boundary_id].type + 8 # road edge type
253 |
254 | if left_start <= curr_index <= left_end:
255 | left_boundary_line = road_polylines[left_boundary_id]
256 | nearest_point = find_neareast_point(point, left_boundary_line)
257 | cache_lane[added_points, 3:6] = nearest_point
258 | cache_lane[added_points, 11] = left_boundary_type
259 |
260 | # right_boundary_point and type
261 | for right_boundary in self.lanes[lane].right_boundaries:
262 | right_boundary_id = right_boundary.boundary_feature_id
263 | right_start = right_boundary.lane_start_index
264 | right_end = right_boundary.lane_end_index
265 | right_boundary_type = right_boundary.boundary_type # road line type
266 | if right_boundary_type == 0:
267 | right_boundary_type = self.roads[right_boundary_id].type + 8 # road edge type
268 |
269 | if right_start <= curr_index <= right_end:
270 | right_boundary_line = road_polylines[right_boundary_id]
271 | nearest_point = find_neareast_point(point, right_boundary_line)
272 | cache_lane[added_points, 6:9] = nearest_point
273 | cache_lane[added_points, 12] = right_boundary_type
274 |
275 | # speed limit
276 | cache_lane[added_points, 9] = self.lanes[lane].speed_limit_mph / 2.237
277 |
278 | # interpolating
279 | cache_lane[added_points, 15] = self.lanes[lane].interpolating
280 |
281 | # traffic_light
282 | if lane in traffic_light_lanes.keys():
283 | cache_lane[added_points, 13] = traffic_light_lanes[lane][0]
284 | if np.linalg.norm(traffic_light_lanes[lane][1:] - point[:2]) < 3:
285 | cache_lane[added_points, 14] = True
286 |
287 | # add stop sign
288 | if lane in stop_sign_lanes:
289 | cache_lane[added_points, 16] = True
290 |
291 | # count
292 | added_points += 1
293 | curr_index += 1
294 |
295 | if added_points >= map_buffer:
296 | break
297 |
298 | # scale the lane
299 | vectorized_map[i] = cache_lane[np.linspace(0, added_points, num=map_len, endpoint=False, dtype=np.int)]
300 |
301 | # count
302 | added_lanes += 1
303 |
304 | if goal is not None:
305 | dist_list = {}
306 | for i in range(vectorized_map.shape[0]):
307 | dist_list[i] = np.min(np.linalg.norm(vectorized_map[i ,:, :2] - goal, axis=-1))
308 | sorted_inds = sorted(dist_list.items(), key=lambda item:item[1])[:self.ego_map]
309 | sorted_inds = [ind[0] for ind in sorted_inds]
310 | vectorized_map = vectorized_map[sorted_inds]
311 |
312 |
313 | # find surrounding crosswalks and add them to the array
314 | added_cross_walks = 0
315 | detection = Polygon([(0, -5), (50, -20), (50, 20), (0, 5)])
316 | detection = affine_transform(detection, [1, 0, 0, 1, traj[-1][0], traj[-1][1]])
317 | detection = rotate(detection, traj[-1][2], origin=(traj[-1][0], traj[-1][1]), use_radians=True)
318 |
319 | for _, crosswalk in self.crosswalks.items():
320 | polygon = Polygon([(point.x, point.y) for point in crosswalk.polygon])
321 | polyline = polygon_completion(crosswalk.polygon)
322 | polyline = polyline[np.linspace(0, polyline.shape[0], num=100, endpoint=False, dtype=np.int)]
323 |
324 | if detection.intersects(polygon):
325 | vectorized_crosswalks[added_cross_walks, :polyline.shape[0]] = polyline
326 | added_cross_walks += 1
327 |
328 | if added_cross_walks >= 3:
329 | break
330 |
331 | #map [3, 100, 17] ; crosswalk [4, 50, 3]
332 | vectorized_map = vectorized_map[:, 0::2, :]
333 | vectorized_crosswalks = vectorized_crosswalks[:, 0::2, :]
334 |
335 | return vectorized_map.astype(np.float32), vectorized_crosswalks.astype(np.float32)
336 |
337 | @staticmethod
338 | def ego_frame_dynamics(v, theta):
339 | ego_v = v.copy()
340 | ego_v[0] = v[0] * np.cos(theta) + v[1] * np.sin(theta)
341 | ego_v[1] = v[1] * np.cos(theta) - v[0] * np.sin(theta)
342 |
343 | return ego_v
344 |
345 | def dynamic_state_process(self, ind=11):
346 | traffic_light_lanes = {}
347 | stop_sign_lanes = []
348 |
349 | for signal in self.traffic_signals[10].lane_states:
350 | traffic_light_lanes[signal.lane] = (signal.state, signal.stop_point.x, signal.stop_point.y)
351 | for lane in self.lanes[signal.lane].entry_lanes:
352 | traffic_light_lanes[lane] = (signal.state, signal.stop_point.x, signal.stop_point.y)
353 |
354 | for i, sign in self.stop_signs.items():
355 | stop_sign_lanes.extend(sign.lane)
356 |
357 | return traffic_light_lanes, stop_sign_lanes
358 |
359 | def history_ogm_process(self, traj_tensor, valid_tensor, ego_traj, sdc_ids=None):
360 | timestep_grids, observed_valid, ego_occupancy = create_ground_truth_timestep_grids(traj_tensor, valid_tensor,
361 | ego_traj, self.input_config,flow=True, flow_origin=False,sdc_ids=sdc_ids)
362 |
363 | vehicle_flow = timestep_grids.vehicles.all_flow[:, :, 0, :].numpy().astype(np.int8)
364 | ped_flow = timestep_grids.pedestrians.all_flow[:, :, 0, :].numpy().astype(np.int8)
365 | cyc_flow = timestep_grids.cyclists.all_flow[:, :, 0, :].numpy().astype(np.int8)
366 |
367 | hist_flow = np.concatenate([vehicle_flow,ped_flow,cyc_flow],axis=-1)
368 |
369 | hist_vehicles, hist_pedestrians, hist_cyclists = timestep_grids.vehicles, timestep_grids.pedestrians, timestep_grids.cyclists
370 | hist_v_ogm = tf.concat([hist_vehicles.past_occupancy,hist_vehicles.current_occupancy],axis=-1)
371 | hist_p_ogm = tf.concat([hist_pedestrians.past_occupancy,hist_pedestrians.current_occupancy],axis=-1)
372 | hist_c_ogm = tf.concat([hist_cyclists.past_occupancy,hist_cyclists.current_occupancy],axis=-1)
373 |
374 | hist_ogm = tf.stack([hist_v_ogm , hist_p_ogm , hist_c_ogm], axis=-1).numpy().astype(np.bool_)
375 | ego_mask_hist = ego_occupancy[:, :, :11]
376 |
377 | return hist_ogm, hist_flow, ego_mask_hist
378 |
379 | def gt_ogm_process(self, traj_tensor, valid_tensor, ego_traj, sdc_ids=None, test=False):
380 | timestep_grids, observed_valid, ego_occupancy = create_ground_truth_timestep_grids(traj_tensor, valid_tensor,
381 | ego_traj, self.config,flow=True, flow_origin=False, sdc_ids=sdc_ids, test=test)
382 | if test:
383 | return None, None, None, observed_valid, None
384 | true_waypoints = create_ground_truth_waypoint_grids(timestep_grids, self.config, flow=True, flow_origin=False)
385 |
386 | gt_obs_v = tf.stack(true_waypoints.vehicles.observed_occupancy,axis=0)
387 | gt_obs_p = tf.stack(true_waypoints.pedestrians.observed_occupancy,axis=0)
388 | gt_obs_c = tf.stack(true_waypoints.cyclists.observed_occupancy,axis=0)
389 |
390 | gt_occ_v = tf.stack(true_waypoints.vehicles.occluded_occupancy,axis=0)
391 | gt_occ_p = tf.stack(true_waypoints.pedestrians.occluded_occupancy,axis=0)
392 | gt_occ_c = tf.stack(true_waypoints.cyclists.occluded_occupancy,axis=0)
393 |
394 | gt_obs = tf.stack([gt_obs_v, gt_obs_p, gt_obs_c],axis=-1).numpy().astype(np.bool_)
395 | gt_occ = tf.clip_by_value(gt_occ_v + gt_occ_p + gt_occ_c, 0, 1).numpy().astype(np.bool_)
396 |
397 | gt_flow = tf.stack(true_waypoints.vehicles.flow,axis=0)
398 | gt_flow_p = tf.stack(true_waypoints.pedestrians.flow,axis=0)
399 | gt_flow_c = tf.stack(true_waypoints.cyclists.flow,axis=0)
400 |
401 | gt_flow = tf.stack([gt_flow, gt_flow_p, gt_flow_c],axis=-1).numpy().astype(np.int8)
402 |
403 | ego_mask = tf.stack([ego_occupancy[:, :, 10 * (i + 2)] for i in range(self.future_len//10)], axis=0).numpy().astype(np.bool_)
404 |
405 | return gt_obs, gt_occ, gt_flow, observed_valid, ego_mask
406 |
407 | def traj_process(self, traj_tensor, valid_tensor, ego_current, observed_valid, occluded_valid, sdc_ids):
408 | """
409 | 1. Process trajectories for all agents (pack_trajs)
410 | 2. filter agents in fov (observed(currently present) and occluded)
411 | 3. sort neighbors agents in fov separately (observed, occluded)
412 | """
413 |
414 | #ego traj
415 | ego_tensor = traj_tensor[sdc_ids]
416 | ego_traj, gt_traj = ego_tensor[:self.hist_len], ego_tensor[self.hist_len:, :5]
417 | self.current_xyh = [ego_current[0], ego_current[1], ego_current[4]]
418 | #filtered agents
419 | observed_ids, occluded_ids = {}, {}
420 | observed_agents = np.zeros((self.num_observed, self.hist_len, 10))
421 | if self.ol_test:
422 | observed_agents = np.zeros((self.num_observed, self.hist_len + self.future_len, 10))
423 | #x, y coordinates
424 | observed_valid, occluded_valid = observed_valid, occluded_valid
425 | for i in range(traj_tensor.shape[0]):
426 | if i==sdc_ids:
427 | continue
428 | if observed_valid[i]==1:
429 | observed_ids[i] = traj_tensor[i, 10, :2]
430 |
431 | sorted_observed = sorted(observed_ids.items(),
432 | key=lambda item: np.linalg.norm(item[1] - self.current_xyh[:2]))[:self.num_observed]
433 | for i, obs in enumerate(sorted_observed):
434 | if self.ol_test:
435 | observed_agents[i] = traj_tensor[obs[0], :, :]
436 | else:
437 | observed_agents[i] = traj_tensor[obs[0], :self.hist_len, :]
438 | neighbor_traj = observed_agents
439 | return ego_traj.astype(np.float32), neighbor_traj.astype(np.float32), gt_traj.astype(np.float32)
440 |
441 | def route_process(self, sdc_id, timestep, cur_pos, tracks):
442 | # find reference paths according to the gt trajectory
443 | gt_path = tracks[sdc_id].states
444 | # remove rare cases
445 | try:
446 | route = find_route(gt_path, timestep, cur_pos, self.lanes, self.crosswalks, self.traffic_signals)
447 | except:
448 | return None
449 |
450 | ref_path = np.array(route, dtype=np.float32)
451 |
452 | if ref_path.shape[0] < 1200:
453 | repeated_last_point = np.repeat(ref_path[np.newaxis, -1], 1200-ref_path.shape[0], axis=0)
454 | ref_path = np.append(ref_path, repeated_last_point, axis=0)
455 |
456 | return ref_path
457 |
458 |
459 | def occupancy_process(self,traj_tensor, valid_tensor, ego_traj, sdc_ids=None, infer=False):
460 | """
461 | process the historical and future occupancy
462 | """
463 | timestep_grids, observed_valid, ego_occupancy = create_ground_truth_timestep_grids(traj_tensor, valid_tensor, ego_traj, self.config,flow=True,
464 | flow_origin=False,sdc_ids=sdc_ids)
465 |
466 | hist_vehicles, hist_pedestrians, hist_cyclists = timestep_grids.vehicles, timestep_grids.pedestrians, timestep_grids.cyclists
467 |
468 | hist_v_ogm = tf.concat([hist_vehicles.past_occupancy,hist_vehicles.current_occupancy],axis=-1)
469 | hist_p_ogm = tf.concat([hist_pedestrians.past_occupancy,hist_pedestrians.current_occupancy],axis=-1)
470 | hist_c_ogm = tf.concat([hist_cyclists.past_occupancy,hist_cyclists.current_occupancy],axis=-1)
471 | ego_hist = ego_occupancy[...,:self.hist_len].numpy().astype(np.bool_)
472 |
473 | #[h, w, 3]
474 | hist_ogm = tf.stack([hist_v_ogm, hist_p_ogm, hist_c_ogm], axis=-1).numpy().astype(np.bool_)
475 |
476 | if infer:
477 | return hist_ogm, np.zeros((5, 128, 128, 3)), np.zeros((5, 128, 128, 3)), observed_valid.numpy(), None, np.zeros((5, 128, 128, 1)), ego_hist
478 |
479 | true_waypoints = create_ground_truth_waypoint_grids(timestep_grids, self.config,flow=True)
480 | vehicles, pedestrians, cyclists = true_waypoints.vehicles, true_waypoints.pedestrians, true_waypoints.cyclists
481 |
482 | gt_obs_v = tf.stack(true_waypoints.vehicles.observed_occupancy,axis=0)
483 | gt_occ_v = tf.stack(true_waypoints.vehicles.occluded_occupancy,axis=0)
484 |
485 | gt_obs_p = tf.stack(true_waypoints.pedestrians.observed_occupancy,axis=0)
486 | gt_occ_p = tf.stack(true_waypoints.pedestrians.occluded_occupancy,axis=0)
487 |
488 | gt_obs_c = tf.stack(true_waypoints.cyclists.observed_occupancy,axis=0)
489 | gt_occ_c = tf.stack(true_waypoints.cyclists.occluded_occupancy,axis=0)
490 |
491 | gt_obs = tf.concat([gt_obs_v, gt_obs_p, gt_obs_c] ,axis=-1).numpy().astype(np.bool_)
492 | gt_occ = tf.concat([gt_occ_v, gt_occ_p, gt_occ_c] ,axis=-1).numpy().astype(np.bool_)
493 |
494 | ego_future = _ego_ground_truth_occupancy(ego_occupancy, self.config)
495 | ego_future = tf.stack(ego_future, axis=0).numpy().astype(np.bool_)
496 |
497 | return hist_ogm, gt_obs, gt_occ, observed_valid.numpy(), None, ego_future, ego_hist
498 |
499 |
500 | def load_open_loop_files(self):
501 | path = args.ol_dir
502 | data = pd.read_csv(path)
503 | self.test_scenario_ids = set(data['Scenario ID'].to_list())
504 |
505 | def normalize_data(self, ego, neighbors, map_lanes, map_crosswalks, ego_map, ego_crosswalk,
506 | ground_truth, goal, ref_line, viz=True,sc_ids='', plan_res=None):
507 | # get the center and heading (local view)
508 | center, angle = self.current_xyh[:2], self.current_xyh[2]
509 | # normalize agent trajectories
510 | ego[:, :5] = agent_norm(ego, center, angle)
511 | ground_truth = agent_norm(ground_truth, center, angle)
512 |
513 | for i in range(neighbors.shape[0]):
514 | if neighbors[i, 10, 0] != 0:
515 | neighbors[i, :, :5] = agent_norm(neighbors[i, :], center, angle, impute=False)
516 |
517 | # normalize map points
518 | for i in range(map_lanes.shape[0]):
519 | lanes = map_lanes[i]
520 | crosswalks = map_crosswalks[i]
521 |
522 | for j in range(map_lanes.shape[1]):
523 | lane = lanes[j]
524 | if lane[0][0] != 0:
525 | lane[:, :9] = map_norm(lane, center, angle)
526 |
527 | for k in range(map_crosswalks.shape[1]):
528 | crosswalk = crosswalks[k]
529 | if crosswalk[0][0] != 0:
530 | crosswalk[:, :3] = map_norm(crosswalk, center, angle)
531 |
532 | for j in range(ego_map.shape[0]):
533 | lane = ego_map[j]
534 | if lane[0][0] != 0:
535 | lane[:, :9] = map_norm(lane, center, angle)
536 |
537 | for k in range(ego_crosswalk.shape[0]):
538 | crosswalk = ego_crosswalk[k]
539 | if crosswalk[0][0] != 0:
540 | crosswalk[:, :3] = map_norm(crosswalk, center, angle)
541 |
542 | plan_lines = np.zeros_like(ref_line)
543 |
544 | ref_line = ref_line_norm(ref_line, center, angle).astype(np.float32)
545 |
546 | goal = goal_norm(goal, center, angle)
547 |
548 | # visulization
549 | if viz:
550 | plt.figure()
551 | for i in range(map_lanes.shape[0]):
552 | lanes = map_lanes[i]
553 | crosswalks = map_crosswalks[i]
554 | for j in range(map_lanes.shape[1]):
555 | lane = lanes[j]
556 | if lane[0][0] != 0:
557 | centerline = lane[:, 0:2]
558 | centerline = centerline[centerline[:, 0] != 0]
559 |
560 | for k in range(map_crosswalks.shape[1]):
561 | crosswalk = crosswalks[k]
562 | if crosswalk[0][0] != 0:
563 | crosswalk = crosswalk[crosswalk[:, 0] != 0]
564 | plt.plot(crosswalk[:, 0], crosswalk[:, 1], 'b', linewidth=4) # plot crosswalk
565 |
566 | for j in range(ego_map.shape[0]):
567 | lane = ego_map[j]
568 | if lane[0][0] != 0:
569 | centerline = lane[:, 0:2]
570 | centerline = centerline[centerline[:, 0] != 0]
571 | left = lane[:, 3:5]
572 | left = left[left[:, 0] != 0]
573 | right = lane[:, 6:8]
574 | right = right[right[:, 0] != 0]
575 | plt.plot(centerline[:, 0], centerline[:, 1],'k', linewidth=1) # plot centerline
576 |
577 | for k in range(ego_crosswalk.shape[0]):
578 | crosswalk = ego_crosswalk[k]
579 | if crosswalk[0][0] != 0:
580 | crosswalk = crosswalk[crosswalk[:, 0] != 0]
581 | plt.plot(crosswalk[:, 0], crosswalk[:, 1], 'b', linewidth=4) # plot crosswalk
582 |
583 |
584 | rect = plt.Rectangle((ego[-1, 0]-ego[-1, 6]/2, ego[-1, 1]-ego[-1, 7]/2), ego[-1, 6], ego[-1, 7], linewidth=2, color='r', alpha=0.6, zorder=3,
585 | transform=mpl.transforms.Affine2D().rotate_around(*(ego[-1, 0], ego[-1, 1]), ego[-1, 2]) + plt.gca().transData)
586 | plt.gca().add_patch(rect)
587 | plt.plot(ref_line[:, 0][ref_line[:, 0]!=0], ref_line[:, 1][ref_line[:, 0]!=0], 'y', linewidth=2, zorder=4)
588 |
589 | plt.plot(ego[:, 0], ego[:, 1],'royalblue' ,linewidth=3,zorder=3)
590 | future = ground_truth[ground_truth[:, 0] != 0]
591 | plt.plot(future[:, 0], future[:, 1], 'r', linewidth=3, zorder=3)
592 | color_map=['purple','brown','pink','olive','gold','royalblue']
593 |
594 | for i in range(neighbors.shape[0]):
595 | if neighbors[i, 10, 0] != 0:
596 | rect = plt.Rectangle((neighbors[i, 10, 0]-neighbors[i, 10, 6]/2, neighbors[i, 10, 1]-neighbors[i, 10, 7]/2),
597 | neighbors[i, 10, 6], neighbors[i, 10, 7], linewidth=2, color='m', alpha=0.6, zorder=3,
598 | transform=mpl.transforms.Affine2D().rotate_around(*(neighbors[i, 10, 0], neighbors[i, 10, 1]), neighbors[i, 10, 2]) + plt.gca().transData)
599 | plt.gca().add_patch(rect)
600 | mask = neighbors[i, :, 0] + neighbors[i, :, 1] != 0
601 | plt.plot(neighbors[i, mask, 0], neighbors[i, mask, 1], 'm', linewidth=1, zorder=3)
602 |
603 | if plan_res is not None:
604 | plt.plot(plan_res[:, 0], plan_res[:, 1], 'c', linewidth=3,zorder=5)
605 | plt.scatter(plan_res[9::10, 0], plan_res[9::10, 1], 10, 'c',zorder=5)
606 | circle = plt.Circle([goal[0], goal[1]],color='r')
607 | plt.gca().add_patch(circle)
608 | plt.gca().set_aspect('equal')
609 | plt.tight_layout()
610 | plt.show(block=False)
611 | plt.pause(1)
612 | plt.close()
613 |
614 | return ego, neighbors, map_lanes, map_crosswalks, ego_map, ego_crosswalk ,ref_line, ground_truth, goal, plan_lines
615 |
616 | def data_process(self,vis=False):
617 | for data_file in self.data_files:
618 | dataset = tf.data.TFRecordDataset(data_file)
619 | sample_index = [i for i in range(self.hist_len, self.timestep-self.future_len, self.gap)]
620 | if len(sample_index) ==0:
621 | sample_index = [11]
622 | total_len = len(list(dataset))*len(sample_index)
623 | print(f"Processing {data_file.split('/')[-1]}", total_len)
624 | start_time = time.time()
625 | current = 1
626 | for data in dataset:
627 | parsed_data = scenario_pb2.Scenario()
628 | parsed_data.ParseFromString(data.numpy())
629 |
630 | scenario_id = parsed_data.scenario_id
631 | if self.ol_test:
632 | if scenario_id not in self.test_scenario_ids:
633 | continue
634 |
635 | sdc_id = parsed_data.sdc_track_index
636 | time_len = len(parsed_data.tracks[sdc_id].states)
637 | self.build_map(parsed_data.map_features, parsed_data.dynamic_map_states)
638 |
639 | traj_tensor, valid_tensor, goal = pack_trajs(parsed_data, self.timestep)
640 | sdc_id = parsed_data.sdc_track_index
641 | cnt = 0
642 | for ind in sample_index:
643 | #slice current points
644 | traj_window, valid_window = traj_tensor[:, ind-self.hist_len:ind+self.future_len, :], valid_tensor[:, ind-self.hist_len:ind+self.future_len, :]
645 |
646 | ego_current = traj_tensor[sdc_id, ind]
647 |
648 | ego_goal_dist = np.linalg.norm(ego_current.numpy()[:2] - np.array(goal))
649 |
650 | if ego_goal_dist < 3 and cnt < 5:
651 | continue
652 |
653 | traffic_light_lanes, stop_sign_lanes = self.dynamic_state_process()
654 | map_dict, org_maps = pack_maps(self.lanes, self.roads, self.crosswalks, traffic_light_lanes,
655 | stop_sign_lanes, self.input_config, [ego_current[0], ego_current[1], ego_current[4]])
656 |
657 | hist_ogm, hist_flow,ego_ogm = self.history_ogm_process(traj_window, valid_window, ego_current, sdc_id)
658 | gt_obs, gt_occ, gt_flow, observed_valid, ego_ogm_gt = self.gt_ogm_process(traj_window, valid_window, ego_current, sdc_id, False)
659 |
660 | #vectorized trajs for agents in fovs:
661 | rg = render_roadgraph_tf(org_maps).numpy().astype(np.uint8)
662 |
663 | ego_traj, neighbor_traj, gt_traj = self.traj_process(traj_window.numpy(), valid_window.numpy(), ego_current.numpy(), observed_valid, None, sdc_id)
664 | ref_line = self.route_process(sdc_id, ind, self.current_xyh, parsed_data.tracks)
665 | if ref_line is None:
666 | continue
667 |
668 | #agents map and crosswalks:
669 | neighbor_map_lanes = np.zeros((self.num_observed, self.num_map, self.map_len//2, 17))
670 | neighbor_map_crosswalks = np.zeros((self.num_observed, 3, 50, 3))
671 | ego_map_lane, ego_map_crosswalk = self.map_process(ego_traj, self.ego_map*2, self.ego_map_len, self.ego_buffer, ind, goal)
672 |
673 |
674 | for i in range(neighbor_traj.shape[0]):
675 | if neighbor_traj[i, -1, 0] != 0:
676 | neighbor_map_lanes[i], neighbor_map_crosswalks[i] = self.map_process(neighbor_traj[i], self.num_map, self.map_len, self.map_buffer, ind)
677 |
678 | #normalize all
679 | self.sc_ids = f'{scenario_id}_{ind}'
680 |
681 | ego, neighbors, neighbor_map_lanes, neighbor_map_crosswalks, ego_map_lane, ego_map_crosswalk , ref_line, ground_truth, new_goal, plan_lines = self.normalize_data(ego_traj, neighbor_traj,
682 | neighbor_map_lanes, neighbor_map_crosswalks, ego_map_lane,ego_map_crosswalk, gt_traj, goal, ref_line, vis, f'{scenario_id}_{ind}')
683 |
684 | sys.stdout.write(f"\rProcessing{data_file.split('/')[-1]}|length:{current}/{total_len}|{(time.time()-start_time)/current:>.4f}s/sample")
685 | sys.stdout.flush()
686 | current += 1
687 | cnt += 1
688 |
689 | filename = self.save_dir + f"{scenario_id}_{ind}.npz"
690 | np.savez(filename, ego=ego, neighbors=neighbors, neighbor_map_lanes=neighbor_map_lanes,
691 | ego_map_lane=ego_map_lane,ego_map_crosswalk=ego_map_crosswalk,
692 | neighbor_map_crosswalks=neighbor_map_crosswalks, gt_future_states=ground_truth,
693 | ref_line=ref_line,goal=new_goal,hist_ogm=hist_ogm, gt_obs=gt_obs, gt_occ=gt_occ,
694 | ego_ogm=ego_ogm,ego_ogm_gt=ego_ogm_gt,rg=rg,hist_flow=hist_flow,gt_flow=gt_flow
695 | )
696 |
697 | def process_ol_test_data(data_files):
698 | processor = Processor(data_files=data_files,height=128,width=128,gap=10,ref_max_len=1200,ego_map=3,future_len=50,
699 | save_dir=args.save_dir+'/open_loop_test2/',ol_test=True, timestep=91)
700 | processor.load_open_loop_files()
701 | processor.data_process(vis=False)
702 | print(f'{data_files}-done!')
703 | with open(args.save_dir+'ol_log.txt','a') as writer:
704 | writer.write(data_files+'\n')
705 |
706 | def process_training_data(data_files):
707 | for data_file in data_files:
708 | processor = Processor(data_files=data_file,height=128,width=128,gap=10,ref_max_len=1200,ego_map=3,future_len=50,
709 | save_dir=args.save_dir+'train/', timestep=199)
710 | processor.data_process(vis=False)
711 | print(f'{data_file},training_done!')
712 | with open(args.save_dir+'train_log.txt','a') as writer:
713 | writer.write(data_file+'\n')
714 |
715 | def process_validation_data(data_files):
716 | # for data_file in data_files:
717 | processor = Processor(data_files=data_files,height=128,width=128,gap=10,ref_max_len=1200,ego_map=3,future_len=50,
718 | save_dir=args.save_dir+'valid/',timestep=91)
719 | processor.data_process(vis=False)
720 | print(f'{data_files},_done!')
721 | with open(args.save_dir+'val_log.txt','a') as writer:
722 | writer.write(data_files+'\n')
723 |
724 | def file_slice(files, n):
725 | return [files[i:i+len(files)//n] for i in range(0, len(files), len(files)//n)]
726 |
727 | def process_map(func, files, n):
728 | sliced_files = file_slice(files, n)
729 |
730 | process_list = []
731 | for i in range(n):
732 | p = Process(target=func, args=(sliced_files[i],))
733 | process_list.append(p)
734 | for i in range(n):
735 | process_list[i].start()
736 | for i in range(n):
737 | process_list[i].join()
738 |
739 |
740 | if __name__ == "__main__":
741 | parser = argparse.ArgumentParser()
742 | parser.add_argument("--processes", type=int,default=16)
743 | parser.add_argument("--root_dir", type=str,default='', help='path to load original Waymo Datasets')
744 | parser.add_argument("--save_dir", type=str,default='', help='path to save processed datasets')
745 | parser.add_argument("--ol_dir", type=str,default='', help='path for open loop test ids csv (optional)')
746 | args = parser.parse_args()
747 | processes = args.processes
748 |
749 | #Hint: processing full train set is time consuming, so you may sample a ratio for training
750 | train_root_dir = args.root_dir + 'training_20s/'
751 | train_list = glob(train_root_dir+'*')
752 | process_map(process_training_data, train_list, processes)
753 |
754 | # randomly select half portions of val sets:
755 | val_root_dir = args.root_dir + 'validation/'
756 | val_list =[
757 | val_root_dir + 'validation.tfrecord-' + "%05d" % i + '-of-00150'
758 | for i in random.sample(range(150), 75)
759 | ]
760 |
761 | with Pool(processes=processes) as p:
762 | p.map(process_validation_data, val_list)
763 |
764 | # process a sample or open-loop test, you may also directly employ the sampled val set for ol-testing
765 | if args.ol_dir != '':
766 | full_val_list = glob(val_root_dir + '*')
767 | with Pool(processes=processes) as p:
768 | p.map(process_ol_test_data, full_val_list)
--------------------------------------------------------------------------------