├── model ├── __init__.py ├── aggregator.py ├── InteractPlanner.py ├── decoder.py └── encoder.py ├── utils ├── __init__.py ├── waymo_tf_utils.py ├── net_utils.py ├── plan_utils.py ├── occupancy_grid_utils.py ├── train_utils.py └── occupancy_render_utils.py ├── README.md ├── testing.py ├── training.py ├── planner.py ├── metric.py └── preprocess.py /model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /model/aggregator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | import sys 5 | import os 6 | 7 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 9 | 10 | 11 | def temporal_upsample(inputs, size=(2, 2), mode='nearest'): 12 | assert len(inputs.shape) == 5 13 | b, c, t, h, w = inputs.shape 14 | inputs = inputs.permute(0, 2, 1, 3, 4).contiguous().reshape(b*t, c, h, w) 15 | inputs = F.interpolate(input=inputs, size=(2*h, 2*w), mode=mode, align_corners=True) 16 | inputs = inputs.reshape(b, t, c, 2*h, 2*w).permute(0, 2, 1, 3, 4) 17 | return inputs 18 | 19 | class CrossAttention(nn.Module): 20 | def __init__(self, dim=384, heads=8, dropout=0.1): 21 | super(CrossAttention, self).__init__() 22 | self.cross_attention = nn.MultiheadAttention(dim, heads, dropout=dropout, batch_first=True,) 23 | self.ffn = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(0.1), nn.Linear(dim*4, dim), nn.Dropout(0.1)) 24 | self.norm_0 = nn.LayerNorm(dim) 25 | self.norm_1 = nn.LayerNorm(dim) 26 | 27 | def forward(self, query, key, mask): 28 | output, _ = self.cross_attention(query, key, key, key_padding_mask=mask) 29 | attention_output = self.norm_0(output + query) 30 | n_output = self.ffn(attention_output) 31 | return self.norm_1(n_output + attention_output) -------------------------------------------------------------------------------- /model/InteractPlanner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .encoder import * 11 | from .decoder import * 12 | 13 | class InteractPlanner(nn.Module): 14 | def __init__(self, config, dim=256, enc_layer=2, heads=8, dropout=0.1, 15 | timestep=5, decoder_dim=384, fpn_len=2, use_dynamic=True, 16 | large_scale=True,flow_pred=False): 17 | 18 | super(InteractPlanner, self).__init__() 19 | 20 | self.visual_encoder = VisualEncoder(config, input_resolution=(256, 256), 21 | patch_size=2,use_deformable_block=False,large_scale=large_scale) 22 | 23 | self.vector_encoder = VectorEncoder(dim, enc_layer, heads, dropout) 24 | 25 | self.ogm_decoder = STrajNetDecoder(decoder_dim, heads, len_fpn=fpn_len, timestep=timestep, dropout=dropout, 26 | flow_pred=flow_pred,large_scale=large_scale) 27 | 28 | self.plan_decoder = PlanningDecoder(dim, heads, dropout, use_dynamic=use_dynamic, 29 | timestep=timestep) 30 | 31 | 32 | def forward(self, inputs): 33 | encoder_outputs = self.vector_encoder(inputs) 34 | bev_list, _ = self.visual_encoder(inputs) 35 | encoder_outputs.update({ 36 | 'bev_feature' : bev_list[-1] 37 | }) 38 | bev_pred = self.ogm_decoder(bev_list, encoder_outputs['encodings'][:, 0], encoder_outputs['masks'][:, 0]) 39 | traj, score = self.plan_decoder(encoder_outputs) 40 | return bev_pred, traj, score -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OPGP 2 | 3 | This repo is the implementation of: 4 | 5 | **Occupancy Prediction-Guided Neural Planner for Autonomous Driving** 6 |
[Haochen Liu](https://scholar.google.com/citations?user=iizqKUsAAAAJ&hl=en), [Zhiyu Huang](https://mczhi.github.io/), [Chen Lv](https://scholar.google.com/citations?user=UKVs2CEAAAAJ&hl=en) 7 |
[AutoMan Research Lab, Nanyang Technological University](https://lvchen.wixsite.com/automan) 8 |
**[[Paper]](https://ieeexplore.ieee.org/abstract/document/10422055/)**  **[[arXiv]](https://arxiv.org/abs/2305.03303)**  **[[Zhihu]](https://zhuanlan.zhihu.com/p/680304839)**  9 | 10 | - Code is now released 😀! 11 | 12 | ## Overview 13 | In this repository, you can expect to find the following features 🤩: 14 | * Pipelines for data process and training 15 | * Open-loop evaluations 16 | 17 | Not included 😵: 18 | * Model weights (Due to license from WOMD) 19 | * Real-time planning (Codes are not optimized for real-time performance) 20 | 21 | ## Experiment Pipelines 22 | 23 | ### Dataset and Environment 24 | 25 | 26 | - Downloading [Waymo Open Motion Dataset](https://waymo.com/open/download/) v1.1. Utilize data from ```scenario/training_20s``` for train set, and data from ```scenario/validation``` for val & test. 27 | 28 | - Clone this repository and install required packages. 29 | 30 | - **[NOTED]** For [theseus](https://github.com/facebookresearch/theseus) library, you may build from scratch and add system PATH in ```planner.py``` 31 | 32 | ### Data Process 33 | 34 | - Preprocess data for training & testing: 35 | 36 | ``` 37 | python preprocess.py \ 38 | --root_dir path/to/your/Waymo_Dataset/scenario/ \ 39 | --save_dir path/to/your/processed_data/ \ 40 | --processes=16 41 | ``` 42 | 43 | - You may also refer to [Waymo_candid_list](https://github.com/MCZhi/GameFormer/blob/main/open_loop_planning/waymo_candid_list.csv) for more interactive and safety-critical scenarios filtered in ```scenario/validation``` 44 | 45 | ### Training & Testing 46 | 47 | - Train & Eval the model using the command: 48 | 49 | ``` 50 | python -m torch.distributed.launch \ 51 | --nproc_per_node 1 \ # number of gpus 52 | --master_port 16666 \ 53 | training.py \ 54 | --data_dir path/to/your/processed_data/ \ 55 | --save_dir path/to/save/your/logs/ 56 | ``` 57 | 58 | - Conduct Open-loop Testing using the command: 59 | 60 | ``` 61 | python testing.py \ 62 | --data_dir path/to/your/testing_data/ \ 63 | --model_dir path/to/pretrained/model/ 64 | ``` 65 | 66 | ## Citation 67 | If you find this repository useful for your research, please consider giving us a star 🌟 and citing our paper. 68 | 69 | ```angular2html 70 | @inproceedings{liu2023occupancy, 71 | title={Occupancy prediction-guided neural planner for autonomous driving}, 72 | author={Liu, Haochen and Huang, Zhiyu and Lv, Chen}, 73 | booktitle={2023 IEEE 26th International Conference on Intelligent Transportation Systems (ITSC)}, 74 | pages={4859--4865}, 75 | year={2023}, 76 | organization={IEEE} 77 | } 78 | -------------------------------------------------------------------------------- /testing.py: -------------------------------------------------------------------------------- 1 | 2 | import csv 3 | import argparse 4 | import time 5 | import sys 6 | 7 | import torch 8 | from torch import optim 9 | from torch.utils.data import DataLoader 10 | 11 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2 12 | from google.protobuf import text_format 13 | 14 | from model.InteractPlanner import InteractPlanner 15 | from utils.net_utils import * 16 | from metric import TestingMetrics 17 | 18 | from planner import Planner 19 | 20 | 21 | def test_modal_selection(traj, score, targets, level=-1): 22 | gt_future = targets['ego_future_states'] 23 | if isinstance(traj, list): 24 | traj, score = traj[level], score[level] 25 | gt_modes = torch.argmax(score, dim=-1) 26 | B = traj.shape[0] 27 | selected_trajs = traj[torch.arange(B)[:, None], gt_modes.unsqueeze(-1)].squeeze(1) 28 | return selected_trajs, gt_modes 29 | 30 | def flow_warp(bev_pred, current_ogm, occ=False): 31 | ogm_pred, pred_flow = bev_pred[:, :4].sigmoid(), bev_pred[:, -2:] 32 | if not occ: 33 | ogm_pred = torch.cat([(ogm_pred[:,0] + ogm_pred[:,-1]).clamp(0,1).unsqueeze(1), 34 | ogm_pred[:, 1:2], ogm_pred[:, 2:3]],dim=1) 35 | 36 | b, c, t, h, w = pred_flow.shape 37 | pred_flow = pred_flow.permute(0, 2, 3, 4, 1) 38 | x = torch.linspace(0, w - 1, w) 39 | y = torch.linspace(0, h - 1, h) 40 | grid = torch.stack(torch.meshgrid([x, y])).transpose(1, 2) 41 | grid = grid.permute(1, 2, 0).unsqueeze(0).unsqueeze(1).expand(b, t, -1, -1, -1).to(local_rank) 42 | 43 | flow_grid = grid + pred_flow + 0.5 44 | flow_grid = 2 * flow_grid / (h) - 1 45 | 46 | warped_flow = [] 47 | for i in range(flow_grid.shape[1]): 48 | flow_origin_ogm = current_ogm if i==0 else ogm_pred[:, :, i-1] 49 | wf = F.grid_sample(flow_origin_ogm, flow_grid[:, i], mode='nearest', align_corners=False) 50 | warped_flow.append(wf) 51 | 52 | warped_flow = torch.stack(warped_flow, dim=2) 53 | warped_ogm = ogm_pred * warped_flow 54 | return warped_ogm 55 | 56 | def model_testing(valid_data): 57 | 58 | epoch_metrics = TestingMetrics(config) 59 | model.eval() 60 | current = 0 61 | start_time = time.time() 62 | size = len(valid_data) 63 | 64 | print(f'Testing....') 65 | for batch in valid_data: 66 | # prepare data 67 | inputs, target = batch_to_dict(batch, local_rank, use_flow) 68 | 69 | # query the model 70 | with torch.no_grad(): 71 | bev_pred, traj, score = model(inputs) 72 | selected_trajs, gt_modes = test_modal_selection(traj, score, target, level=0) 73 | selected_ref = target['ref_line'] 74 | b, h, w, d = inputs['hist_ogm'].shape 75 | types = inputs['hist_ogm'].reshape(b, h, w, d//3, 3) 76 | current_ogm = types[:, :, :, -1, :].permute(0, 3, 1, 2) 77 | type_mask = types[..., -1, :].sum(-2).sum(-2) > 0 78 | warped_ogm = flow_warp(bev_pred, current_ogm) 79 | 80 | planning_inputs = planner.preprocess(inputs['ego_state'], selected_trajs[:, :50, :2], 81 | selected_ref, warped_ogm[:, :, :5], type_mask, config,left=True) 82 | xy_plan = planner.plan(planning_inputs, selected_ref, inputs['ego_state']) 83 | 84 | epoch_metrics.update(xy_plan, score, warped_ogm, gt_modes, target, inputs['ego_state'][:,-1,:]) 85 | 86 | current += args.batch_size 87 | sys.stdout.write(f'\rVal: [{current:>6d}/{size*args.batch_size:>6d}]|{(time.time()-start_time)/current:>.4f}s/sample') 88 | sys.stdout.flush() 89 | 90 | print('Calculating Open Loop Planning Results...') 91 | print(epoch_metrics.result()) 92 | 93 | 94 | if __name__ == "__main__": 95 | 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--local_rank", type=int) 98 | parser.add_argument("--batch_size", type=int,default=4) 99 | parser.add_argument("--dim", type=int,default=256) 100 | parser.add_argument("--use_flow", type=bool, action='store_true', default=True, 101 | help='whether to use flow warp') 102 | parser.add_argument("--data_dir", type=str, default='', 103 | help='path to load preprocessed data') 104 | parser.add_argument("--model_dir", type=str, default='', 105 | help='path to load pretrained IL model') 106 | 107 | args = parser.parse_args() 108 | local_rank = args.local_rank 109 | 110 | config = occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig() 111 | config_text = f""" 112 | num_past_steps: {10} 113 | num_future_steps: {50} 114 | num_waypoints: {5} 115 | cumulative_waypoints: {'true'} 116 | normalize_sdc_yaw: true 117 | grid_height_cells: {128} 118 | grid_width_cells: {128} 119 | sdc_y_in_grid: {int(128*0.75)} 120 | sdc_x_in_grid: {64} 121 | pixels_per_meter: {1.6} 122 | agent_points_per_side_length: 48 123 | agent_points_per_side_width: 16 124 | """ 125 | 126 | text_format.Parse(config_text, config) 127 | 128 | use_flow = args.use_flow 129 | 130 | model = InteractPlanner(config, dim=args.dim, enc_layer=2, heads=8, dropout=0.1, 131 | timestep=5, decoder_dim=384, fpn_len=2, flow_pred=use_flow) 132 | 133 | local_rank = torch.device('cuda') 134 | print(local_rank) 135 | 136 | model = model.to(local_rank) 137 | 138 | planner = DiffPlanner(device=local_rank,g_length=1200,g_width=60, horizon=5,test_iters=50) 139 | 140 | assert args.model_dir != '', 'you must load a pretrained weights for OL testing!' 141 | kw_dict = {} 142 | for k,v in torch.load(args.model_dir,map_location=torch.device('cpu')).items(): 143 | kw_dict[k[7:]] = v 144 | model.load_state_dict(kw_dict) 145 | continue_ep = int(args.model_dir.split('_')[-3]) - 1 146 | print(f'model loaded!:epoch {continue_ep + 1}') 147 | 148 | test_dataset = DrivingData(args.data_dir + f'*.npz', use_flow=True) 149 | 150 | training_size = len(test_dataset) 151 | print(f'Length test: {training_size}') 152 | 153 | test_data = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=8) 154 | model_testing(test_data) 155 | -------------------------------------------------------------------------------- /model/decoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from .aggregator import CrossAttention, temporal_upsample 12 | 13 | import math 14 | 15 | 16 | class DecodeUpsample(nn.Module): 17 | def __init__(self, input_dim, kernel, timestep): 18 | super(DecodeUpsample, self).__init__() 19 | self.conv = nn.Sequential(nn.Conv3d(input_dim, input_dim//2, (1, kernel, kernel), padding='same'), nn.GELU()) 20 | self.residual_conv = nn.Sequential(nn.Conv3d(input_dim//2, input_dim//2, (timestep, 1, 1)), nn.GELU()) 21 | 22 | def forward(self, inputs, res): 23 | #b, t, c, h, w = inputs.shape 24 | inputs = temporal_upsample(inputs, mode='bilinear') 25 | inputs = self.conv(inputs) + self.residual_conv(res) 26 | return inputs 27 | 28 | 29 | class PredFinalDecoder(nn.Module): 30 | def __init__(self, input_dim, kernel=3, large_scale=False,planning=True, use_flow=False): 31 | super(PredFinalDecoder, self).__init__() 32 | ''' 33 | input h,w = 128 34 | dual deconv for flow and ogms 35 | ''' 36 | self.input_dim = input_dim 37 | if large_scale: 38 | self.ogm_conv = nn.Conv3d(input_dim, 4, (1, kernel, kernel), padding='same') 39 | self.flow_conv = nn.Conv3d(input_dim, 2, (1, kernel, kernel), padding='same') 40 | else: 41 | self.ogm_conv = nn.Sequential(nn.Conv3d(input_dim, input_dim//2, (1, kernel, kernel), padding='same'), 42 | nn.GELU(), nn.Upsample(scale_factor=(1, 2, 2)), 43 | nn.Conv3d(input_dim//2, 4 if planning else 2, (1, kernel, kernel), padding='same')) 44 | 45 | self.flow_conv = nn.Sequential(nn.Conv3d(input_dim, input_dim//2, (1, kernel, kernel), padding='same'), 46 | nn.GELU(), nn.Upsample(scale_factor=(1, 2, 2)), 47 | nn.Conv3d(input_dim//2, 2, (1, kernel, kernel), padding='same')) 48 | 49 | def forward(self, inputs): 50 | ogms = self.ogm_conv(inputs) 51 | flows = self.flow_conv(inputs) 52 | return torch.cat([ogms, flows], dim=1) 53 | 54 | class STrajNetDecoder(nn.Module): 55 | def __init__(self, dim=384, heads=8, len_fpn=2, kernel=3, timestep=5, dropout=0.1, 56 | flow_pred=False, large_scale=False): 57 | super(STrajNetDecoder, self).__init__() 58 | 59 | self.timestep = timestep 60 | self.len_fpn = len_fpn 61 | self.residual_conv = nn.Sequential(nn.Conv3d(dim, dim, (timestep, 1, 1)), nn.GELU()) 62 | self.aggregator = nn.ModuleList([CrossAttention(dim, heads, dropout) for _ in range(timestep)]) 63 | 64 | self.actor_layer = nn.Sequential(nn.Linear(256, dim), nn.GELU()) 65 | 66 | self.fpn_decoders = nn.ModuleList([ 67 | DecodeUpsample(dim // (2 ** i), kernel, timestep) for i in range(len_fpn) 68 | ]) 69 | 70 | self.upsample = nn.Upsample(scale_factor=(1, 2, 2)) 71 | if flow_pred: 72 | self.output_conv = PredFinalDecoder(dim // (2 ** len_fpn),large_scale=large_scale) 73 | else: 74 | self.output_conv = nn.Conv3d(dim // (2 ** len_fpn), 4, (1, kernel, kernel), padding='same') 75 | 76 | def forward(self, output_list, actor, actor_mask): 77 | # Aggregations: 78 | enc_output = output_list[-1] 79 | b, c, h, w = enc_output.shape 80 | res_output = enc_output.unsqueeze(2).expand(-1, -1, self.timestep, -1, -1) 81 | enc_output = enc_output.reshape(b, c, h*w).permute(0, 2, 1) 82 | #[b, t, h*w, c] 83 | enc_output = enc_output.unsqueeze(1).expand(-1, self.timestep, -1, -1) 84 | 85 | actor = self.actor_layer(actor) 86 | actor_mask[:, 0] = False 87 | agg_output = torch.stack([self.aggregator[i](enc_output[:, i], actor, actor_mask) for i in range(self.timestep)], dim=2) 88 | agg_output = agg_output.permute(0, 3, 2, 1).reshape(b, -1, self.timestep, h, w) 89 | decode_output = agg_output + self.residual_conv(res_output) 90 | # fpn decoding: 91 | for j in range(self.len_fpn): 92 | decode_output = self.fpn_decoders[j](decode_output, output_list[-2-j].unsqueeze(2).expand(-1, -1, self.timestep, -1, -1)) 93 | decode_output = self.output_conv(self.upsample(decode_output)) 94 | 95 | #[b, t, c, h, w] 96 | return decode_output 97 | 98 | 99 | 100 | class EgoPlanner(nn.Module): 101 | def __init__(self, dim=256, use_dynamic=False,timestep=5): 102 | super(EgoPlanner,self).__init__() 103 | self.timestep = timestep 104 | self.out_step = 2 105 | self.planner = nn.Sequential(nn.Linear(256, 128), nn.ELU(),nn.Dropout(0.1), 106 | nn.Linear(128, timestep*self.out_step *10)) 107 | self.scorer = nn.Sequential(nn.Linear(256, 128), nn.ELU(),nn.Dropout(0.1), 108 | nn.Linear(128, 1)) 109 | self.use_dynamic = use_dynamic 110 | 111 | def physical(self, action, last_state): 112 | d_t = 0.1 113 | d_v = action[:, :, :, 0].clamp(-5, 5) 114 | d_theta = action[:, :, :, 1].clamp(-1, 1) 115 | 116 | x_0 = last_state[:, 0] 117 | y_0 = last_state[:, 1] 118 | theta_0 = last_state[:, 4] 119 | v_0 = torch.hypot(last_state[:, 2], last_state[:, 3]) 120 | 121 | v = v_0.reshape(-1,1,1) + torch.cumsum(d_v * d_t, dim=-1) 122 | v = torch.clamp(v, min=0) 123 | theta = theta_0.reshape(-1,1,1) + torch.cumsum(d_theta * d_t, dim=-1) 124 | theta = torch.fmod(theta, 2*torch.pi) 125 | x = x_0.reshape(-1,1,1) + torch.cumsum(v * torch.cos(theta) * d_t, dim=-1) 126 | y = y_0.reshape(-1,1,1) + torch.cumsum(v * torch.sin(theta) * d_t, dim=-1) 127 | traj = torch.stack([x, y, theta], dim=-1) 128 | return traj 129 | 130 | def forward(self, features, current_state): 131 | traj = self.planner(features).reshape(-1, 9, self.timestep*10, self.out_step) 132 | if self.use_dynamic: 133 | traj = self.physical(traj, current_state) 134 | score = self.scorer(features) 135 | return traj, score 136 | 137 | 138 | class PlanningDecoder(nn.Module): 139 | def __init__(self, dim=256, heads=8, dropout=0.1, use_dynamic=False, timestep=5): 140 | super(PlanningDecoder,self).__init__() 141 | 142 | self.region_embed = nn.Parameter(torch.zeros(1, 9, 256), requires_grad=True) 143 | nn.init.kaiming_uniform_(self.region_embed) 144 | 145 | self.self_attention = nn.MultiheadAttention(dim, heads, dropout, batch_first=True) 146 | self.cross_attention = nn.MultiheadAttention(dim, heads, dropout, batch_first=True) 147 | self.bev_attention = nn.MultiheadAttention(dim, heads, dropout, batch_first=True) 148 | self.bev_layer = nn.Sequential(nn.Linear(384, dim), nn.GELU()) 149 | 150 | self.ffn = nn.Sequential(nn.Linear(dim, dim*4), nn.GELU(), nn.Dropout(0.1), nn.Linear(dim*4, dim), nn.Dropout(0.1)) 151 | 152 | self.norm_0 = nn.LayerNorm(dim) 153 | self.norm_1 = nn.LayerNorm(dim) 154 | 155 | self.planner = EgoPlanner(dim, use_dynamic, timestep) 156 | 157 | def forward(self, inputs): 158 | #encode the poly plan as ref: 159 | b = inputs['encodings'].shape[0] 160 | plan_query = self.region_embed.expand(b,-1,-1) 161 | self_plan_query,_ = self.self_attention(plan_query, plan_query, plan_query) 162 | #cross attention with bev and map-actors: 163 | map_actors = inputs['encodings'][:, 0] 164 | map_actors_mask = inputs['masks'][:, 0] 165 | map_actors_mask[:,0] = False 166 | dense_feature,_ = self.cross_attention(self_plan_query, map_actors, map_actors, key_padding_mask=map_actors_mask) 167 | b, c, h, w = inputs['bev_feature'].shape 168 | bev_feature = inputs['bev_feature'].reshape(b, c, h*w).permute(0, 2, 1) 169 | bev_feature = self.bev_layer(bev_feature) 170 | bev_feature,_ = self.bev_attention(self_plan_query, bev_feature, bev_feature) 171 | 172 | attention_feature = self.norm_0(dense_feature + bev_feature + plan_query) 173 | output_feature = self.ffn(attention_feature) + attention_feature 174 | output_feature = self.norm_1(output_feature) 175 | 176 | # output: 177 | ego_current = inputs['actors'][:, 0, -1, :] 178 | traj, score = self.planner(output_feature, ego_current) 179 | return traj, score.squeeze(-1) -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import argparse 3 | import time 4 | import sys 5 | 6 | import torch 7 | from torch import optim 8 | import torch.distributed as dist 9 | from torch.utils.data.distributed import DistributedSampler 10 | from torch.nn.parallel import DistributedDataParallel as DDP 11 | from torch.utils.data import DataLoader 12 | 13 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2 14 | from google.protobuf import text_format 15 | 16 | from model.InteractPlanner import InteractPlanner 17 | from utils.net_utils import * 18 | from metric import TrainingMetrics, ValidationMetrics 19 | 20 | 21 | # define model training epoch 22 | def training_epoch(train_data, optimizer, epoch, scheduler): 23 | 24 | model.train() 25 | current = 0 26 | start_time = time.time() 27 | size = len(train_data) 28 | epoch_loss = [] 29 | train_metric = TrainingMetrics() 30 | i = 0 31 | for batch in train_data: 32 | # prepare data 33 | inputs, target = batch_to_dict(batch, local_rank , use_flow=use_flow) 34 | 35 | optimizer.zero_grad() 36 | # query the model 37 | bev_pred, traj, score = model(inputs) 38 | actor_loss, occ_loss, flow_loss = occupancy_loss(bev_pred, target, use_flow=use_flow) 39 | il_loss, _, gt_modes = imitation_loss(traj, score, target, args.use_planning) 40 | 41 | loss = il_loss + 100*(actor_loss + occ_loss) 42 | if use_flow: 43 | loss += flow_loss 44 | 45 | loss.backward() 46 | torch.nn.utils.clip_grad_norm_(model.parameters(), 5) 47 | optimizer.step() 48 | 49 | current += args.batch_size 50 | epoch_loss.append(loss.item()) 51 | if isinstance(traj, list): 52 | traj, score = traj[-1], score[-1] 53 | ade, fde, l_il, l_ogm = train_metric.update(traj, score, gt_modes, target, il_loss, actor_loss + occ_loss, bev_pred) 54 | 55 | if dist.get_rank() == 0: 56 | sys.stdout.write(f"\rTrain: [{current:>6d}/{size*args.batch_size:>6d}]|Loss: {np.mean(epoch_loss):>.4f}-{l_il:>.4f}-{l_ogm:>.4f}|ADE:{ade:>.4f}-FDE:{fde:>.4f}|{(time.time()-start_time)/current:>.4f}s/sample") 57 | sys.stdout.flush() 58 | 59 | scheduler.step(epoch + i/size) 60 | i += 1 61 | 62 | results = train_metric.result() 63 | 64 | return np.mean(epoch_loss), results 65 | 66 | # define model validation epoch 67 | def validation_epoch(valid_data,epoch): 68 | epoch_metrics = ValidationMetrics() 69 | model.eval() 70 | current = 0 71 | start_time = time.time() 72 | size = len(valid_data) 73 | epoch_loss = [] 74 | 75 | print(f'Validation...Epoch{epoch+1}') 76 | for batch in valid_data: 77 | # prepare data 78 | inputs, target = batch_to_dict(batch, local_rank, use_flow=use_flow) 79 | 80 | # query the model 81 | with torch.no_grad(): 82 | bev_pred, traj, score = model(inputs) 83 | actor_loss, occ_loss, flow_loss = occupancy_loss(bev_pred, target, use_flow=use_flow) 84 | il_loss, _, gt_modes = imitation_loss(traj, score, target, args.use_planning) 85 | loss = il_loss + 100*(actor_loss + occ_loss) 86 | if use_flow: 87 | loss += flow_loss 88 | # compute metrics 89 | epoch_loss.append(loss.item()) 90 | if isinstance(traj, list): 91 | traj, score = traj[-1], score[-1] 92 | ade, fde, l_il, l_ogm, ogm_auc,_, occ_auc,_ = epoch_metrics.update(traj, score, bev_pred, 93 | gt_modes, target, il_loss, actor_loss + occ_loss) 94 | 95 | current += args.batch_size 96 | if dist.get_rank() == 0: 97 | sys.stdout.write(f"\r\Val: [{current:>6d}/{size*args.batch_size:>6d}]|Loss: {np.mean(epoch_loss):>.4f}-{l_il:>.4f}-{l_ogm:>.4f}|ADE:{ade:>.4f}-FDE:{fde:>.4f}{(time.time()-start_time)/current:>.4f}s/sample") 98 | sys.stdout.flush() 99 | 100 | # process metrics 101 | epoch_metrics = epoch_metrics.result() 102 | 103 | return epoch_metrics,np.mean(epoch_loss) 104 | 105 | # Define model training process 106 | def model_training(train_data, valid_data, epochs, save_dir): 107 | # define optimizer and loss function 108 | optimizer = optim.AdamW(model.parameters(), lr=args.lr) 109 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=2, eta_min=1e-6) 110 | 111 | for epoch in range(epochs): 112 | if dist.get_rank() == 0: 113 | print(f"Epoch {epoch+1}/{epochs}") 114 | 115 | if epoch<=continue_ep and continue_ep!=0: 116 | scheduler.step() 117 | continue 118 | 119 | train_data.sampler.set_epoch(epoch) 120 | valid_data.sampler.set_epoch(epoch) 121 | 122 | train_loss,train_res = training_epoch(train_data, optimizer, epoch, scheduler) 123 | valid_metrics,val_loss = validation_epoch(valid_data,epoch) 124 | 125 | # save to training log 126 | log = {'epoch': epoch+1, 'loss': train_loss, 'lr': optimizer.param_groups[0]['lr']} 127 | log.update(valid_metrics) 128 | 129 | if dist.get_rank() == 0: 130 | if epoch == 0: 131 | with open(save_dir + f'train_log.csv', 'a') as csv_file: 132 | writer = csv.writer(csv_file) 133 | writer.writerow(log.keys()) 134 | writer.writerow(log.values()) 135 | else: 136 | with open(save_dir + f'train_log.csv', 'a') as csv_file: 137 | writer = csv.writer(csv_file) 138 | writer.writerow(log.values()) 139 | 140 | # adjust learning rate 141 | scheduler.step() 142 | 143 | # save model at the end of epoch 144 | if dist.get_rank() == 0: 145 | torch.save(model.state_dict(), save_dir+f'model_{epoch+1}_{train_loss:4f}_{val_loss:4f}.pth') 146 | 147 | 148 | if __name__ == "__main__": 149 | 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument("--local_rank", type=int) 152 | parser.add_argument("--batch_size", type=int, default=8) 153 | parser.add_argument("--dim", type=int, default=256) 154 | parser.add_argument("--lr", type=float, default=1e-4) 155 | parser.add_argument("--epochs", type=int, default=30) 156 | parser.add_argument("--use_flow", type=bool, action='store_true', default=True, 157 | help='whether to use flow warp') 158 | 159 | parser.add_argument("--save_dir", type=str, default='',help='path to save logs') 160 | parser.add_argument("--data_dir", type=str, default='', 161 | help='path to load preprocessed train & val sets') 162 | parser.add_argument("--model_dir", type=str, default='', 163 | help='path to load models for continue training') 164 | 165 | args = parser.parse_args() 166 | local_rank = args.local_rank 167 | 168 | use_flow = args.use_flow 169 | 170 | config = occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig() 171 | config_text = f""" 172 | num_past_steps: {10} 173 | num_future_steps: {50} 174 | num_waypoints: {5} 175 | cumulative_waypoints: {'false'} 176 | normalize_sdc_yaw: true 177 | grid_height_cells: {128} 178 | grid_width_cells: {128} 179 | sdc_y_in_grid: {int(128*0.75)} 180 | sdc_x_in_grid: {64} 181 | pixels_per_meter: {1.6} 182 | agent_points_per_side_length: 48 183 | agent_points_per_side_width: 16 184 | """ 185 | 186 | text_format.Parse(config_text, config) 187 | 188 | model = InteractPlanner(config, dim=args.dim, enc_layer=2, heads=8, dropout=0.1, 189 | timestep=5, decoder_dim=384, fpn_len=2, flow_pred=use_flow) 190 | 191 | save_dir = args.save_dir + f"models/" 192 | os.makedirs(save_dir,exist_ok=True) 193 | 194 | torch.cuda.set_device(local_rank) 195 | dist.init_process_group(backend='nccl') 196 | 197 | model = model.to(local_rank) 198 | if args.model_dir!= '': 199 | kw_dict = {} 200 | for k,v in torch.load(save_dir + args.load_dir,map_location='cpu').items(): 201 | kw_dict[k[7:]] = v 202 | model.load_state_dict(kw_dict) 203 | continue_ep = int(args.load_dir.split('_')[-3]) - 1 204 | print(f'model loaded!:epoch {continue_ep + 1}') 205 | else: 206 | continue_ep = 0 207 | 208 | model = DDP(model, device_ids=[local_rank], output_device=local_rank) 209 | 210 | train_dataset = DrivingData(args.data_dir + f'train/*.npz',use_flow=use_flow) 211 | valid_dataset = DrivingData(args.data_dir + f'valid/*.npz',use_flow=use_flow) 212 | 213 | training_size = len(train_dataset) 214 | valid_size = len(valid_dataset) 215 | if dist.get_rank() == 0: 216 | print(f'Length train: {training_size}; Valid: {valid_size}') 217 | 218 | train_sampler = DistributedSampler(train_dataset) 219 | valid_sampler = DistributedSampler(valid_dataset, shuffle=False) 220 | train_data = DataLoader(train_dataset, batch_size=args.batch_size, 221 | sampler=train_sampler, num_workers=16) 222 | valid_data = DataLoader(valid_dataset, batch_size=args.batch_size, 223 | sampler=valid_sampler, num_workers=4) 224 | 225 | model_training(train_data, valid_data, args.epochs, save_dir) 226 | -------------------------------------------------------------------------------- /utils/waymo_tf_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | #### Example field definition 4 | # Features of road graph. 5 | roadgraph_features = { 6 | 'roadgraph_samples/dir': 7 | tf.io.FixedLenFeature([20000, 3], tf.float32, default_value=None), 8 | 'roadgraph_samples/id': 9 | tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None), 10 | 'roadgraph_samples/type': 11 | tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None), 12 | 'roadgraph_samples/valid': 13 | tf.io.FixedLenFeature([20000, 1], tf.int64, default_value=None), 14 | 'roadgraph_samples/xyz': 15 | tf.io.FixedLenFeature([20000, 3], tf.float32, default_value=None), 16 | } 17 | 18 | # Features of other agents. 19 | state_features = { 20 | 'state/id': 21 | tf.io.FixedLenFeature([128], tf.float32, default_value=None), 22 | 'state/type': 23 | tf.io.FixedLenFeature([128], tf.float32, default_value=None), 24 | 'state/is_sdc': 25 | tf.io.FixedLenFeature([128], tf.int64, default_value=None), 26 | 'state/tracks_to_predict': 27 | tf.io.FixedLenFeature([128], tf.int64, default_value=None), 28 | 'state/current/bbox_yaw': 29 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 30 | 'state/current/height': 31 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 32 | 'state/current/length': 33 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 34 | 'state/current/timestamp_micros': 35 | tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None), 36 | 'state/current/valid': 37 | tf.io.FixedLenFeature([128, 1], tf.int64, default_value=None), 38 | 'state/current/vel_yaw': 39 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 40 | 'state/current/velocity_x': 41 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 42 | 'state/current/velocity_y': 43 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 44 | 'state/current/speed': 45 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 46 | 'state/current/width': 47 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 48 | 'state/current/x': 49 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 50 | 'state/current/y': 51 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 52 | 'state/current/z': 53 | tf.io.FixedLenFeature([128, 1], tf.float32, default_value=None), 54 | 'state/future/bbox_yaw': 55 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 56 | 'state/future/height': 57 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 58 | 'state/future/length': 59 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 60 | 'state/future/timestamp_micros': 61 | tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None), 62 | 'state/future/valid': 63 | tf.io.FixedLenFeature([128, 80], tf.int64, default_value=None), 64 | 'state/future/vel_yaw': 65 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 66 | 'state/future/velocity_x': 67 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 68 | 'state/future/velocity_y': 69 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 70 | 'state/future/width': 71 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 72 | 'state/future/x': 73 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 74 | 'state/future/y': 75 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 76 | 'state/future/z': 77 | tf.io.FixedLenFeature([128, 80], tf.float32, default_value=None), 78 | 'state/past/bbox_yaw': 79 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 80 | 'state/past/height': 81 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 82 | 'state/past/length': 83 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 84 | 'state/past/timestamp_micros': 85 | tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None), 86 | 'state/past/valid': 87 | tf.io.FixedLenFeature([128, 10], tf.int64, default_value=None), 88 | 'state/past/vel_yaw': 89 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 90 | 'state/past/velocity_x': 91 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 92 | 'state/past/velocity_y': 93 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 94 | 'state/past/speed': 95 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 96 | 'state/past/width': 97 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 98 | 'state/past/x': 99 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 100 | 'state/past/y': 101 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 102 | 'state/past/z': 103 | tf.io.FixedLenFeature([128, 10], tf.float32, default_value=None), 104 | 'scenario/id': 105 | tf.io.FixedLenFeature([1], tf.string, default_value=None), 106 | } 107 | 108 | # Features of traffic lights. 109 | traffic_light_features = { 110 | 'traffic_light_state/current/state': 111 | tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None), 112 | 'traffic_light_state/current/valid': 113 | tf.io.FixedLenFeature([1, 16], tf.int64, default_value=None), 114 | 'traffic_light_state/current/x': 115 | tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None), 116 | 'traffic_light_state/current/y': 117 | tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None), 118 | 'traffic_light_state/current/z': 119 | tf.io.FixedLenFeature([1, 16], tf.float32, default_value=None), 120 | 'traffic_light_state/past/state': 121 | tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None), 122 | 'traffic_light_state/past/valid': 123 | tf.io.FixedLenFeature([10, 16], tf.int64, default_value=None), 124 | 'traffic_light_state/past/x': 125 | tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None), 126 | 'traffic_light_state/past/y': 127 | tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None), 128 | 'traffic_light_state/past/z': 129 | tf.io.FixedLenFeature([10, 16], tf.float32, default_value=None), 130 | } 131 | 132 | features_description = {} 133 | features_description.update(roadgraph_features) 134 | features_description.update(state_features) 135 | features_description.update(traffic_light_features) 136 | 137 | # road label 138 | road_label = {1:'LaneCenter-Freeway', 2:'LaneCenter-SurfaceStreet', 3:'LaneCenter-BikeLane', 6:'RoadLine-BrokenSingleWhite', 139 | 7:'RoadLine-SolidSingleWhite', 8:'RoadLine-SolidDoubleWhite', 9:'RoadLine-BrokenSingleYellow', 10:'RoadLine-BrokenDoubleYellow', 140 | 11:'Roadline-SolidSingleYellow', 12:'Roadline-SolidDoubleYellow', 13:'RoadLine-PassingDoubleYellow', 15:'RoadEdgeBoundary', 141 | 16:'RoadEdgeMedian', 17:'StopSign', 18:'Crosswalk', 19:'SpeedBump'} 142 | 143 | road_line_map = {1:['xkcd:grey', 'solid', 14], 2:['xkcd:grey', 'solid', 14], 3:['xkcd:grey', 'solid', 10], 5:['w', 'solid', 2], 6:['w', 'dashed', 2], 144 | 7:['w', 'solid', 2], 8:['w', 'solid', 2], 9:['xkcd:yellow', 'dashed', 4], 10:['xkcd:yellow', 'dashed', 2], 145 | 11:['xkcd:yellow', 'solid', 2], 12:['xkcd:yellow', 'solid', 3], 13:['xkcd:yellow', 'dotted', 1.5], 15:['y', 'solid', 4.5], 146 | 16:['y', 'solid', 4.5], 17:['r', '.', 40], 18:['b', 'solid', 13], 19:['xkcd:orange', 'solid', 13]} 147 | 148 | # traffic light label 149 | light_label = {0:'Unknown', 1:'Arrow_Stop', 2:'Arrow_Caution', 3:'Arrow_Go', 4:'Stop', 5:'Caution', 6:'Go', 7:'Flashing_Stop', 8:'Flashing_Caution'} 150 | light_state_map = {0:'k', 1:'r', 2:'b', 3:'g', 4:'r', 5:'b', 6:'g', 7:'r', 8:'b'} 151 | light_state_map_num = {0:0, 1:1, 2:2, 3:3, 4:1, 5:2, 6:3, 7:1, 8:2} 152 | light_state_rank_map = {0: 1, 1: 4, 2: 3, 3: 2, 4: 4, 5: 3, 6: 2, 7: 4, 8: 3} 153 | light_near_state_map = {0:'black', 1:'darkred', 2:'darkblue', 3:'darkgreen', 4:'darkred', 5:'darkblue', 6:'darkgreen', 7:'darkred', 8:'darkblue'} 154 | 155 | def linecolormap(value,m_per_pixel): 156 | return {'color':value[0], 'linestyle':value[1], 'linewidth': value[2]*m_per_pixel/3} 157 | 158 | def light_linecolormap(light_state,value,m_per_pixel): 159 | return {'color':light_state_map[light_state], 'linestyle':value[1], 'linewidth': value[2]*m_per_pixel/3, 160 | 'zorder': light_state_rank_map[light_state]} 161 | 162 | def lightnear_linecolormap(light_state,value,m_per_pixel): 163 | return {'color':light_near_state_map[light_state], 'linestyle':value[1], 'linewidth': value[2]*m_per_pixel/3, 164 | 'zorder': light_state_rank_map[light_state]} 165 | 166 | def traffic_light_map(road_type, tl_state, tl_near, m_per_pixel): 167 | road_val = road_line_map[road_type] 168 | if tl_near: 169 | return lightnear_linecolormap(tl_state, road_val, m_per_pixel) 170 | else: 171 | return light_linecolormap(tl_state, road_val, m_per_pixel) 172 | 173 | from waymo_open_dataset.protos import scenario_pb2 174 | 175 | _ObjectType = scenario_pb2.Track.ObjectType 176 | ALL_AGENT_TYPES = [ 177 | _ObjectType.TYPE_VEHICLE, 178 | _ObjectType.TYPE_PEDESTRIAN, 179 | _ObjectType.TYPE_CYCLIST, 180 | ] 181 | 182 | agent_color={ 183 | _ObjectType.TYPE_VEHICLE:'r', 184 | _ObjectType.TYPE_PEDESTRIAN:'g', 185 | _ObjectType.TYPE_CYCLIST:'b' 186 | } 187 | -------------------------------------------------------------------------------- /utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | 7 | import torch 8 | import logging 9 | import glob 10 | 11 | import numpy as np 12 | from torch.utils.data import Dataset 13 | from torch.nn import functional as F 14 | from google.protobuf import text_format 15 | 16 | # import torchmetrics 17 | from torchvision.ops.focal_loss import sigmoid_focal_loss 18 | 19 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 20 | import random 21 | 22 | 23 | def initLogging(log_file: str, level: str = "INFO"): 24 | logging.basicConfig(filename=log_file, filemode='w', 25 | level=getattr(logging, level, None), 26 | format='[%(levelname)s %(asctime)s] %(message)s', 27 | datefmt='%m-%d %H:%M:%S') 28 | logging.getLogger().addHandler(logging.StreamHandler()) 29 | 30 | class DrivingData(Dataset): 31 | def __init__(self, data_dir, use_flow=False): 32 | self.data_list = glob.glob(data_dir) 33 | self.use_flow = use_flow 34 | 35 | def __len__(self): 36 | return len(self.data_list) 37 | 38 | def __getitem__(self, idx): 39 | data = np.load(self.data_list[idx],allow_pickle=True) 40 | ego = data['ego'] 41 | neighbor = data['neighbors'][:, :11, :] 42 | 43 | neighbor_map_lanes = data['neighbor_map_lanes'] 44 | ego_map_lane = data['ego_map_lane'] 45 | 46 | neighbor_crosswalk = data['neighbor_map_crosswalks'] 47 | ego_crosswalk = data['ego_map_crosswalk'] 48 | 49 | ego_future_states = data['gt_future_states'] 50 | 51 | ref_line = data['ref_line'] 52 | goal = data['goal'] 53 | 54 | hist_ogm = data['hist_ogm'] 55 | ego_ogm = data['ego_ogm'] 56 | 57 | gt_obs = data['gt_obs'] 58 | gt_occ = data['gt_occ'] 59 | gt_ego = data['ego_ogm_gt'] 60 | 61 | if self.use_flow: 62 | hist_flow = data['hist_flow'] 63 | gt_flow = data['gt_flow'] 64 | road_graph = data['rg'].astype(np.float32) 65 | road_graph = np.array(road_graph) 66 | 67 | return ego, neighbor, ego_map_lane, neighbor_map_lanes, ego_crosswalk, neighbor_crosswalk,\ 68 | ego_future_states, ref_line, goal, hist_ogm, ego_ogm, gt_obs, gt_occ, gt_ego, hist_flow, gt_flow, road_graph 69 | 70 | return ego, neighbor, ego_map_lane, neighbor_map_lanes, ego_crosswalk, neighbor_crosswalk,\ 71 | ego_future_states, ref_line, goal, hist_ogm, ego_ogm, gt_obs, gt_occ, gt_ego 72 | 73 | def batch_to_dict(batch, local_rank, use_flow=False): 74 | if use_flow: 75 | ego, neighbor, ego_map_lane, neighbor_map_lanes, ego_crosswalk, neighbor_crosswalk,\ 76 | ego_future_states, ref_line, goal, hist_ogm, ego_ogm, gt_obs, gt_occ, gt_ego, hist_flow, gt_flow, road_graph = batch 77 | else: 78 | ego, neighbor, ego_map_lane, neighbor_map_lanes, ego_crosswalk, neighbor_crosswalk,\ 79 | ego_future_states, ref_line, goal, hist_ogm, ego_ogm, gt_obs, gt_occ, gt_ego = batch 80 | 81 | # if not use_flow: 82 | ego_mask = (1 - ego_ogm.to(local_rank).float().unsqueeze(-1)) 83 | hist_ogm = hist_ogm.to(local_rank).float()*ego_mask 84 | b, h, w, t, c = hist_ogm.shape 85 | hist_ogm = hist_ogm.reshape(b, h ,w, t*c) 86 | 87 | input_dict = { 88 | 'ego_state': ego.to(local_rank).float(), 89 | 'neighbor_state': neighbor.to(local_rank).float(), 90 | 'ego_map_lane': ego_map_lane.to(local_rank).float(), 91 | 'neighbor_map_lanes': neighbor_map_lanes.to(local_rank).float(), 92 | 'ego_map_crosswalk': ego_crosswalk.to(local_rank).float(), 93 | 'neighbor_map_crosswalks': neighbor_crosswalk.to(local_rank).float(), 94 | 'hist_ogm': hist_ogm, 95 | 'ego_ogm': ego_ogm.to(local_rank).float(), 96 | } 97 | target_dict = { 98 | 'ref_line': ref_line.to(local_rank).float(), 99 | 'goal': goal.to(local_rank).float(), 100 | 'ego_future_states':ego_future_states[..., [0,1,4]].to(local_rank).float(), 101 | 'gt_obs':gt_obs.to(local_rank).float(), 102 | 'gt_occ':gt_occ.to(local_rank).sum(-1).clamp(0, 1).float(), 103 | 'gt_ego':gt_ego.to(local_rank).float(), 104 | } 105 | if not use_flow: 106 | return input_dict, target_dict 107 | else: 108 | road_graph = road_graph[:, 128:128+256, 128:128+256, :] 109 | input_dict = { 110 | 'ego_state': ego.to(local_rank).float(), 111 | 'neighbor_state': neighbor.to(local_rank).float(), 112 | 'ego_map_lane': ego_map_lane.to(local_rank).float(), 113 | 'neighbor_map_lanes': neighbor_map_lanes.to(local_rank).float(), 114 | 'ego_map_crosswalk': ego_crosswalk.to(local_rank).float(), 115 | 'neighbor_map_crosswalks': neighbor_crosswalk.to(local_rank).float(), 116 | 'hist_ogm': hist_ogm, 117 | 'ego_ogm': ego_ogm.to(local_rank).float(), 118 | 'hist_flow': hist_flow.to(local_rank).float(), 119 | 'road_graph': road_graph.to(local_rank).float(), 120 | } 121 | target_dict = { 122 | 'ref_line': ref_line.to(local_rank).float(), 123 | 'goal': goal.to(local_rank).float(), 124 | 'ego_future_states':ego_future_states[..., [0,1,4]].to(local_rank).float(), 125 | 'gt_obs':gt_obs.to(local_rank).float(), 126 | 'gt_occ':gt_occ.to(local_rank).sum(-1).clamp(0, 1).float(), 127 | 'gt_ego':gt_ego.to(local_rank).float(), 128 | 'gt_flow':gt_flow.to(local_rank).float() 129 | } 130 | return input_dict, target_dict 131 | 132 | def occupancy_loss(outputs, targets, use_flow=False): 133 | ego_mask = 1-targets['gt_ego'] 134 | gt_obs = targets['gt_obs'][..., 0, :] * ego_mask.unsqueeze(-1) 135 | target_ogm = gt_obs.permute(0, 4, 1, 2, 3) #[b, c, t, h, w] 136 | # actors: 137 | actor_loss: torch.Tensor = 0 138 | alpha_list = [0.1, 0.01, 0.01] 139 | for i in range(3): 140 | ref = outputs[:, i] 141 | tar = target_ogm[:, i] 142 | loss = sigmoid_focal_loss(ref, tar, alpha=alpha_list[i], gamma=1, reduction='mean') 143 | actor_loss += loss 144 | actor_loss = actor_loss / 3 145 | 146 | #occulsions: 147 | occ_ref, occ_tar = outputs[:, 3], targets['gt_occ']*ego_mask#[:, 0] 148 | occ_loss = sigmoid_focal_loss(occ_ref, occ_tar, alpha=0.05, gamma=1, reduction='mean') 149 | 150 | if use_flow: 151 | flow_loss: torch.Tensor = 0 152 | target_flow = targets['gt_flow'] 153 | target_flow = target_flow.sum(-1) 154 | flow_exists = torch.logical_or(torch.ne(target_flow[..., 0], 0), torch.ne(target_flow[..., 1], 0)).float() 155 | flow_outputs = outputs[:, -2:] 156 | b, c, t, h, w = flow_outputs.shape 157 | flow_outputs = flow_outputs.permute(0, 2, 3, 4, 1) 158 | pred_flow = torch.mul(flow_outputs, flow_exists.unsqueeze(-1)) 159 | exist_num = torch.nan_to_num(torch.sum(flow_exists)/2, 1.0) 160 | if exist_num > 0: 161 | flow_loss += F.smooth_l1_loss(pred_flow, target_flow, reduction='sum') / exist_num 162 | else: 163 | flow_loss = None 164 | 165 | return actor_loss, occ_loss, flow_loss 166 | 167 | def modal_selections(x, y, mode): 168 | b, n, t, d = x.shape 169 | if mode=='fde': 170 | fde_dist = torch.norm(x[:, :, -1, :2] - y[:, -1, :2].unsqueeze(1).expand(-1, n, -1), dim=-1) 171 | dist = torch.argmin(fde_dist, dim=-1) 172 | else: 173 | # joint ade and fde 174 | fde_dist = torch.norm(x[:, :, -1, :2] - y[:, -1, :2].unsqueeze(1).expand(-1, n, -1), dim=-1) 175 | ade_dist = torch.norm(x[:, :, :, :2] - y[:, :, :2].unsqueeze(1).expand(-1, n, -1, -1), dim=-1).mean(-1) 176 | dist = torch.argmin(0.5*fde_dist + ade_dist, dim=-1) 177 | 178 | return dist 179 | 180 | 181 | def infer_modal_selection(traj, score, targets, use_planning): 182 | B = score.shape[0] 183 | gt_modes = torch.argmax(score, dim=1) 184 | selected_trajs = traj[torch.arange(B)[:, None], gt_modes.unsqueeze(-1)].squeeze(1) 185 | return selected_trajs, gt_modes 186 | 187 | 188 | def imitation_loss(traj, score, targets, use_planning=False): 189 | gt_future = targets['ego_future_states'] 190 | if isinstance(traj, list): 191 | il_loss: torch.Tensor = 0 192 | for tr,sc in zip(traj, score): 193 | loss, selected_trajs, gt_modes = single_layer_planning_loss(tr, sc, targets) 194 | il_loss += loss 195 | else: 196 | il_loss, selected_trajs, gt_modes = single_layer_planning_loss(traj, score, targets) 197 | 198 | return il_loss, selected_trajs, gt_modes 199 | 200 | 201 | def single_layer_planning_loss(traj, score, targets): 202 | gt_future = targets['ego_future_states'] 203 | p_d = traj.shape[-1] 204 | gt_future = gt_future[...,:p_d] 205 | gt_modes = modal_selections(traj, gt_future,mode='joint') 206 | 207 | classification_loss = F.cross_entropy(score, gt_modes, label_smoothing=0.2) 208 | B = traj.shape[0] 209 | selected_trajs = traj[torch.arange(B)[:, None], gt_modes.unsqueeze(-1)].squeeze(1) 210 | goal_time = [9, 29, 49] 211 | 212 | ade_loss = F.smooth_l1_loss(selected_trajs, gt_future[...,:p_d]) 213 | fde_loss = F.smooth_l1_loss(selected_trajs[...,goal_time,:], gt_future[...,goal_time,:p_d]) 214 | 215 | il_loss = ade_loss + 0.5*fde_loss + 2*classification_loss 216 | 217 | return il_loss, selected_trajs, gt_modes 218 | 219 | 220 | def get_contributing_params(y, top_level=True): 221 | nf = y.grad_fn.next_functions if top_level else y.next_functions 222 | for f, _ in nf: 223 | try: 224 | yield f.variable 225 | except AttributeError: 226 | pass # node has no tensor 227 | if f is not None: 228 | yield from get_contributing_params(f, top_level=False) 229 | 230 | def check_non_contributing_params(model, outputs): 231 | contrib_params = set() 232 | all_parameters = set(model.parameters()) 233 | 234 | for output in outputs: 235 | contrib_params.update(get_contributing_params(output)) 236 | print(all_parameters - contrib_params) -------------------------------------------------------------------------------- /utils/plan_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | import numpy as np 12 | import math 13 | import bisect 14 | 15 | import matplotlib.pyplot as plt 16 | 17 | 18 | def ref_line_grids(ref_line, widths=10, pixels_per_meter=3.2, left=True): 19 | ''' 20 | generate the mapping of Cartisan coordinates (x, y) 21 | according to the Frenet grids (s, d) 22 | inputs: ref_lines (b, length, 2), width 23 | outputs: refline_grids (b, length, width*2, 2) 24 | ''' 25 | width_d = (torch.arange(-widths, widths) + 0.5) / pixels_per_meter 26 | # print(ref_line.shape) 27 | b, l, c = ref_line.shape 28 | width_d = width_d.unsqueeze(0).unsqueeze(1).expand(b, l, -1).to(ref_line.device) 29 | angle = ref_line[:, :, 2] 30 | angle = (angle + np.pi) % (2*np.pi) - np.pi #- 31 | 32 | ref_x = ref_line[:, :, 0:1] 33 | ref_y = ref_line[:, :, 1:2] 34 | 35 | # output coords always conincide with ogm's coords settings 36 | x = -torch.sin(angle).unsqueeze(-1) * width_d + ref_x 37 | y = torch.cos(angle).unsqueeze(-1) * width_d + ref_y 38 | 39 | cart_grids = torch.stack([x, y], dim=-1) 40 | return cart_grids 41 | 42 | def ref_line_ogm_sample(ogm, rl_grids, config): 43 | """ 44 | scatter the ogm fields to ref_line fields 45 | according to the ref_line_grids 46 | inputs: ogm [b, h, w] 47 | grids: [b, l_s, l_d, 2] 48 | outputs: ref_line fields: [b, l_s, l_d] 49 | """ 50 | points_x, points_y = rl_grids[..., 0], rl_grids[..., 1] 51 | pixels_per_meter = config.pixels_per_meter 52 | points_x = torch.round(-points_y * pixels_per_meter) + config.sdc_x_in_grid 53 | points_y = torch.round(-points_x * pixels_per_meter) + config.sdc_y_in_grid 54 | 55 | # Filter out points that are located outside the FOV of topdown map. 56 | point_is_in_fov = torch.logical_and( 57 | torch.logical_and( 58 | torch.greater_equal(points_x, 0), torch.greater_equal(points_y, 0)), 59 | torch.logical_and( 60 | torch.less(points_x, config.grid_width_cells), 61 | torch.less(points_y, config.grid_height_cells))).float() 62 | 63 | w_axis_in = points_x * point_is_in_fov 64 | h_axis_in = points_y * point_is_in_fov 65 | 66 | w_axis_in = w_axis_in.long() 67 | h_axis_in = h_axis_in.long() 68 | 69 | b, h, w = w_axis_in.shape 70 | B = torch.arange(b).long() 71 | refline_fields = ogm[B[:, None], h_axis_in.view(b, -1), w_axis_in.view(b, -1)] 72 | refline_fields = refline_fields.view(b, h, w) 73 | 74 | # mask refline_fields not in fovs: 75 | refline_fields = refline_fields * point_is_in_fov 76 | return refline_fields 77 | 78 | def generate_ego_pos_at_field(ego_pos, ref_lines, angle): 79 | ''' 80 | transfrom the ego occupancy into Frenet for refline_field: 81 | inputs: ego_pos: [B, 2] (x, y) angle: ego angles 82 | ref_lines : [B, L, 3] (a, y, angle) 83 | outputs 3 quantile points (s, d, theta) and safe-distance ||1/4 h, w||_2 84 | ''' 85 | # 1. Transform ego pos (x, y, angle) into Frenet Coords (s, l, the): 86 | dist = torch.norm(ego_pos.unsqueeze(1) - ref_lines[..., :2], dim=-1, p=2.0) 87 | s = torch.argmin(dist, dim=1) 88 | b = ref_lines.shape[0] 89 | B = torch.arange(b).long() 90 | sel_ref = ref_lines[B, s, :] 91 | s = (s - 200) *0.1 92 | 93 | x_r, y_r, theta_r = sel_ref[:, 0], sel_ref[:, 1], sel_ref[:, 2] 94 | x, y = ego_pos[:, 0], ego_pos[:, 1] 95 | 96 | sgn = (y - y_r) * torch.cos(theta_r) - (x - x_r) * torch.sin(theta_r) 97 | dis = torch.sqrt(torch.square(x - x_r) + torch.square(y - y_r)) 98 | l = torch.sign(sgn) * dis 99 | 100 | the = angle - theta_r 101 | # the += np.pi/2 102 | the = (the + np.pi) % (2* np.pi) - np.pi 103 | 104 | ego = torch.stack([s, l, the], dim=-1) 105 | 106 | return ego 107 | 108 | 109 | def refline_meshgrids(ref_line_field, pixels_per_meter=3.2): 110 | ''' 111 | build the (s,l) meshgrids for ref_line field 112 | ''' 113 | device = ref_line_field.device 114 | b, s, l, _ = ref_line_field.shape 115 | widths = int(l/2) 116 | mesh_l = (torch.arange(-widths, widths) + 0.5) / pixels_per_meter 117 | mesh_s = (torch.arange(s).float() + 0.5) * 0.1 #/ pixels_per_meter 118 | mesh_s, mesh_l = torch.meshgrid(mesh_s, mesh_l) 119 | mesh_sl = torch.stack([mesh_s, mesh_l], dim=-1) 120 | mesh_sl = mesh_sl.unsqueeze(0).expand(b, -1, -1, -1) 121 | mesh_sl = mesh_sl.to(device) 122 | return mesh_sl 123 | 124 | 125 | def gather_nd_slow(ogm, h_axis_in, w_axis_in): 126 | b, h, w = w_axis_in.shape 127 | output = torch.zeros((b, h, w)).to(w_axis_in.device) 128 | for i in range(b): 129 | for j in range(h): 130 | for k in range(w): 131 | output[i, j, k] = ogm[i, h_axis_in[i, j, k], w_axis_in[i, j, k]] 132 | return output 133 | 134 | 135 | 136 | class Spline(object): 137 | """ 138 | Cubic Spline class 139 | """ 140 | 141 | def __init__(self, x, y): 142 | self.b, self.c, self.d, self.w = [], [], [], [] 143 | 144 | self.x = x 145 | self.y = y 146 | 147 | self.nx = len(x) # dimension of x 148 | h = np.diff(x) + 1e-3 149 | 150 | # calc coefficient c 151 | self.a = [iy for iy in y] 152 | 153 | # calc coefficient c 154 | A = self.__calc_A(h) 155 | B = self.__calc_B(h) 156 | self.c = np.linalg.solve(A, B) 157 | # print(self.c1) 158 | 159 | # calc spline coefficient b and d 160 | for i in range(self.nx - 1): 161 | self.d.append((self.c[i + 1] - self.c[i]) / (3.0 * h[i])) 162 | tb = (self.a[i + 1] - self.a[i]) / h[i] - h[i] * (self.c[i + 1] + 2.0 * self.c[i]) / 3.0 163 | self.b.append(tb) 164 | 165 | def calc(self, t): 166 | """ 167 | Calc position 168 | if t is outside of the input x, return None 169 | """ 170 | 171 | if t < self.x[0]: 172 | return None 173 | elif t > self.x[-1]: 174 | return None 175 | 176 | i = self.__search_index(t) 177 | dx = t - self.x[i] 178 | result = self.a[i] + self.b[i] * dx + self.c[i] * dx ** 2.0 + self.d[i] * dx ** 3.0 179 | 180 | return result 181 | 182 | def calcd(self, t): 183 | """ 184 | Calc first derivative 185 | if t is outside of the input x, return None 186 | """ 187 | 188 | if t < self.x[0]: 189 | return None 190 | elif t > self.x[-1]: 191 | return None 192 | 193 | i = self.__search_index(t) 194 | dx = t - self.x[i] 195 | result = self.b[i] + 2.0 * self.c[i] * dx + 3.0 * self.d[i] * dx ** 2.0 196 | 197 | return result 198 | 199 | def calcdd(self, t): 200 | """ 201 | Calc second derivative 202 | """ 203 | 204 | if t < self.x[0]: 205 | return None 206 | elif t > self.x[-1]: 207 | return None 208 | 209 | i = self.__search_index(t) 210 | dx = t - self.x[i] 211 | result = 2.0 * self.c[i] + 6.0 * self.d[i] * dx 212 | 213 | return result 214 | 215 | def __search_index(self, x): 216 | """ 217 | search data segment index 218 | """ 219 | return bisect.bisect(self.x, x) - 1 220 | 221 | def search_index(self, x): 222 | """ 223 | search data segment index 224 | """ 225 | return bisect.bisect(self.x, x) - 1 226 | 227 | def __calc_A(self, h): 228 | """ 229 | calc matrix A for spline coefficient c 230 | """ 231 | A = np.zeros((self.nx, self.nx)) 232 | A[0, 0] = 1.0 233 | 234 | for i in range(self.nx - 1): 235 | if i != (self.nx - 2): 236 | A[i + 1, i + 1] = 2.0 * (h[i] + h[i + 1]) 237 | A[i + 1, i] = h[i] 238 | A[i, i + 1] = h[i] 239 | 240 | A[0, 1] = 0.0 241 | A[self.nx - 1, self.nx - 2] = 0.0 242 | A[self.nx - 1, self.nx - 1] = 1.0 243 | 244 | return A 245 | 246 | def __calc_B(self, h): 247 | """ 248 | calc matrix B for spline coefficient c 249 | """ 250 | B = np.zeros(self.nx) 251 | 252 | for i in range(self.nx - 2): 253 | B[i + 1] = 3.0 * (self.a[i + 2] - self.a[i + 1]) / h[i + 1] - 3.0 * (self.a[i + 1] - self.a[i]) / h[i] 254 | 255 | return B 256 | 257 | 258 | class Spline2D: 259 | """ 260 | 2D Cubic Spline class 261 | """ 262 | 263 | def __init__(self, x, y): 264 | self.s = self.__calc_s(x, y) 265 | self.sx = Spline(self.s, x) 266 | self.sy = Spline(self.s, y) 267 | 268 | def __calc_s(self, x, y): 269 | dx = np.diff(x) 270 | dy = np.diff(y) 271 | self.ds = np.hypot(dx, dy) 272 | s = [0] 273 | s.extend(np.cumsum(self.ds)) 274 | 275 | return s 276 | 277 | def calc_position(self, s): 278 | """ 279 | calc position 280 | """ 281 | x = self.sx.calc(s) 282 | y = self.sy.calc(s) 283 | 284 | return x, y 285 | 286 | def calc_curvature(self, s): 287 | """ 288 | calc curvature 289 | """ 290 | dx = self.sx.calcd(s) 291 | ddx = self.sx.calcdd(s) 292 | dy = self.sy.calcd(s) 293 | ddy = self.sy.calcdd(s) 294 | k = (ddy * dx - ddx * dy) / ((dx ** 2 + dy ** 2)**(3 / 2)) 295 | 296 | return k 297 | 298 | def calc_yaw(self, s): 299 | """ 300 | calc yaw 301 | """ 302 | dx = self.sx.calcd(s) 303 | dy = self.sy.calcd(s) 304 | yaw = math.atan2(dy, dx) 305 | 306 | return yaw 307 | 308 | def search_index(self,s): 309 | i = self.sx.search_index(s) 310 | j = self.sy.search_index(s) 311 | return i,j 312 | 313 | def generate_target_course(x, y): 314 | csp = Spline2D(x, y) 315 | s = np.arange(0, csp.s[-1], 0.1) 316 | rx, ry, ryaw, rk = [], [], [], [] 317 | for i_s in s: 318 | ix, iy = csp.calc_position(i_s) 319 | rx.append(ix) 320 | ry.append(iy) 321 | ryaw.append(csp.calc_yaw(i_s)) 322 | rk.append(csp.calc_curvature(i_s)) 323 | 324 | return rx, ry, ryaw, rk, csp 325 | 326 | 327 | -------------------------------------------------------------------------------- /planner.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | sys.path.append('/theseus') 7 | 8 | import theseus as th 9 | import matplotlib.pyplot as plt 10 | 11 | import torch 12 | import torch.nn as nn 13 | import numpy as np 14 | 15 | from utils.plan_utils import * 16 | import casadi as cs 17 | import math 18 | import scipy 19 | 20 | class Planner(object): 21 | def __init__(self, 22 | device, 23 | horizon=4, 24 | g_length=50, 25 | g_width=40, 26 | test_iters=50, 27 | test_step=0.3, 28 | ): 29 | super(Planner, self).__init__() 30 | 31 | self.g_width = g_width 32 | self.horizon = horizon 33 | control_variables = th.Vector(dof=horizon *10 * 2, name="control_variables") 34 | ref_line_fields = th.Variable(torch.empty(1, horizon, g_length, g_width, 2), name="ref_line_field") 35 | ref_line_costs = th.Variable(torch.empty(1, horizon, g_length, g_width), name="ref_line_costs") 36 | lwt = th.Variable(torch.empty(1, horizon*10, 3), name="lwt") 37 | current_state = th.Variable(torch.empty(1, 2, 4), name="current_state") 38 | spl = th.Variable(torch.empty(1, 1), name="speed_limit") 39 | stp = th.Variable(torch.empty(1, 1), name="stop_point") 40 | il_lane = th.Variable(torch.empty(1, horizon*10, 2), name="il_lane") 41 | 42 | objective = th.Objective() 43 | objective = self.cost_function(objective, control_variables, ref_line_fields, 44 | ref_line_costs, lwt, current_state, spl, stp, il_lane) 45 | self.optimizer = th.GaussNewton(objective, th.CholmodSparseSolver, vectorize=False, 46 | max_iterations=test_iters, step_size=test_step, abs_err_tolerance=1e-2) 47 | 48 | self.layer = th.TheseusLayer(self.optimizer, vectorize=False) 49 | self.layer.to(device=device) 50 | 51 | self.max_acc = 5 52 | self.max_delta = math.pi / 6 53 | self.ts = 0.1 54 | self.sigma = 1.0 55 | self.num = 0 56 | 57 | def preprocess(self, ego_state, ego_plan, ref_lines, ogm_prediction, type_mask, config, left=True): 58 | #computing the angle: 59 | diff_traj = torch.diff(ego_plan, axis=1) 60 | diff_traj = torch.cat([diff_traj, diff_traj[:, -1, :].unsqueeze(1)], dim=1) 61 | angle = torch.nan_to_num(torch.atan2(diff_traj[:,:,1], diff_traj[:,:,0]), 0).clamp(-0.67,0.67) 62 | 63 | #reshape time axis: l to the batch axis 64 | frenet_plan, angle = ego_plan[:, :50, :2], angle[:, :50] 65 | length, width = ego_state[:, -1, 6], ego_state[:, -1, 7] 66 | b, l, d = frenet_plan.shape 67 | length, width = length.unsqueeze(1).expand(-1, l), width.unsqueeze(1).expand(-1, l) 68 | frenet_plan = frenet_plan.reshape(b*l, d) 69 | speed = ego_state[:, -2:, 3] 70 | # print(ego_state[:, -2:]) 71 | angle = angle.reshape(b*l) 72 | 73 | # ref_lines = ref_lines[:, ::2, :] 74 | speed_limit = ref_lines[:, :, 4] 75 | b, t, d = ref_lines.shape 76 | orf = ref_lines 77 | ref_lines = ref_lines.unsqueeze(1).expand(-1, l, -1, -1).reshape(b*l, t, d) 78 | 79 | #generate the ego mask 80 | ego = generate_ego_pos_at_field(frenet_plan, ref_lines, angle) 81 | spl = torch.max(speed_limit, dim=1, keepdim=True)[0] 82 | stp = torch.max(speed_limit==0, dim=1, keepdim=True)[1] * 0.1 83 | theta = ego[:, -1].reshape(b, l) 84 | 85 | ref_lines = ref_lines.reshape(b, l, t, d)[:, ::10].reshape(b*l//10, t, d) 86 | 87 | refline_fields = ref_line_grids(ref_lines, widths=int(self.g_width/2), left=left ,pixels_per_meter=3.2) 88 | 89 | ogm_prediction = ogm_prediction[:, 0] * (ogm_prediction[:, 0] > 0.3) + \ 90 | 10*ogm_prediction[:, 1] * type_mask[:, 1,None,None,None] + 91 | 10*ogm_prediction[:, 2] * type_mask[:, 2,None,None,None] 92 | ogm_prediction = ogm_prediction.clamp(0, 1) 93 | 94 | ogm_prediction = ogm_prediction * (ogm_prediction > 0.1) 95 | 96 | b, lp, h, w = ogm_prediction.shape 97 | og = ogm_prediction 98 | ogm_prediction = ogm_prediction.reshape(b*l//10, h, w) 99 | 100 | ogm_refline_fields = ref_line_ogm_sample(ogm_prediction, refline_fields, config) 101 | _, h, w = ogm_refline_fields.shape 102 | 103 | current_state = generate_ego_pos_at_field(ego_state[:, -1, [0, 1]],orf,ego_state[:, -1, 2]) 104 | lcurrent_state = generate_ego_pos_at_field(ego_state[:, -2, [0, 1]],orf,ego_state[:, -2, 2]) 105 | 106 | all_ego = torch.cat([current_state[:, None, :2], ego[:, :2].reshape(b, l, 2)], dim=1) 107 | d_ego = torch.diff(all_ego, dim=1) / 0.1 108 | 109 | c_state = torch.stack([lcurrent_state, current_state], dim=-2) 110 | c_state = torch.concat([c_state, speed.unsqueeze(-1)], dim=-1) 111 | il_lane = ego[:, :2].reshape(b, l, 2) 112 | # print(il_lane[0, 9::10]) 113 | 114 | return { 115 | 'control_variables': d_ego.reshape(b, l*2).detach(), 116 | 'il_lane': il_lane.detach(), 117 | 'lwt': torch.stack([length, width, theta], dim=-1).detach(), 118 | 'ref_line_field': refline_fields.reshape(b, lp, h, w, 2).detach(), 119 | 'ref_line_costs': ogm_refline_fields.reshape(b, lp, h, w).detach(), 120 | 'current_state': c_state.detach(), 121 | 'speed_limit': spl, 122 | 'stop_point':stp, 123 | } 124 | 125 | def il_cost(self, optim_vars, aux_vars): 126 | ego = optim_vars[0].tensor.view(-1, self.horizon*10, 2) 127 | current_state = aux_vars[0].tensor 128 | ds = ego[:, :, 0].clamp(min=0) 129 | dl = ego[:, :, 1] 130 | s = current_state[:, -1, 0][:,None] + torch.cumsum(ds * 0.1, dim=-1) 131 | L = current_state[:, -1, 1][:,None] + torch.cumsum(dl * 0.1, dim=-1) 132 | ego = torch.stack([s, L], dim=-1) 133 | il_lane = aux_vars[1].tensor 134 | cost = torch.abs(il_lane - ego).mean(-1) 135 | return 1 * cost 136 | 137 | def collision_cost(self, optim_vars, aux_vars): 138 | ego = optim_vars 139 | lwt, refline_fields, ref_line_costs, current_state = aux_vars 140 | lwt, refline_fields, ref_line_costs = lwt.tensor, refline_fields.tensor, ref_line_costs.tensor 141 | b, l, h, w = ref_line_costs.shape 142 | lwt = lwt[:, 9::10,:].reshape(b*l, 3) 143 | ego = ego[0].tensor.view(b, l*10, 2) 144 | ds = ego[:, :, 0].clamp(min=0) 145 | dl = ego[:, :, 1] 146 | s = current_state[:, -1, 0][:,None] + torch.cumsum(ds * 0.1, dim=-1) 147 | L = current_state[:, -1, 1][:,None] + torch.cumsum(dl * 0.1, dim=-1) 148 | ego = torch.stack([s, L], dim=-1)[:, 4::5, :] 149 | ego = ego.view(b, l, 2, 2) 150 | ego = ego.view(b*l, 2, 2) 151 | 152 | refline_fields = refline_fields.view(b*l, h, w, 2) 153 | ref_line_costs = ref_line_costs.view(b*l, h, w) 154 | 155 | mesh_sl = refline_meshgrids(refline_fields, pixels_per_meter=1.6) 156 | 157 | safety_cost_mask = ref_line_costs 158 | safety_cost_mask = safety_cost_mask > 0.3 159 | 160 | diff = mesh_sl.unsqueeze(1) - ego.unsqueeze(-2).unsqueeze(-2) 161 | 162 | interactive_mask = (diff[..., 0] > 0) * (torch.abs(diff[..., 1]) < 7.5) 163 | ego_dist = torch.sqrt(torch.square(diff[..., 0]) + torch.square(diff[..., 1])) 164 | ego_dist = (5 - ego_dist) * (ego_dist < 5) * interactive_mask 165 | 166 | ego_s = ego_dist * safety_cost_mask.unsqueeze(1) 167 | safety_cost_s = ego_s.sum(-1).sum(-1) 168 | 169 | safety_cost = safety_cost_s 170 | 171 | safety_cost = safety_cost.view(b, l*2) 172 | 173 | return 10 * safety_cost 174 | 175 | def red_light(self, optim_vars, aux_vars): 176 | ego = optim_vars 177 | ego = ego[0].tensor.view(-1, self.horizon*10, 2) 178 | stop_point = aux_vars[0].tensor 179 | current_state = aux_vars[1].tensor 180 | s = current_state[:, -1, 0][:,None] + torch.cumsum(ego[..., 0] * 0.1, dim=-1) 181 | stop_distance = stop_point - 3 182 | red_light_error = (s - stop_distance) * (s > stop_distance) * (stop_point != 0) 183 | return 10 * red_light_error 184 | 185 | 186 | def cost_function(self, objective, control_variables, ref_line_fields, 187 | ref_line_costs, lwt, current_state, spl, stp, il_lane, vectorize=True): 188 | 189 | safe_cost = th.AutoDiffCostFunction([control_variables], self.collision_cost, self.horizon*2, 190 | aux_vars=[lwt, ref_line_fields, ref_line_costs, current_state], autograd_vectorize=vectorize, name="safe_cost") 191 | objective.add(safe_cost) 192 | 193 | il_cost = th.AutoDiffCostFunction([control_variables], self.il_cost, self.horizon*10, 194 | aux_vars=[current_state, il_lane],autograd_vectorize=vectorize, name="il_cost") 195 | objective.add(il_cost) 196 | 197 | 198 | rl_cost = th.AutoDiffCostFunction([control_variables], self.red_light, self.horizon*10, 199 | aux_vars=[stp, current_state],autograd_vectorize=vectorize, name="red_light") 200 | objective.add(rl_cost) 201 | 202 | return objective 203 | 204 | 205 | def plan(self, planning_inputs, selected_ref, current_state): 206 | 207 | final_values, info = self.layer.forward(planning_inputs, optimizer_kwargs={'track_best_solution': True}) 208 | plan = info.best_solution["control_variables"].view(-1, self.horizon*10, 2) 209 | plan = plan.to(selected_ref.device) 210 | s = planning_inputs['current_state'][:, -1, 0][:,None] + torch.cumsum(plan[..., 0] * 0.1, dim=-1) 211 | l = planning_inputs['current_state'][:, -1, 1][:,None] + torch.cumsum(plan[..., 1] * 0.1, dim=-1) 212 | plan = torch.stack([s, l], dim=-1) 213 | xy_plan = self.frenet_to_cartiesan(plan, selected_ref) 214 | speed = torch.hypot(current_state[:, -1, 2], current_state[:, -1, 3]) 215 | last_speed = torch.hypot(current_state[:, -2, 2], current_state[:,-2, 3]) 216 | acc = (speed - last_speed)/ 0.1 217 | current_state = torch.stack([current_state[:,-1, 0], current_state[:, -1, 1], current_state[:, -1, 4], speed, acc], dim=-1) 218 | b = xy_plan.shape[0] 219 | res = [] 220 | for i in range(b): 221 | pl = self.refine(current_state[i].cpu().numpy(), xy_plan[i].cpu().numpy()) 222 | res.append(pl) 223 | return torch.tensor(np.stack(res, 0)).to(xy_plan.device).float() 224 | 225 | def refine(self, current_state, reference): 226 | opti = cs.Opti() 227 | 228 | # Define the optimization variables 229 | X = opti.variable(4, self.horizon*10 + 1) 230 | U = opti.variable(2, self.horizon*10) 231 | 232 | # Define the initial state and the reference trajectory 233 | x0 = current_state[:4] # (x, y, theta, v) 234 | xr = reference.T 235 | 236 | # Define the cost function for the MPC problem 237 | obj = 0 238 | 239 | for i in range(self.horizon*10): 240 | obj += (i+1) / (self.horizon*10) * cs.sumsqr(X[:2, i+1] - xr[:2, i]) 241 | obj += 100 * (X[2, i+1] - xr[2, i]) ** 2 242 | obj += 0.1 * (U[0, i]) ** 2 243 | obj += U[1, i] ** 2 244 | 245 | if i >= 1: 246 | obj += 0.1 * (U[0, i] - U[0, i-1]) ** 2 247 | obj += (U[1, i] - U[1, i-1]) ** 2 248 | 249 | opti.minimize(obj) 250 | 251 | # Define the constraints for the MPC problem 252 | opti.subject_to(X[:, 0] == x0) 253 | opti.subject_to(U[0, 0] == current_state[4]) 254 | 255 | for i in range(self.horizon*10): 256 | opti.subject_to([X[0, i+1] == X[0, i] + X[3, i] * cs.cos(X[2, i]) * self.ts, 257 | X[1, i+1] == X[1, i] + X[3, i] * cs.sin(X[2, i]) * self.ts, 258 | X[2, i+1] == X[2, i] + X[3, i] / 4 * cs.tan(U[1, i]) * self.ts, 259 | X[3, i+1] == X[3, i] + U[0, i] * self.ts]) 260 | 261 | for i in range(self.horizon): 262 | int_step = (i+1)*10 263 | opti.subject_to(X[0, int_step] - xr[0, int_step-1] <= 2) 264 | opti.subject_to(X[0, int_step] - xr[0, int_step-1] >= -2) 265 | opti.subject_to(X[1, int_step] - xr[1, int_step-1] <= 0.5) 266 | opti.subject_to(X[1, int_step] - xr[1, int_step-1] >= -0.5) 267 | 268 | opti.subject_to(X[3, :] >= 0) 269 | 270 | # Create the MPC solver 271 | opts = {'ipopt.print_level': 0, 'print_time': False} 272 | opti.solver('ipopt', opts) 273 | try: 274 | sol = opti.solve() 275 | states = sol.value(X) 276 | # act = sol.value(U) 277 | except: 278 | print("Solver failed. Returning best solution found.") 279 | states = opti.debug.value(X) 280 | 281 | traj = states.T[1:, :3] 282 | 283 | return traj 284 | 285 | def frenet_to_cartiesan(self, sl, ref_line): 286 | s, l = (sl[:, :, 0]*10 + 200), sl[:, :, 1] 287 | s = s.clamp(0, 1199) 288 | b = ref_line.shape[0] 289 | ref_points = ref_line[torch.arange(b).long()[:,None], s.long(), :] 290 | cartesian_x = ref_points[:, :, 0] - l * torch.sin(ref_points[:, :, 2]) 291 | cartesian_y = ref_points[:, :, 1] + l * torch.cos(ref_points[:, :, 2]) 292 | angle = ref_points[:, :, 2] 293 | 294 | return torch.stack([cartesian_x, cartesian_y, angle], dim=-1) -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from torchmetrics.functional import auroc 5 | 6 | import numpy as np 7 | import math 8 | 9 | import skimage as ski 10 | 11 | def draw_ego_mask(ego, length=5.2860+1, width=2.332+0.5, size=(128, 128), pixels_per_meter=1.6): 12 | 13 | b = ego.shape[0] 14 | dego = ego.detach().cpu().numpy() 15 | masks = [] 16 | for i in range(b): 17 | x, y, angle = dego[i, 0], dego[i, 1], dego[i, 2] 18 | sin, cos = np.sin(angle), np.cos(angle) 19 | front_left = [x + length/2*cos - width/2*sin, y + length/2*sin + width/2*cos] 20 | front_right = [x + length/2*cos + width/2*sin, y + length/2*sin - width/2*cos] 21 | rear_left = [x - length/2*cos - width/2*sin, y - length/2*sin + width/2*cos] 22 | rear_right = [x - length/2*cos + width/2*sin, y - length/2*sin - width/2*cos] 23 | poly_xy = np.array([front_left, front_right, rear_right, rear_left]) #(4, 2) 24 | # x,y -> h, w 25 | poly_h = int(size[0]*0.75) - np.round(poly_xy[:, 0] * pixels_per_meter) 26 | poly_w = int(size[1]*0.5) - np.round(poly_xy[:, 1] * pixels_per_meter) 27 | poly_hw = np.stack([poly_h, poly_w] ,axis=-1) 28 | mask = ski.draw.polygon2mask(size, poly_hw) 29 | masks.append(mask) 30 | masks = np.stack(masks, axis=0) 31 | masks = torch.tensor(masks).to(ego.device) 32 | return masks 33 | 34 | def plan_metrics(trajectories, ego_future): 35 | l = ego_future.shape[-2] 36 | trajectories = trajectories[..., :l, :] 37 | ego_future_valid = torch.ne(ego_future[..., :2], 0) 38 | ego_trajectory = trajectories[..., :2] * ego_future_valid[:, None, :, :] 39 | distance = torch.norm(ego_trajectory[:,:, :, :2] - ego_future[:,None, :, :2], dim=-1) 40 | 41 | ade = distance.mean(-1) 42 | ade, _ = torch.min(ade,dim=-1) 43 | egoADE = torch.mean(ade) 44 | fde = distance[:,:,-1] 45 | fde, _ = torch.min(fde,dim=-1) 46 | egoFDE = torch.mean(fde) 47 | 48 | fde3 = distance[:,:,29] 49 | fde3, _ = torch.min(fde3,dim=-1) 50 | egoFDE3 = torch.mean(fde3) 51 | 52 | fde1 = distance[:,:,9] 53 | fde1, _ = torch.min(fde1,dim=-1) 54 | egoFDE1 = torch.mean(fde1) 55 | 56 | return egoADE.item(), egoFDE.item() ,egoFDE3.item(), egoFDE1.item() 57 | 58 | from time import time 59 | def occupancy_metrics(preds, target): 60 | #preds: [batch, t, h, w] (sigmoid) target[b, t, h, w] 61 | T = target.shape[1] 62 | check_time = [1, 3, 5] 63 | auc_list, iou_list = [], [] 64 | 65 | for i in range(T): 66 | auc_list.append(auc_metrics(preds[:, i], target[:, i])) 67 | iou_list.append(soft_iou(preds[:, i], target[:, i])) 68 | res_auc, res_iou = [], [] 69 | for t in check_time: 70 | res_auc.append(torch.mean(torch.stack(auc_list[:t])).item()) 71 | res_iou.append(torch.mean(torch.stack(iou_list[:t])).item()) 72 | return res_auc, res_iou 73 | 74 | def all_type_occupancy_metrics(preds, target, n_types=2): 75 | res_list = [] 76 | for i in range(n_types - 1): 77 | res_auc, res_iou = occupancy_metrics(preds[:, i], target['gt_obs'][..., 0, i]) 78 | res_list.append([res_auc, res_iou]) 79 | if preds.shape[1]<=3: 80 | res_auc, res_iou = occupancy_metrics(preds[:, 0], (target['gt_occ'] + target['gt_obs'][..., 0, 0]).clamp(0, 1)) 81 | else: 82 | res_auc, res_iou = occupancy_metrics(preds[:, 3], target['gt_occ']) 83 | res_list.append([res_auc, res_iou]) 84 | return res_list 85 | 86 | def auc_metrics(inputs, target): 87 | return auroc(inputs, target.int(), task='binary', thresholds=100) 88 | 89 | def soft_iou(inputs, target): 90 | inputs, target = inputs.reshape(-1), target.reshape(-1) 91 | intersection = torch.mean(torch.mul(inputs, target)) 92 | T_inputs, T_target = torch.mean(inputs), torch.mean(target) 93 | soft_iou_score = torch.nan_to_num(intersection / (T_inputs + T_target - intersection) ,0.0) 94 | return soft_iou_score 95 | 96 | def check_dynamics(traj, current_state): 97 | d_t = 0.1 98 | diff_xy = torch.diff(traj, dim=1) 99 | diff_x, diff_y = diff_xy[:,:, 0], diff_xy[:,:, 1] 100 | 101 | v_x, v_y, theta = diff_x / d_t, diff_y/d_t, np.arctan2(diff_y.cpu().numpy(), diff_x.cpu().numpy() + 1e-6) 102 | theta = torch.tensor(theta).to(v_x.device) 103 | lon_speed = v_x * torch.cos(theta) + v_y * torch.sin(theta) 104 | lat_speed = v_y * torch.cos(theta) - v_x * torch.sin(theta) 105 | 106 | acc = torch.diff(lon_speed,dim=-1) / d_t 107 | jerk = torch.diff(lon_speed,dim=-1,n=2) / d_t**2 108 | lat_acc = torch.diff(lat_speed,dim=-1) / d_t 109 | 110 | return torch.mean(torch.abs(acc)).item(), torch.mean(torch.abs(jerk)).item(), torch.mean(torch.abs(lat_acc)).item() 111 | 112 | def check_traffic(traj, ref_line, gt_modes): 113 | b, t, c = ref_line.shape 114 | red_light = False 115 | off_route = False 116 | 117 | # project to frenet 118 | distance_to_ref = torch.cdist(traj[:,:, :2], ref_line[:,:, :2]) 119 | #b, L_ego , s_ref 120 | s_ego = torch.argmin(distance_to_ref, axis=-1) 121 | distance_to_route = torch.min(distance_to_ref, axis=-1).values 122 | off_route = torch.any(distance_to_route > 5, dim=1) 123 | 124 | # get stop point 125 | stop_point = torch.argmax(ref_line[:,:,-2].int(),dim=1) 126 | sig = ref_line[torch.arange(b)[:,None],stop_point.unsqueeze(-1),-1].squeeze(1) 127 | rl_sig = torch.logical_or(sig==1, torch.logical_or(sig==4, sig==7)) 128 | s_stp = s_ego-stop_point.unsqueeze(-1) 129 | red_light = torch.logical_and(torch.logical_and(stop_point > 0, torch.any(s_stp > 0,dim=1)), rl_sig) 130 | 131 | return red_light, off_route#.float().mean().item() 132 | 133 | def compare_to_gt(ego_metric, gt_metric): 134 | not_gt_metric = torch.logical_not(gt_metric) 135 | real_metric = torch.logical_and(ego_metric, not_gt_metric) 136 | if not_gt_metric.float().sum()==0: 137 | return not_gt_metric.float().sum().item() 138 | real_val = real_metric.float().sum() / not_gt_metric.float().sum() 139 | return real_val.item() 140 | 141 | def flow_epe(outputs, targets): 142 | if 'gt_flow' not in targets: 143 | return 0 144 | target_flow = targets['gt_flow'] 145 | target_flow = target_flow.sum(-1) 146 | flow_exists = torch.logical_or(torch.ne(target_flow[..., 0], 0), torch.ne(target_flow[..., 1], 0)).float() 147 | flow_outputs = outputs[:, -2:] 148 | b, c, t, h, w = flow_outputs.shape 149 | flow_outputs = flow_outputs.permute(0, 2, 3, 4, 1) 150 | pred_flow = torch.mul(flow_outputs, flow_exists.unsqueeze(-1))#.permute(0, 2, 3, 4, 1) #[b, t, h, w, 2] 151 | epe_list = [] 152 | # for i in range(3): 153 | if torch.sum(flow_exists) > 0: 154 | flow_epe = torch.sum(torch.norm(pred_flow - target_flow, p=2, dim=-1))/ torch.nan_to_num(torch.sum(flow_exists)/2, 1.0) 155 | return flow_epe.item() 156 | else: 157 | return 0 158 | # return epe_list 159 | 160 | class TrainingMetrics: 161 | def __init__(self): 162 | self.ade = [] 163 | self.fde = [] 164 | self.il_loss = [] 165 | self.ogm_loss = [] 166 | self.epe = [] 167 | 168 | def update(self, traj, score, gt_modes, target, il_loss, ogm_loss, outputs): 169 | ade, fde, fde3, fde1 = plan_metrics(traj, target['ego_future_states']) 170 | self.epe.append(flow_epe(outputs, target)) 171 | self.ade.append(ade) 172 | self.fde.append(fde) 173 | 174 | self.il_loss.append(il_loss.item()) 175 | self.ogm_loss.append(ogm_loss.item()) 176 | return np.mean(self.ade), np.mean(self.fde), np.mean(self.il_loss), np.mean(self.ogm_loss) 177 | 178 | def result(self): 179 | return { 180 | 'T_il_loss':np.mean(self.il_loss), 181 | 'T_ogm_loss':np.mean(self.ogm_loss), 182 | 'T_ade':np.mean(self.ade), 183 | 'T_fde':np.mean(self.fde), 184 | 'epe_v':np.mean(self.epe) 185 | } 186 | 187 | class ValidationMetrics: 188 | def __init__(self): 189 | self.ade = [] 190 | self.fde = [] 191 | self.fde3 = [] 192 | self.fde1 = [] 193 | self.il_loss = [] 194 | self.ogm_loss = [] 195 | 196 | self.ogm_auc = [] 197 | self.ogm_iou = [] 198 | self.epe = [] 199 | 200 | self.ogm_auc_p = [] 201 | self.ogm_iou_p = [] 202 | self.epe_p = [] 203 | 204 | self.ogm_auc_c = [] 205 | self.ogm_iou_c = [] 206 | self.epe_c = [] 207 | 208 | self.occ_auc = [] 209 | self.occ_iou = [] 210 | 211 | def update(self, traj, score, ogm_pred, gt_modes, target, il_loss, ogm_loss): 212 | ade, fde, fde3, fde1 = plan_metrics(traj, target['ego_future_states']) 213 | self.ade.append(ade) 214 | self.fde.append(fde) 215 | self.fde3.append(fde3) 216 | self.fde1.append(fde1) 217 | 218 | epe_list = flow_epe(ogm_pred, target) 219 | self.epe.append(epe_list) 220 | 221 | self.il_loss.append(il_loss.item()) 222 | self.ogm_loss.append(ogm_loss.item()) 223 | 224 | ogm_list = all_type_occupancy_metrics(ogm_pred.sigmoid(), target, 4) 225 | ogm_auc, ogm_iou, occ_auc, occ_iou = ogm_list[0][0][1], ogm_list[0][1][1], ogm_list[-1][0][1], ogm_list[-1][1][1] 226 | 227 | self.ogm_auc.append(ogm_auc) 228 | self.ogm_iou.append(ogm_iou) 229 | self.occ_auc.append(occ_auc) 230 | self.occ_iou.append(occ_iou) 231 | 232 | ogm_auc_p, ogm_iou_p, ogm_auc_c, ogm_iou_c = ogm_list[1][0][1], ogm_list[1][1][1], ogm_list[2][0][1], ogm_list[2][1][1] 233 | self.ogm_auc_p.append(ogm_auc_p) 234 | self.ogm_iou_p.append(ogm_iou_p) 235 | self.ogm_auc_c.append(ogm_auc_c) 236 | self.ogm_iou_c.append(ogm_iou_c) 237 | 238 | return np.mean(self.ade), np.mean(self.fde), np.mean(self.il_loss), np.mean(self.ogm_loss),\ 239 | np.mean(self.ogm_auc), np.mean(self.ogm_iou), np.mean(self.occ_auc), np.mean(self.occ_iou) 240 | 241 | def result(self): 242 | return { 243 | 'E_il_loss':np.mean(self.il_loss), 244 | 'E_ogm_loss':np.mean(self.ogm_loss), 245 | 'E_ade':np.mean(self.ade), 246 | 'E_fde_5': np.mean(self.fde), 247 | 'E_fde_3': np.mean(self.fde3), 248 | 'E_fde_1': np.mean(self.fde1), 249 | 'ogm_auc_v':np.mean(self.ogm_auc), 250 | 'ogm_iou_v':np.mean(self.ogm_iou), 251 | 'ogm_auc_p':np.mean(self.ogm_auc_p), 252 | 'ogm_iou_p':np.mean(self.ogm_iou_p), 253 | 'ogm_auc_c':np.mean(self.ogm_auc_c), 254 | 'ogm_iou_c':np.mean(self.ogm_iou_c), 255 | 'occ_auc':np.mean(self.occ_auc), 256 | 'occ_iou':np.mean(self.occ_iou), 257 | 'v_epe':np.mean(self.epe) 258 | } 259 | 260 | 261 | class TestingMetrics: 262 | def __init__(self, config, lite_mode=False): 263 | 264 | self.reset() 265 | self.config = config 266 | self.ogm_to_position() 267 | self.lite_mode = lite_mode 268 | 269 | def reset(self): 270 | self.valid_dict = { 271 | 'fde_1s':[], 'fde_3s':[], 'fde_5s':[], 'ade':[],'collisions_rate':[], 'off_road_rate':[], 'red_light':[], 272 | 'acc':[], 'jerk':[], 'lat_acc':[] 273 | } 274 | for m in ['auc','iou']: 275 | for t in [3 ,5, 8]: 276 | for ty in ['v','p','c','occ']: 277 | self.valid_dict[f'{m}_{t}_{ty}'] = [] 278 | 279 | def ogm_to_position(self): 280 | indexes = torch.arange(1, self.config.grid_height_cells + 1) 281 | widths_indexes = - (indexes - self.config.sdc_y_in_grid - 0.5) / self.config.pixels_per_meter 282 | heights_indexes = - (indexes - self.config.sdc_x_in_grid - 0.5) / self.config.pixels_per_meter 283 | #correspnding (x,y) in dense coordinates (h, w, 2) 284 | coordinates = torch.stack(torch.meshgrid([widths_indexes, heights_indexes]), dim=-1)#.permute(1, 0, 2) 285 | self.ogm_coordinates = coordinates 286 | 287 | 288 | def gt_ogm_to_position(self, traj, target, t, current_state, col_thres=3): 289 | ego_mask = 1- target['gt_ego'][:, t] 290 | all_occupancies = target['gt_obs'][:, t, :, :, 0, :].sum(-1) + target['gt_occ'][:, t , :, :] 291 | all_occupancies = all_occupancies * ego_mask 292 | all_occupancies = all_occupancies.clamp(0, 1).unsqueeze(-1) 293 | b = all_occupancies.shape[0] 294 | current_ogm_coordinates = self.ogm_coordinates.unsqueeze(0).expand(b, -1, -1, -1).to(all_occupancies.device) 295 | 296 | b, h, w, c = current_ogm_coordinates.shape 297 | ego_plan = draw_ego_mask(traj[:, t*10 + 9, :3]) 298 | 299 | #filtering the occupied occupancies 300 | collision_grids = ego_plan * all_occupancies[..., 0] 301 | collision_grids = collision_grids.reshape(b, h*w) 302 | collision_grids = collision_grids.sum(1) > col_thres 303 | 304 | return collision_grids.float() 305 | 306 | def collsions_check(self, traj, target, current_state): 307 | col_list = [] 308 | for t in range(5): 309 | col_list.append(self.gt_ogm_to_position(traj, target, t, current_state)) 310 | col_list = torch.stack(col_list, dim=1).sum(1) >= 1 311 | col_rate = col_list 312 | return col_rate 313 | 314 | def update(self, traj, score, ogm_pred, gt_modes, target, current_state): 315 | 316 | ade, fde, fde3, fde1 = plan_metrics(traj, target['ego_future_states'][:, :50, :]) 317 | 318 | self.valid_dict['ade'].append(ade) 319 | self.valid_dict['fde_1s'].append(fde1) 320 | self.valid_dict['fde_3s'].append(fde3) 321 | self.valid_dict['fde_5s'].append(fde) 322 | 323 | ogm_list = all_type_occupancy_metrics(ogm_pred, target, 4) 324 | 325 | for i, ty in enumerate(['v','p','c','occ']): 326 | for j, t in enumerate([3, 5, 8]): 327 | for k, m in enumerate(['auc', 'iou']): 328 | self.valid_dict[f'{m}_{t}_{ty}'].append(ogm_list[i][k][j]) 329 | # self.valid_dict[f'{m}_{t}_{ty}'].append(0) 330 | 331 | gt_traj = target['ego_future_states'][..., :,:] 332 | if not self.lite_mode: 333 | col_rate = self.collsions_check(traj, target, current_state) 334 | self.valid_dict['collisions_rate'].append(col_rate.float().mean().item()) 335 | 336 | acc, jerk, lat_acc = check_dynamics(traj, current_state) 337 | 338 | self.valid_dict['acc'].append(acc) 339 | self.valid_dict['lat_acc'].append(lat_acc) 340 | self.valid_dict['jerk'].append(jerk) 341 | 342 | red_light, off_route = check_traffic(traj, target['ref_line'], gt_modes) 343 | gt_red_light, gt_off_route = check_traffic(gt_traj, target['ref_line'], gt_modes) 344 | self.valid_dict['red_light'].append(compare_to_gt(red_light, gt_red_light)) 345 | self.valid_dict['off_road_rate'].append(compare_to_gt(off_route, gt_off_route)) 346 | 347 | def result(self): 348 | new_dict = {} 349 | for k, v in self.valid_dict.items(): 350 | new_dict[k] = np.nanmean(v) 351 | return new_dict -------------------------------------------------------------------------------- /utils/occupancy_grid_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from functools import partial 4 | 5 | from .occupancy_render_utils import render_occupancy, render_flow_from_inputs, sample_filter, generate_units, render_ego_occupancy 6 | from waymo_open_dataset.utils.occupancy_flow_grids import TimestepGrids, WaypointGrids,_WaypointGridsOneType 7 | 8 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2 9 | from waymo_open_dataset.utils import occupancy_flow_data 10 | 11 | def create_ground_truth_timestep_grids( 12 | traj_tensor, 13 | valid_tensor, 14 | ego_traj, 15 | config, 16 | flow=False, 17 | flow_origin=False, 18 | sdc_ids=None, 19 | test=False 20 | ): 21 | """Renders topdown views of agents over past/current/future time frames. 22 | 23 | Args: 24 | inputs: Dict of input tensors from the motion dataset. 25 | config: OccupancyFlowTaskConfig proto message. 26 | 27 | Returns: 28 | TimestepGrids object holding topdown renders of agents. 29 | """ 30 | 31 | timestep_grids = TimestepGrids() 32 | 33 | unit_x, unit_y = generate_units(config.agent_points_per_side_length, config.agent_points_per_side_width) 34 | 35 | # Occupancy grids. 36 | sample_func = partial( 37 | sample_filter, 38 | traj_tensor=traj_tensor, 39 | valid_tensor=valid_tensor, 40 | ego_traj=ego_traj, 41 | config=config, 42 | unit_x=unit_x,unit_y=unit_y 43 | 44 | ) 45 | 46 | current_sample = sample_func( 47 | times=['current'], 48 | include_observed=True, 49 | include_occluded=True, 50 | ) 51 | current_occupancy, current_valid = render_occupancy(current_sample, config, sdc_ids=None) 52 | #[num_agents] 53 | # print(current_valid.shape) 54 | current_valid = tf.reduce_max(tf.cast(current_valid,tf.int32),axis=-1)[:,0] 55 | 56 | timestep_grids.vehicles.current_occupancy = current_occupancy.vehicles 57 | timestep_grids.pedestrians.current_occupancy = current_occupancy.pedestrians 58 | timestep_grids.cyclists.current_occupancy = current_occupancy.cyclists 59 | 60 | past_sample = sample_func( 61 | times=['past'], 62 | include_observed=True, 63 | include_occluded=True, 64 | ) 65 | past_occupancy,_ = render_occupancy(past_sample, config, sdc_ids=None) 66 | timestep_grids.vehicles.past_occupancy = past_occupancy.vehicles 67 | timestep_grids.pedestrians.past_occupancy = past_occupancy.pedestrians 68 | timestep_grids.cyclists.past_occupancy = past_occupancy.cyclists 69 | 70 | #[num_agents] presence in fov {AT Present} 71 | observed_valid = current_valid 72 | 73 | if not test: 74 | future_sample = sample_func( 75 | times=['future'], 76 | include_observed=True, 77 | include_occluded=False, 78 | ) 79 | future_obs,_ = render_occupancy(future_sample, config, sdc_ids=None) 80 | timestep_grids.vehicles.future_observed_occupancy = future_obs.vehicles 81 | timestep_grids.pedestrians.future_observed_occupancy = future_obs.pedestrians 82 | timestep_grids.cyclists.future_observed_occupancy = future_obs.cyclists 83 | 84 | 85 | future_sample_occ = sample_func( 86 | times=['future'], 87 | include_observed=False, 88 | include_occluded=True, 89 | ) 90 | future_occ, _ = render_occupancy(future_sample_occ, config, sdc_ids=None) 91 | # occluded_valid = tf.reduce_max(tf.cast(occ_valid,tf.int32),axis=-1)[:,0] 92 | timestep_grids.vehicles.future_occluded_occupancy = future_occ.vehicles 93 | timestep_grids.pedestrians.future_occluded_occupancy = future_occ.pedestrians 94 | timestep_grids.cyclists.future_occluded_occupancy = future_occ.cyclists 95 | 96 | # All occupancy for flow_origin_occupancy. 97 | if flow_origin or flow: 98 | all_sample = sample_func( 99 | times=['past', 'current', 'future'], 100 | include_observed=True, 101 | include_occluded=True, 102 | ) 103 | if flow_origin: 104 | all_occupancy,_ = render_occupancy(all_sample, config, sdc_ids=None) 105 | timestep_grids.vehicles.all_occupancy = all_occupancy.vehicles 106 | timestep_grids.pedestrians.all_occupancy = all_occupancy.pedestrians 107 | timestep_grids.cyclists.all_occupancy = all_occupancy.cyclists 108 | 109 | # Flow. 110 | # NOTE: Since the future flow depends on the current and past timesteps, we 111 | # need to compute it from [past + current + future] sparse points. 112 | if flow: 113 | all_flow = render_flow_from_inputs(all_sample, config, sdc_ids=None) 114 | timestep_grids.vehicles.all_flow = all_flow.vehicles 115 | timestep_grids.pedestrians.all_flow = all_flow.pedestrians 116 | timestep_grids.cyclists.all_flow = all_flow.cyclists 117 | 118 | if sdc_ids is not None: 119 | all_sample = sample_func( 120 | times=['past', 'current', 'future'], 121 | include_observed=True, 122 | include_occluded=False, 123 | ) 124 | ego_occupancy = render_ego_occupancy(all_sample, sdc_ids, config) 125 | return timestep_grids, observed_valid, ego_occupancy 126 | 127 | 128 | return timestep_grids, observed_valid 129 | 130 | 131 | 132 | def create_ground_truth_waypoint_grids( 133 | timestep_grids: TimestepGrids, 134 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig, 135 | flow_origin: bool=False, 136 | flow: bool=False, 137 | ) -> WaypointGrids: 138 | """Subsamples or aggregates future topdowns as ground-truth labels. 139 | 140 | Args: 141 | timestep_grids: Holds topdown renders of agents over time. 142 | config: OccupancyFlowTaskConfig proto message. 143 | 144 | Returns: 145 | WaypointGrids object. 146 | """ 147 | if config.num_future_steps % config.num_waypoints != 0: 148 | raise ValueError(f'num_future_steps({config.num_future_steps}) must be ' 149 | f'a multiple of num_waypoints({config.num_waypoints}).') 150 | 151 | true_waypoints = WaypointGrids( 152 | vehicles=_WaypointGridsOneType( 153 | observed_occupancy=[], occluded_occupancy=[], flow=[]), 154 | pedestrians=_WaypointGridsOneType( 155 | observed_occupancy=[], occluded_occupancy=[], flow=[]), 156 | cyclists=_WaypointGridsOneType( 157 | observed_occupancy=[], occluded_occupancy=[], flow=[]), 158 | ) 159 | 160 | # Observed occupancy. 161 | _add_ground_truth_observed_occupancy_to_waypoint_grids( 162 | timestep_grids=timestep_grids, 163 | waypoint_grids=true_waypoints, 164 | config=config) 165 | # Occluded occupancy. 166 | _add_ground_truth_occluded_occupancy_to_waypoint_grids( 167 | timestep_grids=timestep_grids, 168 | waypoint_grids=true_waypoints, 169 | config=config) 170 | # Flow origin occupancy. 171 | if flow_origin: 172 | _add_ground_truth_flow_origin_occupancy_to_waypoint_grids( 173 | timestep_grids=timestep_grids, 174 | waypoint_grids=true_waypoints, 175 | config=config) 176 | # Flow. 177 | if flow: 178 | _add_ground_truth_flow_to_waypoint_grids( 179 | timestep_grids=timestep_grids, 180 | waypoint_grids=true_waypoints, 181 | config=config) 182 | 183 | return true_waypoints 184 | 185 | def _ego_ground_truth_occupancy(ego_occupancy, config): 186 | waypoint_size = config.num_future_steps // config.num_waypoints 187 | future_obs = ego_occupancy[..., config.num_past_steps + 1:] 188 | gt_ogm = [] 189 | for k in range(config.num_waypoints): 190 | waypoint_end = (k + 1) * waypoint_size 191 | if config.cumulative_waypoints: 192 | waypoint_start = waypoint_end - waypoint_size 193 | # [batch_size, height, width, waypoint_size] 194 | segment = future_obs[..., waypoint_start:waypoint_end] 195 | # [batch_size, height, width, 1] 196 | waypoint_occupancy = tf.reduce_max(segment, axis=-1, keepdims=True) 197 | else: 198 | # [batch_size, height, width, 1] 199 | waypoint_occupancy = future_obs[..., waypoint_end - 1:waypoint_end] 200 | gt_ogm.append(waypoint_occupancy) 201 | 202 | return gt_ogm 203 | 204 | 205 | def _add_ground_truth_observed_occupancy_to_waypoint_grids( 206 | timestep_grids: TimestepGrids, 207 | waypoint_grids: WaypointGrids, 208 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig, 209 | ) -> None: 210 | """Subsamples or aggregates future topdowns as ground-truth labels. 211 | 212 | Args: 213 | timestep_grids: Holds topdown renders of agents over time. 214 | waypoint_grids: Holds topdown waypoints selected as ground-truth labels. 215 | config: OccupancyFlowTaskConfig proto message. 216 | """ 217 | waypoint_size = config.num_future_steps // config.num_waypoints 218 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES: 219 | # [batch_size, height, width, num_future_steps] 220 | future_obs = timestep_grids.view(object_type).future_observed_occupancy 221 | for k in range(config.num_waypoints): 222 | waypoint_end = (k + 1) * waypoint_size 223 | if config.cumulative_waypoints: 224 | waypoint_start = waypoint_end - waypoint_size 225 | # [batch_size, height, width, waypoint_size] 226 | segment = future_obs[..., waypoint_start:waypoint_end] 227 | # [batch_size, height, width, 1] 228 | waypoint_occupancy = tf.reduce_max(segment, axis=-1, keepdims=True) 229 | else: 230 | # [batch_size, height, width, 1] 231 | waypoint_occupancy = future_obs[..., waypoint_end - 1:waypoint_end] 232 | waypoint_grids.view(object_type).observed_occupancy.append( 233 | waypoint_occupancy) 234 | 235 | 236 | def _add_ground_truth_occluded_occupancy_to_waypoint_grids( 237 | timestep_grids: TimestepGrids, 238 | waypoint_grids: WaypointGrids, 239 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig, 240 | ) -> None: 241 | """Subsamples or aggregates future topdowns as ground-truth labels. 242 | 243 | Args: 244 | timestep_grids: Holds topdown renders of agents over time. 245 | waypoint_grids: Holds topdown waypoints selected as ground-truth labels. 246 | config: OccupancyFlowTaskConfig proto message. 247 | """ 248 | waypoint_size = config.num_future_steps // config.num_waypoints 249 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES: 250 | # [batch_size, height, width, num_future_steps] 251 | future_occ = timestep_grids.view(object_type).future_occluded_occupancy 252 | for k in range(config.num_waypoints): 253 | waypoint_end = (k + 1) * waypoint_size 254 | if config.cumulative_waypoints: 255 | waypoint_start = waypoint_end - waypoint_size 256 | # [batch_size, height, width, waypoint_size] 257 | segment = future_occ[..., waypoint_start:waypoint_end] 258 | # [batch_size, height, width, 1] 259 | waypoint_occupancy = tf.reduce_max(segment, axis=-1, keepdims=True) 260 | else: 261 | # [batch_size, height, width, 1] 262 | waypoint_occupancy = future_occ[..., waypoint_end - 1:waypoint_end] 263 | waypoint_grids.view(object_type).occluded_occupancy.append( 264 | waypoint_occupancy) 265 | 266 | 267 | def _add_ground_truth_flow_origin_occupancy_to_waypoint_grids( 268 | timestep_grids: TimestepGrids, 269 | waypoint_grids: WaypointGrids, 270 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig, 271 | ) -> None: 272 | """Subsamples or aggregates topdowns as origin occupancies for flow fields. 273 | 274 | Args: 275 | timestep_grids: Holds topdown renders of agents over time. 276 | waypoint_grids: Holds topdown waypoints selected as ground-truth labels. 277 | config: OccupancyFlowTaskConfig proto message. 278 | """ 279 | waypoint_size = config.num_future_steps // config.num_waypoints 280 | num_history_steps = config.num_past_steps + 1 # Includes past + current. 281 | num_future_steps = config.num_future_steps 282 | if waypoint_size > num_history_steps: 283 | raise ValueError('If waypoint_size > num_history_steps, we cannot find the ' 284 | 'flow origin occupancy for the first waypoint.') 285 | 286 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES: 287 | # [batch_size, height, width, num_past_steps + 1 + num_future_steps] 288 | all_occupancy = timestep_grids.view(object_type).all_occupancy 289 | # Keep only the section containing flow_origin_occupancy timesteps. 290 | # First remove `waypoint_size` from the end. Then keep the tail containing 291 | # num_future_steps timesteps. 292 | flow_origin_occupancy = all_occupancy[..., :-waypoint_size] 293 | # [batch_size, height, width, num_future_steps] 294 | flow_origin_occupancy = flow_origin_occupancy[..., -num_future_steps:] 295 | for k in range(config.num_waypoints): 296 | waypoint_end = (k + 1) * waypoint_size 297 | if config.cumulative_waypoints: 298 | waypoint_start = waypoint_end - waypoint_size 299 | # [batch_size, height, width, waypoint_size] 300 | segment = flow_origin_occupancy[..., waypoint_start:waypoint_end] 301 | # [batch_size, height, width, 1] 302 | waypoint_flow_origin = tf.reduce_max(segment, axis=-1, keepdims=True) 303 | else: 304 | # [batch_size, height, width, 1] 305 | waypoint_flow_origin = flow_origin_occupancy[..., waypoint_end - 306 | 1:waypoint_end] 307 | waypoint_grids.view(object_type).flow_origin_occupancy.append( 308 | waypoint_flow_origin) 309 | 310 | 311 | def _add_ground_truth_flow_to_waypoint_grids( 312 | timestep_grids: TimestepGrids, 313 | waypoint_grids: WaypointGrids, 314 | config: occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig, 315 | ) -> None: 316 | """Subsamples or aggregates future flow fields as ground-truth labels. 317 | 318 | Args: 319 | timestep_grids: Holds topdown renders of agents over time. 320 | waypoint_grids: Holds topdown waypoints selected as ground-truth labels. 321 | config: OccupancyFlowTaskConfig proto message. 322 | """ 323 | num_future_steps = config.num_future_steps 324 | waypoint_size = config.num_future_steps // config.num_waypoints 325 | 326 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES: 327 | # num_flow_steps = (num_past_steps + num_futures_steps) - waypoint_size 328 | # [batch_size, height, width, num_flow_steps, 2] 329 | flow = timestep_grids.view(object_type).all_flow 330 | # Keep only the flow tail, containing num_future_steps timesteps. 331 | # [batch_size, height, width, num_future_steps, 2] 332 | flow = flow[..., -num_future_steps:, :] 333 | for k in range(config.num_waypoints): 334 | waypoint_end = (k + 1) * waypoint_size 335 | if config.cumulative_waypoints: 336 | waypoint_start = waypoint_end - waypoint_size 337 | # [batch_size, height, width, waypoint_size, 2] 338 | segment = flow[..., waypoint_start:waypoint_end, :] 339 | # Compute mean flow over the timesteps in this segment by counting 340 | # the number of pixels with non-zero flow and dividing the flow sum 341 | # by that number. 342 | # [batch_size, height, width, waypoint_size, 2] 343 | occupied_pixels = tf.cast(tf.not_equal(segment, 0.0), tf.float32) 344 | # [batch_size, height, width, 2] 345 | num_flow_values = tf.reduce_sum(occupied_pixels, axis=3) 346 | # [batch_size, height, width, 2] 347 | segment_sum = tf.reduce_sum(segment, axis=3) 348 | # [batch_size, height, width, 2] 349 | mean_flow = tf.math.divide_no_nan(segment_sum, num_flow_values) 350 | waypoint_flow = mean_flow 351 | else: 352 | waypoint_flow = flow[..., waypoint_end - 1, :] 353 | waypoint_grids.view(object_type).flow.append(waypoint_flow) -------------------------------------------------------------------------------- /model/encoder.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from time import time 12 | 13 | class PositionalEncoding(nn.Module): 14 | def __init__(self, d_model=256, dropout=0.1, max_len=100): 15 | super(PositionalEncoding, self).__init__() 16 | position = torch.arange(max_len).unsqueeze(1) 17 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 18 | pe = torch.zeros(max_len, 1, d_model) 19 | pe[:, 0, 0::2] = torch.sin(position * div_term) 20 | pe[:, 0, 1::2] = torch.cos(position * div_term) 21 | pe = pe.permute(1, 0, 2) 22 | self.register_buffer('pe', pe) 23 | self.dropout = nn.Dropout(p=dropout) 24 | 25 | def forward(self, x): 26 | x = x + self.pe 27 | return self.dropout(x) 28 | 29 | class AgentEncoder(nn.Module): 30 | def __init__(self): 31 | super(AgentEncoder, self).__init__() 32 | self.motion = nn.LSTM(9, 256, 2, batch_first=True) 33 | self.type_embed = nn.Embedding(3, 256, padding_idx=0) 34 | 35 | def forward(self, inputs): 36 | traj, _ = self.motion(inputs[...,:-1]) 37 | types = inputs[...,0,-1].int().clamp(0, 2) 38 | types = self.type_embed(types) 39 | output = traj[:, -1] + types 40 | return output 41 | 42 | class LaneEncoder(nn.Module): 43 | def __init__(self): 44 | super(LaneEncoder, self).__init__() 45 | # encdoer layer 46 | self.self_line = nn.Linear(3, 128) 47 | self.left_line = nn.Linear(3, 128) 48 | self.right_line = nn.Linear(3, 128) 49 | self.speed_limit = nn.Linear(1, 64) 50 | self.self_type = nn.Embedding(4, 64, padding_idx=0) 51 | self.left_type = nn.Embedding(11, 64, padding_idx=0) 52 | self.right_type = nn.Embedding(11, 64, padding_idx=0) 53 | self.traffic_light_type = nn.Embedding(9, 64, padding_idx=0) 54 | self.interpolating = nn.Embedding(2, 64) 55 | self.stop_sign = nn.Embedding(2, 64) 56 | self.stop_point = nn.Embedding(2, 64) 57 | 58 | # hidden layers 59 | self.pointnet = nn.Sequential(nn.Linear(512, 384), nn.ReLU(), nn.Linear(384, 256)) 60 | self.position_encode = PositionalEncoding(max_len=100) 61 | 62 | def forward(self, inputs): 63 | # embedding 64 | self_line = self.self_line(inputs[..., :3]) 65 | left_line = self.left_line(inputs[..., 3:6]) 66 | right_line = self.right_line(inputs[..., 6:9]) 67 | speed_limit = self.speed_limit(inputs[..., 9].unsqueeze(-1)) 68 | self_type = self.self_type(inputs[..., 10].int().clamp(0, 3)) 69 | left_type = self.left_type(inputs[..., 11].int().clamp(0, 10)) 70 | right_type = self.right_type(inputs[..., 12].int().clamp(0, 10)) 71 | traffic_light = self.traffic_light_type(inputs[..., 13].int().clamp(0, 8)) 72 | stop_point = self.stop_point(inputs[..., 14].int().clamp(0, 1)) 73 | interpolating = self.interpolating(inputs[..., 15].int().clamp(0, 1)) 74 | stop_sign = self.stop_sign(inputs[..., 16].int().clamp(0, 1)) 75 | 76 | lane_attr = self_type + left_type + right_type + traffic_light + stop_point + interpolating + stop_sign 77 | lane_embedding = torch.cat([self_line, left_line, right_line, speed_limit, lane_attr], dim=-1) 78 | 79 | # process 80 | output = self.pointnet(lane_embedding) 81 | output = self.position_encode(output) 82 | 83 | return output 84 | 85 | class NeighborLaneEncoder(nn.Module): 86 | def __init__(self): 87 | super(NeighborLaneEncoder, self).__init__() 88 | # encdoer layer 89 | self.self_line = nn.Linear(3, 128) 90 | self.speed_limit = nn.Linear(1, 64) 91 | self.traffic_light_type = nn.Embedding(9, 64, padding_idx=0) 92 | 93 | # hidden layers 94 | self.pointnet = nn.Sequential(nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, 256)) 95 | self.position_encode = PositionalEncoding(max_len=50) 96 | 97 | def forward(self, inputs): 98 | # embedding 99 | self_line = self.self_line(inputs[..., :3]) 100 | speed_limit = self.speed_limit(inputs[..., 3].unsqueeze(-1)) 101 | traffic_light = self.traffic_light_type(inputs[..., 4].int().clamp(0, 8)) 102 | lane_embedding = torch.cat([self_line, speed_limit, traffic_light], dim=-1) 103 | 104 | # process 105 | output = self.pointnet(lane_embedding) 106 | output = self.position_encode(output) 107 | 108 | return output 109 | 110 | class CrosswalkEncoder(nn.Module): 111 | def __init__(self): 112 | super(CrosswalkEncoder, self).__init__() 113 | self.pointnet = nn.Sequential( 114 | nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 128), nn.ReLU(), nn.Linear(128, 256) 115 | ) 116 | self.position_encode = PositionalEncoding(max_len=50) 117 | 118 | def forward(self, inputs): 119 | output = self.pointnet(inputs) 120 | output = self.position_encode(output) 121 | return output 122 | 123 | class VectorEncoder(nn.Module): 124 | def __init__(self, dim=256, layers=4, heads=8, dropout=0.1): 125 | super(VectorEncoder, self).__init__() 126 | 127 | self.ego_encoder = AgentEncoder() 128 | self.neighbor_encoder = AgentEncoder() 129 | 130 | # self.map_encoder = NeighborLaneEncoder() 131 | self.ego_map_encoder = LaneEncoder() 132 | self.crosswalk_encoder = CrosswalkEncoder() 133 | 134 | attention_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4, 135 | activation='gelu', dropout=dropout, batch_first=True) 136 | self.fusion_encoder = nn.TransformerEncoder(attention_layer, layers) 137 | 138 | def segment_map(self, map, map_encoding): 139 | B, N_r, N_p, D = map_encoding.shape 140 | map_encoding = F.max_pool2d(map_encoding.permute(0, 3, 1, 2), kernel_size=(1, 10)) 141 | map_encoding = map_encoding.permute(0, 2, 3, 1).reshape(B, -1, D) 142 | 143 | map_mask = torch.eq(map, 0)[:, :, :, 0].reshape(B, N_r, N_p//10, -1) 144 | map_mask = torch.max(map_mask, dim=-1)[0].reshape(B, -1) 145 | map_mask[:, 0] = False # prevent nan 146 | 147 | return map_encoding, map_mask 148 | 149 | def forward(self, inputs): 150 | 151 | ego = inputs['ego_state'] 152 | neighbor = inputs['neighbor_state'] 153 | 154 | actors = torch.cat([ego.unsqueeze(1), neighbor], dim=1) 155 | actors_mask = torch.eq(actors, 0)[:, :, -1, 0] 156 | actors_mask[:, 0] = False 157 | 158 | ego = self.ego_encoder(ego) 159 | B, N, T, D = neighbor.shape 160 | neighbor = self.neighbor_encoder(neighbor.reshape(B*N, T, D)) 161 | neighbor = neighbor.reshape(B, N, -1) 162 | 163 | encode_actors = torch.cat([ego.unsqueeze(1), neighbor], dim=1) 164 | B,N,C = encode_actors.shape 165 | 166 | ego_maps = inputs['ego_map_lane'] 167 | ego_encode_maps = self.ego_map_encoder(ego_maps) 168 | B, M, L, D = ego_maps.shape 169 | ego_encode_maps, ego_map_mask = self.segment_map(ego_maps, ego_encode_maps) #(B*N,N_map,D) 170 | encode_maps = ego_encode_maps 171 | map_mask = ego_map_mask 172 | 173 | crosswalks = inputs['ego_map_crosswalk'] 174 | encode_cws = self.crosswalk_encoder(crosswalks) 175 | B, M, L, D = crosswalks.shape 176 | encode_cws, cw_mask = self.segment_map(crosswalks, encode_cws) 177 | 178 | encode_maps = torch.cat([encode_maps, encode_cws], dim=1) 179 | map_mask = torch.cat([map_mask, cw_mask], dim=1) 180 | 181 | encode_inputs = torch.cat([encode_actors, encode_maps], dim=1) #(B*N,N + N_map + N_cw, D) 182 | encode_masks = torch.cat([actors_mask, map_mask], dim=1) #(B*N,N + N_map + N_cw) 183 | encode_masks[:,0] = False 184 | 185 | encodings = self.fusion_encoder(encode_inputs ,src_key_padding_mask=encode_masks) 186 | 187 | _, L, D = encodings.shape 188 | N = 1 189 | encodings = encodings.reshape(B, N, L, D) 190 | encode_masks = encode_masks.reshape(B, N, L) 191 | 192 | encoder_outputs = { 193 | 'actors': actors, 194 | 'encodings': encodings, 195 | 'masks': encode_masks 196 | } 197 | 198 | encoder_outputs.update(inputs) 199 | return encoder_outputs 200 | 201 | 202 | class PredLaneEncoder(nn.Module): 203 | def __init__(self): 204 | super(PredLaneEncoder, self).__init__() 205 | # encdoer layer 206 | self.self_line = nn.Linear(3, 128) 207 | self.map_flow = nn.Linear(3, 128) 208 | 209 | self.speed_limit = nn.Linear(1, 64) 210 | self.traffic_light_type = nn.Embedding(4, 64, padding_idx=0) 211 | self.self_type = nn.Embedding(20, 64, padding_idx=0) 212 | self.stop_sign = nn.Linear(1, 64) 213 | 214 | # hidden layers 215 | self.pointnet = nn.Sequential(nn.Linear(384, 256), nn.ReLU(), nn.Linear(256, 256)) 216 | self.position_encode = PositionalEncoding(max_len=20) 217 | 218 | def forward(self, inputs): 219 | # embedding 220 | self_line = self.self_line(inputs[..., :3]) 221 | 222 | road_type = self.self_type(inputs[..., 3].int().clamp(0, 19)) 223 | traffic_light = self.traffic_light_type(inputs[..., 4].int().clamp(0, 3)) 224 | stop_sign = self.stop_sign(inputs[..., 6].unsqueeze(-1)) 225 | sp_limit = self.speed_limit(inputs[..., 7].unsqueeze(-1)) 226 | 227 | map_flow_feat = self.map_flow(inputs[..., -3:]) 228 | lane_feat = road_type + traffic_light + stop_sign 229 | 230 | lane_embedding = torch.cat([self_line, map_flow_feat, lane_feat, sp_limit], dim=-1) 231 | 232 | # process 233 | output = self.pointnet(lane_embedding) 234 | output = self.position_encode(output) 235 | 236 | #max pooling: 237 | output = torch.max(output, dim=-2).values 238 | 239 | return output 240 | 241 | class PredEncoder(nn.Module): 242 | def __init__(self, dim=256, layers=4, heads=8, dropout=0.1,use_map=True): 243 | super(PredEncoder, self).__init__() 244 | self.ego_encoder = AgentEncoder() 245 | self.neighbor_encoder = AgentEncoder() 246 | self.use_map = use_map 247 | if use_map: 248 | self.map_encoder = PredLaneEncoder() 249 | 250 | attention_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=dim*4, 251 | activation='gelu', dropout=dropout, batch_first=True) 252 | self.fusion_encoder = nn.TransformerEncoder(attention_layer, layers, enable_nested_tensor=False) 253 | print('vec_encoder',sum([p.numel() for p in self.parameters()])) 254 | 255 | def forward(self, inputs): 256 | ego = inputs['ego_state'] 257 | neighbor = inputs['neighbor_state'] 258 | 259 | actors = torch.cat([ego.unsqueeze(1), neighbor], dim=1) 260 | actors_mask = torch.eq(actors, 0)[:, :, -1, 0] 261 | actors_mask[:, 0] = False 262 | 263 | ego = self.ego_encoder(ego) 264 | B, N, T, D = neighbor.shape 265 | neighbor = self.neighbor_encoder(neighbor.reshape(B*N, T, D)) 266 | neighbor = neighbor.reshape(B, N, -1) 267 | 268 | encode_actors = torch.cat([ego.unsqueeze(1), neighbor], dim=1) 269 | B,N,C = encode_actors.shape 270 | 271 | if self.use_map: 272 | maps = inputs['map_segs'] 273 | map_mask = torch.eq(maps, 0)[:, :, 0, 0] 274 | maps = self.map_encoder(maps) 275 | encode_actors = torch.cat([encode_actors, maps], dim=1) 276 | encode_masks = torch.cat([actors_mask, map_mask], dim=1) 277 | else: 278 | encode_masks = actors_mask 279 | 280 | encodings = self.fusion_encoder(encode_actors ,src_key_padding_mask=encode_masks) 281 | 282 | encoder_outputs = { 283 | 'actors': actors, 284 | 'encodings': encodings, 285 | 'masks': encode_masks 286 | } 287 | 288 | encoder_outputs.update(inputs) 289 | return encoder_outputs 290 | 291 | 292 | from .swin_T import PredSwinTransformerV2 293 | 294 | 295 | class OGMFlowEncoder(nn.Module): 296 | def __init__(self,sep_flow=False,large_scale=False): 297 | super(OGMFlowEncoder, self).__init__() 298 | 299 | self.map_type_encoder = nn.Embedding(num_embeddings=20, embedding_dim=64 ,padding_idx=0) 300 | self.tl_encoder = nn.Embedding(num_embeddings=4, embedding_dim=64, padding_idx=0) 301 | self.map_encoder = nn.Linear(1, 64) 302 | 303 | self.ogm_rg_encoder = nn.Linear(33 if large_scale else 11, 128) 304 | self.sep_flow = sep_flow 305 | if not self.sep_flow: 306 | self.flow_encoder = nn.Linear(2, 128) 307 | self.offset_encoder = nn.Linear(2, 64) 308 | 309 | self.point_net = nn.Sequential(nn.Linear((384 if not sep_flow else 256), 192), nn.ReLU(), nn.Linear(192, 96)) 310 | print('flow_encoder',sum([p.numel() for p in self.parameters()])) 311 | 312 | def forward(self, inputs, offsets): 313 | hist_ogm = inputs['hist_ogm'] 314 | hist_flow = inputs['hist_flow'] 315 | road_graph = inputs['road_graph'] 316 | 317 | rg, road_type, traffic = road_graph[...,0:1], road_graph[..., 1].int(), road_graph[..., 2].int() 318 | 319 | hist_ogm = self.ogm_rg_encoder(hist_ogm) 320 | if not self.sep_flow: 321 | hist_flow = self.flow_encoder(hist_flow) 322 | 323 | offsets = self.offset_encoder(offsets) 324 | maps = self.map_encoder(rg) + self.tl_encoder(traffic.clamp(0, 3)) + self.map_type_encoder(road_type.clamp(0, 19)) 325 | maps = maps 326 | 327 | if self.sep_flow: 328 | mm_inputs = torch.cat([hist_ogm, maps, offsets], dim=-1) 329 | else: 330 | mm_inputs = torch.cat([hist_ogm, hist_flow, maps, offsets], dim=-1) 331 | mm_inputs = self.point_net(mm_inputs) 332 | mm_inputs = mm_inputs.permute(0, 3, 1, 2) 333 | return mm_inputs 334 | 335 | class VisualEncoder(PredSwinTransformerV2): 336 | def __init__(self, 337 | config, 338 | input_resolution=(512, 512), 339 | embedding_channels=96, 340 | window_size=8, 341 | in_channels=96, 342 | patch_size=4, 343 | use_checkpoint=False, 344 | sequential_self_attention=False, 345 | use_deformable_block=True, 346 | large_scale=False, 347 | **kwargs): 348 | super(VisualEncoder, self).__init__(input_resolution=input_resolution, 349 | window_size=window_size, 350 | in_channels=in_channels, 351 | use_checkpoint=use_checkpoint, 352 | sequential_self_attention=sequential_self_attention, 353 | embedding_channels=embedding_channels, 354 | patch_size=patch_size, 355 | depths=(2, 2, 2), 356 | number_of_heads=(6, 12, 24), 357 | use_deformable_block=use_deformable_block, 358 | large_scale=large_scale, 359 | **kwargs) 360 | 361 | self.mm_encoder = OGMFlowEncoder(sep_flow=True,large_scale=large_scale) 362 | 363 | self.config = config 364 | self.input_resolution = input_resolution 365 | self._make_position_bias_input() 366 | 367 | self.init_crop = input_resolution[0] / patch_size 368 | print('swin_encoder',sum([p.numel() for p in self.parameters()])) 369 | 370 | def half_cropping(self, tensor, stage=0): 371 | cropped_len = int(self.init_crop / (2 ** stage)) 372 | begin, end = int(cropped_len / 4), int(3 * cropped_len / 4) 373 | return tensor[:, :, begin:end, begin:end] 374 | 375 | def crop_output_list(self, output_list): 376 | return [self.half_cropping(outputs, i) for i, outputs in enumerate(output_list)] 377 | 378 | def _make_position_bias_input(self): 379 | device = self.stages[0].blocks[0].window_attention.tau.device 380 | indexes: torch.Tensor = torch.arange(1, self.config.grid_height_cells*2 + 1, device=device) 381 | widths_indexes = - (indexes - (self.config.sdc_x_in_grid + self.config.grid_height_cells) - 0.5) / self.config.pixels_per_meter 382 | heights_indexes = - (indexes - (self.config.sdc_y_in_grid + self.config.grid_width_cells) - 0.5) / self.config.pixels_per_meter 383 | #correspnding (x,y) in dense coordinates 384 | coordinates: torch.Tensor = torch.stack(torch.meshgrid([heights_indexes, widths_indexes]), dim=-1) 385 | self.register_buffer('input_bias', coordinates) 386 | 387 | def forward(self, inputs): 388 | coordinate_bias = self.input_bias #[b, h, w, 2] 389 | b = inputs['hist_ogm'].shape[0] 390 | offsets = coordinate_bias.unsqueeze(0).expand(b, -1, -1, -1) 391 | visual_inputs = self.mm_encoder(inputs, offsets) 392 | outputs_list, flow = super(VisualEncoder, self).forward(visual_inputs, inputs['hist_flow'].permute(0, 3, 1, 2)) 393 | outputs_list = self.crop_output_list(outputs_list) 394 | return outputs_list, None -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from shapely.geometry import LineString, Point, Polygon 4 | from shapely.affinity import affine_transform, rotate 5 | 6 | def wrap_to_pi(theta): 7 | #[0, 2pi] ->[0, pi,-pi, 0] 8 | return (theta+np.pi) % (2*np.pi) - np.pi 9 | 10 | def compute_direction_diff(ego_theta, target_theta): 11 | delta = np.abs(ego_theta - target_theta) 12 | delta = np.where(delta > np.pi, 2*np.pi - delta, delta) 13 | 14 | return delta 15 | 16 | def depth_first_search(cur_lane, lanes, dist=0, threshold=100): 17 | """ 18 | Perform depth first search over lane graph up to the threshold. 19 | Args: 20 | cur_lane: Starting lane_id 21 | lanes: raw lane data 22 | dist: Distance of the current path 23 | threshold: Threshold after which to stop the search 24 | Returns: 25 | lanes_to_return (list of list of integers): List of sequence of lane ids 26 | """ 27 | if dist > threshold: 28 | return [[cur_lane]] 29 | else: 30 | traversed_lanes = [] 31 | child_lanes = lanes[cur_lane].exit_lanes 32 | 33 | if child_lanes: 34 | for child in child_lanes: 35 | centerline = np.array([(map_point.x, map_point.y, map_point.z) for map_point in lanes[child].polyline]) 36 | cl_length = centerline.shape[0] 37 | curr_lane_ids = depth_first_search(child, lanes, dist + cl_length, threshold) 38 | traversed_lanes.extend(curr_lane_ids) 39 | 40 | if len(traversed_lanes) == 0: 41 | return [[cur_lane]] 42 | 43 | lanes_to_return = [] 44 | 45 | for lane_seq in traversed_lanes: 46 | lanes_to_return.append([cur_lane] + lane_seq) 47 | 48 | return lanes_to_return 49 | 50 | def is_overlapping_lane_seq(lane_seq1, lane_seq2): 51 | """ 52 | Check if the 2 lane sequences are overlapping. 53 | Args: 54 | lane_seq1: list of lane ids 55 | lane_seq2: list of lane ids 56 | Returns: 57 | bool, True if the lane sequences overlap 58 | """ 59 | 60 | if lane_seq2[1:] == lane_seq1[1:]: 61 | return True 62 | elif set(lane_seq2) <= set(lane_seq1): 63 | return True 64 | 65 | return False 66 | 67 | def remove_overlapping_lane_seq(lane_seqs): 68 | """ 69 | Remove lane sequences which are overlapping to some extent 70 | Args: 71 | lane_seqs (list of list of integers): List of list of lane ids (Eg. [[12345, 12346, 12347], [12345, 12348]]) 72 | Returns: 73 | List of sequence of lane ids (e.g. ``[[12345, 12346, 12347], [12345, 12348]]``) 74 | """ 75 | redundant_lane_idx = set() 76 | 77 | for i in range(len(lane_seqs)): 78 | for j in range(len(lane_seqs)): 79 | if i in redundant_lane_idx or i == j: 80 | continue 81 | if is_overlapping_lane_seq(lane_seqs[i], lane_seqs[j]): 82 | redundant_lane_idx.add(j) 83 | 84 | unique_lane_seqs = [lane_seqs[i] for i in range(len(lane_seqs)) if i not in redundant_lane_idx] 85 | 86 | return unique_lane_seqs 87 | 88 | def polygon_completion(polygon): 89 | polyline_x = [] 90 | polyline_y = [] 91 | 92 | for i in range(len(polygon)): 93 | if i+1 < len(polygon): 94 | next = i+1 95 | else: 96 | next = 0 97 | 98 | dist_x = polygon[next].x - polygon[i].x 99 | dist_y = polygon[next].y - polygon[i].y 100 | dist = np.linalg.norm([dist_x, dist_y]) 101 | interp_num = np.ceil(dist)*2 102 | interp_index = np.arange(2+interp_num) 103 | point_x = np.interp(interp_index, [0, interp_index[-1]], [polygon[i].x, polygon[next].x]).tolist() 104 | point_y = np.interp(interp_index, [0, interp_index[-1]], [polygon[i].y, polygon[next].y]).tolist() 105 | polyline_x.extend(point_x[:-1]) 106 | polyline_y.extend(point_y[:-1]) 107 | 108 | polyline_x, polyline_y = np.array(polyline_x), np.array(polyline_y) 109 | polyline_heading = wrap_to_pi(np.arctan2(polyline_y[1:]-polyline_y[:-1], polyline_x[1:]-polyline_x[:-1])) 110 | polyline_heading = np.insert(polyline_heading, -1, polyline_heading[-1]) 111 | 112 | return np.stack([polyline_x, polyline_y, polyline_heading], axis=1) 113 | 114 | def get_polylines(lines): 115 | polylines = {} 116 | 117 | for line in lines.keys(): 118 | polyline = np.array([(map_point.x, map_point.y) for map_point in lines[line].polyline]) 119 | if len(polyline) > 1: 120 | direction = wrap_to_pi(np.arctan2(polyline[1:, 1]-polyline[:-1, 1], polyline[1:, 0]-polyline[:-1, 0])) 121 | direction = np.insert(direction, -1, direction[-1])[:, np.newaxis] 122 | else: 123 | direction = np.array([0])[:, np.newaxis] 124 | polylines[line] = np.concatenate([polyline, direction], axis=-1) 125 | 126 | return polylines 127 | 128 | def find_reference_lanes(agent_type, agent_traj, lanes): 129 | curr_lane_ids = {} 130 | 131 | if agent_type == 2: 132 | distance_threshold = 5 133 | 134 | while len(curr_lane_ids) < 1: 135 | for lane in lanes.keys(): 136 | if lanes[lane].shape[0] > 1: 137 | distance_to_agent = LineString(lanes[lane][:, :2]).distance(Point(agent_traj[-1, :2])) 138 | if distance_to_agent < distance_threshold: 139 | curr_lane_ids[lane] = 0 140 | 141 | distance_threshold += 5 142 | if distance_threshold > 50: 143 | break 144 | else: 145 | distance_threshold = 3.5 146 | direction_threshold = 10 147 | while len(curr_lane_ids) < 1: 148 | for lane in lanes.keys(): 149 | distance_to_ego = np.linalg.norm(agent_traj[-1, :2] - lanes[lane][:, :2], axis=-1) 150 | direction_to_ego = wrap_to_pi(agent_traj[-1, 2] - lanes[lane][:, -1]) 151 | for i, j, k in zip(distance_to_ego, direction_to_ego, range(distance_to_ego.shape[0])): 152 | if i <= distance_threshold :#and np.abs(j) <= np.radians(direction_threshold): 153 | curr_lane_ids[lane] = k 154 | break 155 | 156 | distance_threshold += 3.5 157 | direction_threshold += 10 158 | if distance_threshold > 50: 159 | break 160 | 161 | return curr_lane_ids 162 | 163 | def find_neighbor_lanes(curr_lane_ids, traj, lanes, lane_polylines): 164 | neighbor_lane_ids = {} 165 | 166 | for curr_lane, start in curr_lane_ids.items(): 167 | left_lanes = lanes[curr_lane].left_neighbors 168 | right_lanes = lanes[curr_lane].right_neighbors 169 | left_lane = None 170 | right_lane = None 171 | curr_index = start 172 | 173 | for l_lane in left_lanes: 174 | if l_lane.self_start_index <= curr_index <= l_lane.self_end_index and not l_lane.feature_id in curr_lane_ids: 175 | left_lane = l_lane 176 | 177 | for r_lane in right_lanes: 178 | if r_lane.self_start_index <= curr_index <= r_lane.self_end_index and not r_lane.feature_id in curr_lane_ids: 179 | right_lane = r_lane 180 | 181 | if left_lane is not None: 182 | left_polyline = lane_polylines[left_lane.feature_id] 183 | start = np.argmin(np.linalg.norm(traj[-1, :2] - left_polyline[:, :2], axis=-1)) 184 | neighbor_lane_ids[left_lane.feature_id] = start 185 | 186 | if right_lane is not None: 187 | right_polyline = lane_polylines[right_lane.feature_id] 188 | start = np.argmin(np.linalg.norm(traj[-1, :2] - right_polyline[:, :2], axis=-1)) 189 | neighbor_lane_ids[right_lane.feature_id] = start 190 | 191 | return neighbor_lane_ids 192 | 193 | def find_neareast_point(curr_point, line): 194 | distance_to_curr_point = np.linalg.norm(curr_point[np.newaxis, :2] - line[:, :2], axis=-1) 195 | neareast_point = line[np.argmin(distance_to_curr_point)] 196 | 197 | return neareast_point 198 | 199 | def find_map_waypoint(pos, polylines): 200 | waypoint = [-1, -1, 1e9, 1e9] 201 | direction_threshold = 10 202 | 203 | for id, polyline in polylines.items(): 204 | distance_to_gt = np.linalg.norm(pos[np.newaxis, :2] - polyline[:, :2], axis=-1) 205 | direction_to_gt = wrap_to_pi(pos[np.newaxis, 2]-polyline[:, 2]) 206 | 207 | for i, j, k in zip(range(polyline.shape[0]), distance_to_gt, direction_to_gt): 208 | if j < waypoint[2] and np.abs(k) <= np.radians(direction_threshold): 209 | waypoint = [id, i, j, k] 210 | 211 | lane_id = waypoint[0] 212 | waypoint_id = waypoint[1] 213 | 214 | if lane_id > 0: 215 | return lane_id, waypoint_id 216 | else: 217 | return None, None 218 | 219 | from .plan_utils import generate_target_course 220 | 221 | def ref_line_norm(ref_line, center, angle): 222 | xy = LineString(ref_line[:, 0:2]) 223 | xy = affine_transform(xy, [1, 0, 0, 1, -center[0], -center[1]]) 224 | xy = rotate(xy, -angle, origin=(0, 0), use_radians=True) 225 | yaw = wrap_to_pi(ref_line[:, 2] - angle) 226 | c = ref_line[:, 3] 227 | info = ref_line[:, 4] 228 | return np.column_stack((xy.coords, yaw, c, info)) 229 | 230 | def find_route(traj, timestep, cur_pos, map_lanes, map_crosswalks, map_signals): 231 | cur_pos = np.array(cur_pos) 232 | lane_polylines = get_polylines(map_lanes) 233 | # print(len(lane_polylines.keys()), cur_pos) 234 | end_lane, end_point = find_map_waypoint(np.array((traj[-1].center_x, traj[-1].center_y, traj[-1].heading)), lane_polylines) 235 | # print(end_lane, end_point) 236 | start_lane, start_point = find_map_waypoint(np.array((traj[0].center_x, traj[0].center_y, traj[0].heading)), lane_polylines) 237 | # print(start_lane, start_point, cur_pos) 238 | cur_lane, _ = find_map_waypoint(cur_pos, lane_polylines) 239 | # print(end_lane, start_lane, cur_lane) 240 | 241 | path_waypoints = [] 242 | for t in range(0, len(traj), 10): 243 | lane, point = find_map_waypoint(np.array((traj[t].center_x, traj[t].center_y, traj[t].heading)), lane_polylines) 244 | path_waypoints.append(lane_polylines[lane][point]) 245 | # print(path_waypoints) 246 | 247 | before_waypoints = [] 248 | if start_point < 40: 249 | if map_lanes[start_lane].entry_lanes: 250 | lane = map_lanes[start_lane].entry_lanes[0] 251 | for waypoint in lane_polylines[lane]: 252 | before_waypoints.append(waypoint) 253 | for waypoint in lane_polylines[start_lane][:start_point]: 254 | before_waypoints.append(waypoint) 255 | 256 | after_waypoints = [] 257 | for waypoint in lane_polylines[end_lane][end_point:]: 258 | after_waypoints.append(waypoint) 259 | if len(after_waypoints) < 40: 260 | if map_lanes[end_lane].exit_lanes: 261 | lane = map_lanes[end_lane].exit_lanes[0] 262 | for waypoint in lane_polylines[lane]: 263 | after_waypoints.append(waypoint) 264 | 265 | waypoints = np.concatenate([before_waypoints[::5], path_waypoints, after_waypoints[::5]], axis=0) 266 | 267 | # generate smooth route 268 | tx, ty, tyaw, tc, _ = generate_target_course(waypoints[:, 0], waypoints[:, 1]) 269 | ref_line = np.column_stack([tx, ty, tyaw, tc]) 270 | 271 | # print(ref_line.shape) 272 | 273 | # get reference path at current timestep 274 | current_location = np.argmin(np.linalg.norm(ref_line[:, :2] - cur_pos[np.newaxis, :2], axis=-1)) 275 | start_index = np.max([current_location-200, 0]) 276 | 277 | ref_line = ref_line[start_index:start_index+1200] 278 | 279 | # add speed limit, crosswalk, and traffic signal info to ref route 280 | line_info = np.zeros(shape=(ref_line.shape[0], 1)) 281 | speed_limit = map_lanes[cur_lane].speed_limit_mph / 2.237 282 | ref_line = np.concatenate([ref_line, line_info], axis=-1) 283 | crosswalks = [Polygon([(point.x, point.y) for point in crosswalk.polygon]) for _, crosswalk in map_crosswalks.items()] 284 | signals = [Point([signal.stop_point.x, signal.stop_point.y]) for signal in map_signals[timestep].lane_states if signal.state in [1, 4, 7]] 285 | 286 | for i in range(ref_line.shape[0]): 287 | if any([Point(ref_line[i, :2]).distance(signal) < 0.2 for signal in signals]): 288 | ref_line[i, 4] = 0 # red light 289 | elif any([crosswalk.contains(Point(ref_line[i, :2])) for crosswalk in crosswalks]): 290 | ref_line[i, 4] = 1 # crosswalk 291 | else: 292 | ref_line[i, 4] = speed_limit 293 | 294 | return ref_line 295 | 296 | def imputer(traj): 297 | x, y, v_x, v_y, theta = traj[:, 0], traj[:, 1], traj[:, 3], traj[:, 4], traj[:, 2] 298 | 299 | if np.any(x==0): 300 | for i in reversed(range(traj.shape[0])): 301 | if x[i] == 0 and i!= traj.shape[0]-1: 302 | v_x[i] = v_x[i+1] 303 | v_y[i] = v_y[i+1] 304 | x[i] = x[i+1] - v_x[i]*0.1 305 | y[i] = y[i+1] - v_y[i]*0.1 306 | theta[i] = theta[i+1] 307 | return np.column_stack((x, y, theta, v_x, v_y)) 308 | else: 309 | return np.column_stack((x, y, theta, v_x, v_y)) 310 | 311 | def goal_norm(goal, center, angle): 312 | x = goal[0] - center[0] 313 | y = goal[1] - center[1] 314 | new_x = x * np.cos(angle) + y * np.sin(angle) 315 | new_y = -x * np.sin(angle) + y * np.cos(angle) 316 | return np.array([new_x, new_y]) 317 | 318 | def agent_norm(traj, center, angle, impute=False): 319 | if impute: 320 | traj = imputer(traj[:, :5]) 321 | 322 | x, y = traj[:, 0] - center[0], traj[:, 1] - center[1] 323 | # angle = np.pi/ 2 - angle 324 | tx = np.cos(angle) * x + np.sin(angle) * y 325 | ty = -np.sin(angle) * x + np.cos(angle) * y 326 | line_rotate = np.stack([tx, ty], axis=-1) 327 | line_rotate[traj[:, :2]==0] = 0 328 | # line_rotate[traj[:, 1]==0, 1] = 0 329 | 330 | heading = wrap_to_pi(traj[:, 4] - angle) 331 | heading[traj[:, 4]==0] = 0 332 | 333 | if traj.shape[-1] > 3: 334 | velocity_x = traj[:, 2] * np.cos(angle) + traj[:, 3] * np.sin(angle) 335 | velocity_x[traj[:, 2]==0] = 0 336 | velocity_y = traj[:, 3] * np.cos(angle) - traj[:, 2] * np.sin(angle) 337 | velocity_y[traj[:, 3]==0] = 0 338 | return np.column_stack((line_rotate, heading, velocity_x, velocity_y)) 339 | else: 340 | return np.column_stack((line_rotate, heading)) 341 | 342 | def agent_norm_left(traj, center, angle, impute=False): 343 | if impute: 344 | traj = imputer(traj[:, :5]) 345 | 346 | x, y = traj[:, 0] - center[0], traj[:, 1] - center[1] 347 | angle = np.pi/ 2 - angle 348 | tx = np.cos(angle) * x - np.sin(angle) * y 349 | ty = np.sin(angle) * x + np.cos(angle) * y 350 | line_rotate = np.stack([tx, ty], axis=-1) 351 | line_rotate[traj[:, :2]==0] = 0 352 | # line_rotate[traj[:, 1]==0, 1] = 0 353 | 354 | heading = wrap_to_pi(traj[:, 4] - angle) 355 | heading[traj[:, 4]==0] = 0 356 | 357 | if traj.shape[-1] > 3: 358 | velocity_x = traj[:, 2] * np.cos(angle) - traj[:, 3] * np.sin(angle) 359 | velocity_x[traj[:, 2]==0] = 0 360 | velocity_y = traj[:, 3] * np.cos(angle) + traj[:, 2] * np.sin(angle) 361 | velocity_y[traj[:, 3]==0] = 0 362 | return np.column_stack((line_rotate, heading, velocity_x, velocity_y)) 363 | else: 364 | return np.column_stack((line_rotate, heading)) 365 | 366 | def map_norm(map_line, center, angle): 367 | self_line = LineString(map_line[:, 0:2]) 368 | self_line = affine_transform(self_line, [1, 0, 0, 1, -center[0], -center[1]]) 369 | self_line = rotate(self_line, -angle, origin=(0, 0), use_radians=True) 370 | self_line = np.array(self_line.coords) 371 | self_line[map_line[:, 0:2]==0] = 0 372 | self_heading = wrap_to_pi(map_line[:, 2] - angle) 373 | 374 | if map_line.shape[1] > 3: 375 | left_line = LineString(map_line[:, 3:5]) 376 | left_line = affine_transform(left_line, [1, 0, 0, 1, -center[0], -center[1]]) 377 | left_line = rotate(left_line, -angle, origin=(0, 0), use_radians=True) 378 | left_line = np.array(left_line.coords) 379 | left_line[map_line[:, 3:5]==0] = 0 380 | left_heading = wrap_to_pi(map_line[:, 5] - angle) 381 | left_heading[map_line[:, 5]==0] = 0 382 | 383 | right_line = LineString(map_line[:, 6:8]) 384 | right_line = affine_transform(right_line, [1, 0, 0, 1, -center[0], -center[1]]) 385 | right_line = rotate(right_line, -angle, origin=(0, 0), use_radians=True) 386 | right_line = np.array(right_line.coords) 387 | right_line[map_line[:, 6:8]==0] = 0 388 | right_heading = wrap_to_pi(map_line[:, 8] - angle) 389 | right_heading[map_line[:, 8]==0] = 0 390 | 391 | return np.column_stack((self_line, self_heading, left_line, left_heading, right_line, right_heading)) 392 | else: 393 | return np.column_stack((self_line, self_heading)) 394 | 395 | 396 | def map_norm_left(map_line, center, angle): 397 | angle = np.pi/ 2 - angle 398 | x, y = map_line[:, 0] - center[0], map_line[:, 1] - center[1] 399 | tx = np.cos(angle) * x - np.sin(angle) * y 400 | ty = np.sin(angle) * x + np.cos(angle) * y 401 | self_line = np.stack([tx, ty], axis=-1) 402 | self_line[map_line[:, 0:2]==0] = 0 403 | self_heading = wrap_to_pi(map_line[:, 2] - angle) 404 | self_heading[map_line[:, 2]==0] = 0 405 | 406 | if map_line.shape[1] > 3: 407 | x, y = map_line[:, 3] - center[0], map_line[:, 4] - center[1] 408 | tx = np.cos(angle) * x - np.sin(angle) * y 409 | ty = np.sin(angle) * x + np.cos(angle) * y 410 | left_line = np.stack([tx, ty], axis=-1) 411 | left_line[map_line[:, 3:5]==0] = 0 412 | 413 | left_heading = wrap_to_pi(map_line[:, 5] - angle) 414 | left_heading[map_line[:, 5]==0] = 0 415 | 416 | x, y = map_line[:, 6] - center[0], map_line[:, 7] - center[1] 417 | tx = np.cos(angle) * x - np.sin(angle) * y 418 | ty = np.sin(angle) * x + np.cos(angle) * y 419 | right_line = np.stack([tx, ty], axis=-1) 420 | right_line[map_line[:, 6:8]==0] = 0 421 | 422 | right_heading = wrap_to_pi(map_line[:, 8] - angle) 423 | right_heading[map_line[:, 8]==0] = 0 424 | 425 | return np.column_stack((self_line, self_heading, left_line, left_heading, right_line, right_heading)) 426 | else: 427 | return np.column_stack((self_line, self_heading)) 428 | 429 | 430 | 431 | -------------------------------------------------------------------------------- /utils/occupancy_render_utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from glob import glob 3 | import dataclasses 4 | 5 | import numpy as np 6 | import math 7 | 8 | from waymo_open_dataset.protos import scenario_pb2 9 | from waymo_open_dataset.utils.occupancy_flow_renderer import _transform_to_image_coordinates, rotate_points_around_origin 10 | from waymo_open_dataset.utils import occupancy_flow_data 11 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2 12 | from waymo_open_dataset.protos import scenario_pb2 13 | 14 | import matplotlib.pyplot as plt 15 | import matplotlib.style as mplstyle 16 | 17 | from PIL import Image 18 | 19 | from .waymo_tf_utils import linecolormap, road_label, road_line_map, traffic_light_map, light_state_map_num 20 | 21 | _ObjectType = scenario_pb2.Track.ObjectType 22 | 23 | @dataclasses.dataclass 24 | class _SampledPoints: 25 | """Set of points sampled from agent boxes. 26 | 27 | All fields have shape - 28 | [batch_size, num_agents, num_steps, num_points] where num_points is 29 | (points_per_side_length * points_per_side_width). 30 | """ 31 | # [batch, num_agents, num_steps, points_per_agent]. 32 | x: tf.Tensor 33 | # [batch, num_agents, num_steps, points_per_agent]. 34 | y: tf.Tensor 35 | # [batch, num_agents, num_steps, points_per_agent]. 36 | z: tf.Tensor 37 | # [batch, num_agents, num_steps, points_per_agent]. 38 | agent_type: tf.Tensor 39 | # [batch, num_agents, num_steps, points_per_agent]. 40 | valid: tf.Tensor 41 | 42 | def pack_trajs(parsed_data, time=199): 43 | #x, y, z, heading, length, width, valid required for occupancy 44 | tracks = parsed_data.tracks 45 | traj_tensor = np.zeros((len(tracks), time+1, 10)) 46 | valid_tensor = np.zeros((len(tracks), time+1)) 47 | sdc_id = parsed_data.sdc_track_index 48 | goal = None 49 | 50 | for i, track in enumerate(tracks): 51 | object_type = track.object_type 52 | for j, state in enumerate(track.states): 53 | if state.valid: 54 | traj_tensor[i, j] = np.array([state.center_x, state.center_y, state.velocity_x, state.velocity_y, 55 | state.heading, state.center_z, state.length, state.width, state.height, object_type]) 56 | valid_tensor[i, j] = 1 57 | if i== sdc_id: 58 | goal = [state.center_x, state.center_y] 59 | 60 | traj_tensor = tf.convert_to_tensor(traj_tensor, tf.float32) 61 | valid_tensor = tf.convert_to_tensor(valid_tensor, tf.int32)[...,tf.newaxis] 62 | 63 | return traj_tensor, valid_tensor, goal 64 | 65 | def _np_to_img_coordinate(points_x, points_y, config): 66 | pixels_per_meter = config.pixels_per_meter 67 | points_x = np.round(points_x * pixels_per_meter) + config.sdc_x_in_grid 68 | points_y = np.round(-points_y * pixels_per_meter) + config.sdc_y_in_grid 69 | 70 | # Filter out points that are located outside the FOV of topdown map. 71 | point_is_in_fov = np.logical_and( 72 | np.logical_and( 73 | np.greater_equal(points_x, 0), np.greater_equal(points_y, 0)), 74 | np.logical_and( 75 | np.less(points_x, config.grid_width_cells), 76 | np.less(points_y, config.grid_height_cells))) 77 | 78 | return points_x, points_y, point_is_in_fov.astype(np.float32) 79 | 80 | 81 | def wrap_to_pi(theta): 82 | return (theta+np.pi) % (2*np.pi) - np.pi 83 | 84 | 85 | def get_polylines_type(lines, traffic_light_lanes, stop_sign_lanes, config, ego_xyh, polyline_len): 86 | tl_keys = set(traffic_light_lanes.keys()) 87 | polylines = {} 88 | org_polylnes = [] 89 | for line in lines.keys(): 90 | types = lines[line].type 91 | polyline = np.array([(map_point.x, map_point.y) for map_point in lines[line].polyline]) 92 | if polyline.shape[0] <= 2: 93 | continue 94 | # rotations: 95 | x, y = polyline[:, 0] - ego_xyh[0], polyline[:, 1] - ego_xyh[1] 96 | angle = np.pi/ 2 - ego_xyh[2] 97 | tx = np.cos(angle) * x - np.sin(angle) * y 98 | ty = np.sin(angle) * x + np.cos(angle) * y 99 | new_polyline = np.stack([tx, ty], axis=-1) 100 | 101 | if len(polyline) > 1: 102 | direction = wrap_to_pi(np.arctan2(polyline[1:, 1]-polyline[:-1, 1], polyline[1:, 0]-polyline[:-1, 0]) - angle) 103 | direction = np.insert(direction, -1, direction[-1])[:, np.newaxis] 104 | else: 105 | direction = np.array([0])[:, np.newaxis] 106 | 107 | trajs = np.concatenate([new_polyline, direction], axis=-1) 108 | 109 | #(x_img, y_img, in_fov) 110 | ogm_states = np.zeros((trajs.shape[0], 3)) 111 | points_x, points_y, point_is_in_fov = _np_to_img_coordinate(new_polyline[:,0], new_polyline[:,1], config) 112 | 113 | ogm_states = np.stack([points_x, points_y, point_is_in_fov], axis=-1) 114 | 115 | # attrib_states: (type, tl_state, near_tl, stop_sign, sp_limit) 116 | attrib_states = np.zeros((trajs.shape[0], 5)) 117 | attrib_states[:, 0] = types 118 | if line in tl_keys: 119 | attrib_states[:, 1] = light_state_map_num[traffic_light_lanes[line][0]] 120 | near_tl = np.less_equal(np.linalg.norm(polyline[:, :2] - np.array(traffic_light_lanes[line][1:])[np.newaxis,...], axis=-1), 3).astype(np.float32) 121 | attrib_states[:, 2] = near_tl 122 | # add stop sign 123 | if line in stop_sign_lanes: 124 | attrib_states[:, 3] = True 125 | try: 126 | attrib_states[:, 4] = lines[line].speed_limit_mph / 2.237 127 | except: 128 | attrib_states[:, 4] = 0 129 | 130 | ogm_center = np.array([0, (config.sdc_y_in_grid - 0.5*config.grid_height_cells)/config.pixels_per_meter])[np.newaxis,...] 131 | polyline_traj = np.concatenate((trajs, attrib_states, ogm_states), axis=-1) 132 | org_polylnes.append(polyline_traj) 133 | traj_splits = np.array_split(polyline_traj, np.ceil(polyline_traj.shape[0] / polyline_len), axis=0) 134 | i = 0 135 | for sub_traj in traj_splits: 136 | 137 | ade = np.mean(np.linalg.norm(sub_traj[:, :2] - ogm_center, axis=-1)) 138 | polylines[f'{line}_{i}'] = (sub_traj, ade) 139 | 140 | i += 1 141 | 142 | return polylines, org_polylnes 143 | 144 | def render_roadgraph_tf(rg_tensor): 145 | if len(rg_tensor)==0: 146 | print('Warning: RG tensor is 0!') 147 | return tf.zeros((512, 512, 3)) 148 | rg_tensor = tf.convert_to_tensor(np.concatenate(rg_tensor,axis=0)) 149 | topdown_shape = [512, 512, 1] 150 | rg_x, rg_y, point_is_in_fov = rg_tensor[:, -3], rg_tensor[:, -2], rg_tensor[:, -1] 151 | 152 | types = rg_tensor[:, 3] 153 | tl_state = rg_tensor[:, 4] 154 | 155 | should_render_point = tf.cast(point_is_in_fov, tf.bool) 156 | point_indices = tf.cast(tf.where(should_render_point), tf.int32) 157 | x_img_coord = tf.gather_nd(rg_x, point_indices)[..., tf.newaxis] 158 | y_img_coord = tf.gather_nd(rg_y, point_indices)[..., tf.newaxis] 159 | 160 | types = tf.gather_nd(types, point_indices)[..., tf.newaxis] 161 | tl_state = tf.gather_nd(tl_state, point_indices)[..., tf.newaxis] 162 | 163 | num_points_to_render = point_indices.shape.as_list()[0] 164 | 165 | # [num_points_to_render, 3] 166 | xy_img_coord = tf.concat( 167 | [ 168 | # point_indices[:, :1], 169 | tf.cast(y_img_coord, tf.int32), 170 | tf.cast(x_img_coord, tf.int32), 171 | ], 172 | axis=1, 173 | ) 174 | gt_values = tf.ones_like(x_img_coord, dtype=tf.float32) 175 | 176 | # [batch_size, grid_height_cells, grid_width_cells, 1] 177 | rg_viz = tf.scatter_nd(xy_img_coord, gt_values, topdown_shape) 178 | # assert_shapes([(rg_viz, topdown_shape)]) 179 | rg_type = tf.math.divide_no_nan(tf.cast(tf.scatter_nd(xy_img_coord, types, topdown_shape),tf.float32) , rg_viz) 180 | rg_tl_type = tf.math.divide_no_nan(tf.cast(tf.scatter_nd(xy_img_coord, tl_state, topdown_shape),tf.float32) , rg_viz) 181 | rg_viz = tf.clip_by_value(rg_viz, 0.0, 1.0) 182 | 183 | return tf.concat([rg_viz, rg_type, rg_tl_type],axis=-1) 184 | 185 | 186 | def get_crosswalk_type(lines, traffic_light_lanes, stop_sign_lanes, config, ego_xyh, polyline_len): 187 | tl_keys = set(traffic_light_lanes.keys()) 188 | polylines = {} 189 | org_polylnes = [] 190 | # id_list = [] 191 | for line in lines.keys(): 192 | types = 18 193 | polyline = np.array([(map_point.x, map_point.y) for map_point in lines[line].polygon]) 194 | if polyline.shape[0] <= 2: 195 | continue 196 | # rotations: 197 | x, y = polyline[:, 0] - ego_xyh[0], polyline[:, 1] - ego_xyh[1] 198 | angle = np.pi/ 2 - ego_xyh[2] 199 | tx = np.cos(angle) * x - np.sin(angle) * y 200 | ty = np.sin(angle) * x + np.cos(angle) * y 201 | new_polyline = np.stack([tx, ty], axis=-1) 202 | 203 | if len(polyline) > 1: 204 | direction = wrap_to_pi(np.arctan2(polyline[1:, 1]-polyline[:-1, 1], polyline[1:, 0]-polyline[:-1, 0]) - angle) 205 | direction = np.insert(direction, -1, direction[-1])[:, np.newaxis] 206 | else: 207 | direction = np.array([0])[:, np.newaxis] 208 | 209 | trajs = np.concatenate([new_polyline, direction], axis=-1) 210 | 211 | #(x_img, y_img, in_fov) 212 | ogm_states = np.zeros((trajs.shape[0], 3)) 213 | points_x, points_y, point_is_in_fov = _np_to_img_coordinate(new_polyline[:,0], new_polyline[:,1], config) 214 | 215 | ogm_states = np.stack([points_x, points_y, point_is_in_fov], axis=-1) 216 | 217 | # attrib_states: (type, tl_state, near_tl, stop_sign, sp_limit) 218 | attrib_states = np.zeros((trajs.shape[0], 5)) 219 | attrib_states[:, 0] = types 220 | if line in tl_keys: 221 | attrib_states[:, 1] = light_state_map_num[traffic_light_lanes[line][0]] 222 | near_tl = np.less_equal(np.linalg.norm(polyline[:, :2] - np.array(traffic_light_lanes[line][1:])[np.newaxis,...], axis=-1), 3).astype(np.float32) 223 | attrib_states[:, 2] = near_tl 224 | # add stop sign 225 | if line in stop_sign_lanes: 226 | attrib_states[:, 3] = True 227 | try: 228 | attrib_states[:, 4] = lines[line].speed_limit_mph / 2.237 229 | except: 230 | attrib_states[:, 4] = 0 231 | 232 | ogm_center = np.array([0, (config.sdc_y_in_grid - 0.5*config.grid_height_cells)/config.pixels_per_meter])[np.newaxis,...] 233 | polyline_traj = np.concatenate((trajs, attrib_states, ogm_states), axis=-1) 234 | org_polylnes.append(polyline_traj) 235 | # org_polylnes[line] = polyline_traj 236 | traj_splits = np.array_split(polyline_traj, np.ceil(polyline_traj.shape[0] / polyline_len), axis=0) 237 | i = 0 238 | 239 | for sub_traj in traj_splits: 240 | 241 | ade = np.mean(np.linalg.norm(sub_traj[:, :2] - ogm_center, axis=-1)) 242 | polylines[f'{line}_{i}'] = (sub_traj, ade) 243 | 244 | i += 1 245 | 246 | return polylines, org_polylnes 247 | 248 | 249 | def pack_maps(lanes, roads, crosswalks, traffic_light_lanes, stop_sign_lanes, config, ego_xy, polyline_len=20): 250 | ''' 251 | inputs: (Dict) lanes, roads, crosswalks 252 | outputs: dict of polyline-trajs with their ids which are inside fov 253 | ''' 254 | all_poly_trajs = {} 255 | org_ploys = [] 256 | 257 | lane_polylines, lane_ids = get_polylines_type(lanes, traffic_light_lanes, stop_sign_lanes, config, ego_xy, polyline_len) 258 | all_poly_trajs.update(lane_polylines) 259 | # print([id.shape for id in lane_ids.values()]) 260 | org_ploys.extend(lane_ids) 261 | 262 | roads_polylines, road_ids = get_polylines_type(roads, traffic_light_lanes, stop_sign_lanes, config, ego_xy, polyline_len) 263 | all_poly_trajs.update(roads_polylines) 264 | org_ploys.extend(road_ids) 265 | 266 | cw_polylines, cw_ids = get_crosswalk_type(crosswalks, traffic_light_lanes, stop_sign_lanes, config, ego_xy, polyline_len) 267 | all_poly_trajs.update(cw_polylines) 268 | org_ploys.extend(cw_ids) 269 | 270 | return all_poly_trajs, org_ploys 271 | 272 | 273 | def points_sample( 274 | traj_tensor, 275 | valid_tensor, 276 | ego_trajs, 277 | unit_x, 278 | unit_y, 279 | config 280 | ): 281 | 282 | # traj_tensor, valid_tensor = pack_trajs(parsed_data) 283 | 284 | x, y, z = traj_tensor[...,0:1], traj_tensor[...,1:2], traj_tensor[...,5:6] 285 | length, width = traj_tensor[...,6:7], traj_tensor[...,7:8] 286 | bbox_yaw = traj_tensor[...,4:5] 287 | 288 | sdc_x = ego_trajs[0:1][tf.newaxis, tf.newaxis, :] 289 | sdc_y = ego_trajs[1:2][tf.newaxis, tf.newaxis, :] 290 | sdc_z = ego_trajs[5:6][tf.newaxis, tf.newaxis, :] 291 | 292 | x = x - sdc_x 293 | y = y - sdc_y 294 | z = z - sdc_z 295 | 296 | angle = math.pi / 2 - ego_trajs[4:5][tf.newaxis, tf.newaxis, :] 297 | x, y = rotate_points_around_origin(x, y, angle) 298 | bbox_yaw = bbox_yaw + angle 299 | 300 | agent_type = traj_tensor[...,-1][...,tf.newaxis] 301 | 302 | return _sample_points_from_agent_boxes( 303 | x=x, 304 | y=y, 305 | z=z, 306 | bbox_yaw=bbox_yaw, 307 | width=width, 308 | length=length, 309 | agent_type=agent_type, 310 | valid=valid_tensor, 311 | unit_x=unit_x, 312 | unit_y=unit_y 313 | # points_per_side_length=config.points_per_side_length, 314 | # points_per_side_width=config.points_per_side_width, 315 | ) 316 | 317 | def sample_filter( 318 | traj_tensor, 319 | valid_tensor, 320 | ego_traj, 321 | config, 322 | times, 323 | unit_x, 324 | unit_y, 325 | include_observed=True, 326 | include_occluded=True 327 | ): 328 | 329 | b,e = _get_num_steps_from_times(times, config) 330 | 331 | # Sample points from agent boxes over specified time frames. 332 | # All fields have shape [num_agents, num_steps, points_per_agent]. 333 | sampled_points = points_sample( 334 | traj_tensor[:,b:e], 335 | valid_tensor[:,b:e], 336 | ego_traj, 337 | unit_x,unit_y, 338 | config 339 | ) 340 | 341 | agent_valid = tf.cast(sampled_points.valid, tf.bool) 342 | 343 | include_all = include_observed and include_occluded 344 | if not include_all and 'future' in times: 345 | history_times = ['past', 'current'] 346 | b,e = _get_num_steps_from_times(history_times, config) 347 | agent_is_observed = valid_tensor[:,b:e] 348 | # [num_agents, 1, 1] 349 | agent_is_observed = tf.reduce_max(agent_is_observed, axis=1, keepdims=True) 350 | agent_is_observed = tf.cast(agent_is_observed, tf.bool) 351 | 352 | if include_observed: 353 | agent_filter = agent_is_observed 354 | elif include_occluded: 355 | agent_filter = tf.logical_not(agent_is_observed) 356 | else: # Both observed and occluded are off. 357 | raise ValueError('Either observed or occluded agents must be requested.') 358 | agent_valid = tf.logical_and(agent_valid, agent_filter) 359 | 360 | return _SampledPoints( 361 | x=sampled_points.x, 362 | y=sampled_points.y, 363 | z=sampled_points.z, 364 | agent_type=sampled_points.agent_type, 365 | valid=agent_valid, 366 | ) 367 | 368 | def render_ego_occupancy( 369 | sampled_points, 370 | sdc_ids, 371 | config 372 | ): 373 | agent_x = sampled_points.x 374 | agent_y = sampled_points.y 375 | agent_type = sampled_points.agent_type 376 | agent_valid = sampled_points.valid 377 | 378 | # Set up assert_shapes. 379 | assert_shapes = tf.debugging.assert_shapes 380 | num_agents, num_steps, points_per_agent = agent_x.shape.as_list() 381 | topdown_shape = [ 382 | config.grid_height_cells, config.grid_width_cells, num_steps 383 | ] 384 | 385 | # print(topdown_shape) 386 | 387 | # Transform from world coordinates to topdown image coordinates. 388 | # All 3 have shape: [batch, num_agents, num_steps, points_per_agent] 389 | agent_x, agent_y, point_is_in_fov = _transform_to_image_coordinates( 390 | points_x=agent_x, 391 | points_y=agent_y, 392 | config=config, 393 | ) 394 | 395 | assert_shapes([(point_is_in_fov, 396 | [num_agents, num_steps, points_per_agent])]) 397 | 398 | # Filter out points from invalid objects. 399 | agent_valid = tf.cast(agent_valid, tf.bool) 400 | point_is_in_fov_and_valid = tf.logical_and(point_is_in_fov, agent_valid) 401 | agent_x, agent_y, should_render_point = agent_x[sdc_ids][tf.newaxis,...], agent_y[sdc_ids][tf.newaxis,...], point_is_in_fov_and_valid[sdc_ids][tf.newaxis,...] 402 | 403 | # Collect points for ego vehicle 404 | assert_shapes([ 405 | (should_render_point, 406 | [1, num_steps, points_per_agent]), 407 | ]) 408 | 409 | # [num_points_to_render, 4] 410 | point_indices = tf.cast(tf.where(should_render_point), tf.int32) 411 | 412 | # [num_points_to_render, 1] 413 | x_img_coord = tf.gather_nd(agent_x, point_indices)[..., tf.newaxis] 414 | y_img_coord = tf.gather_nd(agent_y, point_indices)[..., tf.newaxis] 415 | 416 | num_points_to_render = point_indices.shape.as_list()[0] 417 | assert_shapes([(x_img_coord, [num_points_to_render, 1]), 418 | (y_img_coord, [num_points_to_render, 1])]) 419 | 420 | # [num_points_to_render, 4] 421 | xy_img_coord = tf.concat( 422 | [ 423 | # point_indices[:, :1], 424 | tf.cast(y_img_coord, tf.int32), 425 | tf.cast(x_img_coord, tf.int32), 426 | point_indices[:, 1:2], 427 | ], 428 | axis=1, 429 | ) 430 | # [num_points_to_render] 431 | gt_values = tf.squeeze(tf.ones_like(x_img_coord, dtype=tf.float32), axis=-1) 432 | 433 | # [batch_size, grid_height_cells, grid_width_cells, num_steps] 434 | topdown = tf.scatter_nd(xy_img_coord, gt_values, topdown_shape) 435 | 436 | 437 | assert_shapes([(topdown, topdown_shape)]) 438 | 439 | # scatter_nd() accumulates values if there are repeated indices. Since 440 | # we sample densely, this happens all the time. Clip the final values. 441 | topdown = tf.clip_by_value(topdown, 0.0, 1.0) 442 | return topdown 443 | 444 | 445 | def render_occupancy( 446 | sampled_points, 447 | config, 448 | sdc_ids=None, 449 | ): 450 | 451 | agent_x = sampled_points.x 452 | agent_y = sampled_points.y 453 | agent_type = sampled_points.agent_type 454 | agent_valid = sampled_points.valid 455 | 456 | # Set up assert_shapes. 457 | assert_shapes = tf.debugging.assert_shapes 458 | num_agents, num_steps, points_per_agent = agent_x.shape.as_list() 459 | topdown_shape = [ 460 | config.grid_height_cells, config.grid_width_cells, num_steps 461 | ] 462 | 463 | # print(topdown_shape) 464 | 465 | # Transform from world coordinates to topdown image coordinates. 466 | # All 3 have shape: [batch, num_agents, num_steps, points_per_agent] 467 | agent_x, agent_y, point_is_in_fov = _transform_to_image_coordinates( 468 | points_x=agent_x, 469 | points_y=agent_y, 470 | config=config, 471 | ) 472 | assert_shapes([(point_is_in_fov, 473 | [num_agents, num_steps, points_per_agent])]) 474 | 475 | # Filter out points from invalid objects. 476 | agent_valid = tf.cast(agent_valid, tf.bool) 477 | 478 | #cases masking the ego car: 479 | if sdc_ids is not None: 480 | mask = np.ones((num_agents, 1, 1)) 481 | mask[sdc_ids] = 0 482 | mask = tf.convert_to_tensor(mask, tf.bool) 483 | agent_valid = tf.logical_and(agent_valid, mask) 484 | 485 | point_is_in_fov_and_valid = tf.logical_and(point_is_in_fov, agent_valid) 486 | 487 | occupancies = {} 488 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES: 489 | # Collect points for each agent type, i.e., pedestrians and vehicles. 490 | agent_type_matches = tf.equal(agent_type, object_type) 491 | should_render_point = tf.logical_and(point_is_in_fov_and_valid, 492 | agent_type_matches) 493 | 494 | assert_shapes([ 495 | (should_render_point, 496 | [num_agents, num_steps, points_per_agent]), 497 | ]) 498 | 499 | # [num_points_to_render, 4] 500 | point_indices = tf.cast(tf.where(should_render_point), tf.int32) 501 | 502 | # [num_points_to_render, 1] 503 | x_img_coord = tf.gather_nd(agent_x, point_indices)[..., tf.newaxis] 504 | y_img_coord = tf.gather_nd(agent_y, point_indices)[..., tf.newaxis] 505 | 506 | num_points_to_render = point_indices.shape.as_list()[0] 507 | assert_shapes([(x_img_coord, [num_points_to_render, 1]), 508 | (y_img_coord, [num_points_to_render, 1])]) 509 | 510 | # [num_points_to_render, 4] 511 | xy_img_coord = tf.concat( 512 | [ 513 | # point_indices[:, :1], 514 | tf.cast(y_img_coord, tf.int32), 515 | tf.cast(x_img_coord, tf.int32), 516 | point_indices[:, 1:2], 517 | ], 518 | axis=1, 519 | ) 520 | # [num_points_to_render] 521 | gt_values = tf.squeeze(tf.ones_like(x_img_coord, dtype=tf.float32), axis=-1) 522 | 523 | # [batch_size, grid_height_cells, grid_width_cells, num_steps] 524 | topdown = tf.scatter_nd(xy_img_coord, gt_values, topdown_shape) 525 | 526 | 527 | assert_shapes([(topdown, topdown_shape)]) 528 | 529 | # scatter_nd() accumulates values if there are repeated indices. Since 530 | # we sample densely, this happens all the time. Clip the final values. 531 | topdown = tf.clip_by_value(topdown, 0.0, 1.0) 532 | occupancies[object_type] = topdown 533 | 534 | return occupancy_flow_data.AgentGrids( 535 | vehicles=occupancies[_ObjectType.TYPE_VEHICLE], 536 | pedestrians=occupancies[_ObjectType.TYPE_PEDESTRIAN], 537 | cyclists=occupancies[_ObjectType.TYPE_CYCLIST], 538 | ), point_is_in_fov_and_valid 539 | 540 | 541 | def render_flow_from_inputs( 542 | sampled_points, 543 | config, 544 | sdc_ids=None, 545 | ): 546 | 547 | agent_x = sampled_points.x 548 | agent_y = sampled_points.y 549 | agent_type = sampled_points.agent_type 550 | agent_valid = sampled_points.valid 551 | 552 | # Set up assert_shapes. 553 | assert_shapes = tf.debugging.assert_shapes 554 | num_agents, num_steps, points_per_agent = agent_x.shape.as_list() 555 | # The timestep distance between flow steps. 556 | waypoint_size = config.num_future_steps // config.num_waypoints 557 | num_flow_steps = num_steps - waypoint_size 558 | topdown_shape = [ 559 | config.grid_height_cells, config.grid_width_cells,num_flow_steps 560 | ] 561 | 562 | # Transform from world coordinates to topdown image coordinates. 563 | # All 3 have shape: [batch, num_agents, num_steps, points_per_agent] 564 | agent_x, agent_y, point_is_in_fov = _transform_to_image_coordinates( 565 | points_x=agent_x, 566 | points_y=agent_y, 567 | config=config, 568 | ) 569 | 570 | # Filter out points from invalid objects. 571 | agent_valid = tf.cast(agent_valid, tf.bool) 572 | #cases masking the ego car: 573 | if sdc_ids is not None: 574 | mask = np.ones((num_agents, 1, 1)) 575 | mask[sdc_ids] = 0 576 | mask = tf.convert_to_tensor(mask, tf.bool) 577 | agent_valid = tf.logical_and(agent_valid, mask) 578 | 579 | # Backward Flow. 580 | # [num_agents, num_flow_steps, points_per_agent] 581 | dx = agent_x[:, :-waypoint_size, :] - agent_x[:, waypoint_size:, :] 582 | dy = agent_y[:, :-waypoint_size, :] - agent_y[:, waypoint_size:, :] 583 | 584 | # Adjust other fields as well to reduce from num_steps to num_flow_steps. 585 | # agent_x, agent_y: Use later timesteps since flow vectors go back in time. 586 | # [batch_size, num_agents, num_flow_steps, points_per_agent] 587 | agent_x = agent_x[:, waypoint_size:, :] 588 | agent_y = agent_y[:, waypoint_size:, :] 589 | # agent_type: Use later timesteps since flow vectors go back in time. 590 | # [batch_size, num_agents, num_flow_steps, points_per_agent] 591 | agent_type = agent_type[:, waypoint_size:, :] 592 | # point_is_in_fov: Use later timesteps since flow vectors go back in time. 593 | # [batch_size, num_agents, num_flow_steps, points_per_agent] 594 | point_is_in_fov = point_is_in_fov[:, waypoint_size:, :] 595 | # agent_valid: And the two timesteps. They both need to be valid. 596 | # [batch_size, num_agents, num_flow_steps, points_per_agent] 597 | agent_valid = tf.logical_and(agent_valid[:, waypoint_size:, :], 598 | agent_valid[:, :-waypoint_size, :]) 599 | 600 | # [batch_size, num_agents, num_flow_steps, points_per_agent] 601 | point_is_in_fov_and_valid = tf.logical_and(point_is_in_fov, agent_valid) 602 | 603 | flows = {} 604 | for object_type in occupancy_flow_data.ALL_AGENT_TYPES: 605 | # Collect points for each agent type, i.e., pedestrians and vehicles. 606 | agent_type_matches = tf.equal(agent_type, object_type) 607 | should_render_point = tf.logical_and(point_is_in_fov_and_valid, 608 | agent_type_matches) 609 | 610 | # [batch_size, height, width, num_flow_steps, 2] 611 | flow = _render_flow_points_for_one_agent_type( 612 | agent_x=agent_x, 613 | agent_y=agent_y, 614 | dx=dx, 615 | dy=dy, 616 | should_render_point=should_render_point, 617 | topdown_shape=topdown_shape, 618 | ) 619 | flows[object_type] = flow 620 | 621 | return occupancy_flow_data.AgentGrids( 622 | vehicles=flows[_ObjectType.TYPE_VEHICLE], 623 | pedestrians=flows[_ObjectType.TYPE_PEDESTRIAN], 624 | cyclists=flows[_ObjectType.TYPE_CYCLIST], 625 | ) 626 | 627 | 628 | def _render_flow_points_for_one_agent_type( 629 | agent_x, 630 | agent_y, 631 | dx, 632 | dy, 633 | should_render_point, 634 | topdown_shape, 635 | ): 636 | assert_shapes = tf.debugging.assert_shapes 637 | 638 | # Scatter points across topdown maps for each timestep. The tensor 639 | # `point_indices` holds the indices where `should_render_point` is True. 640 | # It is a 2-D tensor with shape [n, 3], where n is the number of valid 641 | # agent points inside FOV. Each row in this tensor contains indices over 642 | # the following 3 dimensions: (agent, timestep, point). 643 | 644 | # [num_points_to_render, 3] 645 | point_indices = tf.cast(tf.where(should_render_point), tf.int32) 646 | # [num_points_to_render, 1] 647 | x_img_coord = tf.gather_nd(agent_x, point_indices)[..., tf.newaxis] 648 | y_img_coord = tf.gather_nd(agent_y, point_indices)[..., tf.newaxis] 649 | 650 | num_points_to_render = point_indices.shape.as_list()[0] 651 | assert_shapes([(x_img_coord, [num_points_to_render, 1]), 652 | (y_img_coord, [num_points_to_render, 1])]) 653 | 654 | # [num_points_to_render, 4] 655 | xy_img_coord = tf.concat( 656 | [ 657 | tf.cast(y_img_coord, tf.int32), 658 | tf.cast(x_img_coord, tf.int32), 659 | point_indices[:, 1:2], 660 | ], 661 | axis=1, 662 | ) 663 | # [num_points_to_render] 664 | gt_values_dx = tf.gather_nd(dx, point_indices) 665 | gt_values_dy = tf.gather_nd(dy, point_indices) 666 | 667 | gt_values = tf.squeeze(tf.ones_like(x_img_coord, dtype=tf.float32), axis=-1) 668 | 669 | # [batch_size, grid_height_cells, grid_width_cells, num_flow_steps] 670 | flow_x = tf.scatter_nd(xy_img_coord, gt_values_dx, topdown_shape) 671 | flow_y = tf.scatter_nd(xy_img_coord, gt_values_dy, topdown_shape) 672 | num_values_per_pixel = tf.scatter_nd(xy_img_coord, gt_values, topdown_shape) 673 | 674 | # Undo the accumulation effect of tf.scatter_nd() for repeated indices. 675 | flow_x = tf.math.divide_no_nan(flow_x, num_values_per_pixel) 676 | flow_y = tf.math.divide_no_nan(flow_y, num_values_per_pixel) 677 | 678 | # [batch_size, grid_height_cells, grid_width_cells, num_flow_steps, 2] 679 | flow = tf.stack([flow_x, flow_y], axis=-1) 680 | return flow 681 | 682 | 683 | def generate_units(points_per_side_length, points_per_side_width): 684 | 685 | if points_per_side_length < 1: 686 | raise ValueError('points_per_side_length must be >= 1') 687 | if points_per_side_width < 1: 688 | raise ValueError('points_per_side_width must be >= 1') 689 | 690 | # Create sample points on a unit square or boundary depending on flag. 691 | if points_per_side_length == 1: 692 | step_x = 0.0 693 | else: 694 | step_x = 1.0 / (points_per_side_length - 1) 695 | if points_per_side_width == 1: 696 | step_y = 0.0 697 | else: 698 | step_y = 1.0 / (points_per_side_width - 1) 699 | unit_x = [] 700 | unit_y = [] 701 | for xi in range(points_per_side_length): 702 | for yi in range(points_per_side_width): 703 | unit_x.append(xi * step_x - 0.5) 704 | unit_y.append(yi * step_y - 0.5) 705 | 706 | # Center unit_x and unit_y if there was only 1 point on those dimensions. 707 | if points_per_side_length == 1: 708 | unit_x = np.array(unit_x) + 0.5 709 | if points_per_side_width == 1: 710 | unit_y = np.array(unit_y) + 0.5 711 | 712 | unit_x = tf.convert_to_tensor(unit_x, tf.float32) 713 | unit_y = tf.convert_to_tensor(unit_y, tf.float32) 714 | 715 | return unit_x, unit_y 716 | 717 | 718 | def _sample_ego_from_boxes(x, y, bbox_yaw, width, length, unit_x, unit_y): 719 | sin_yaw = tf.sin(bbox_yaw) 720 | cos_yaw = tf.cos(bbox_yaw) 721 | 722 | tx = cos_yaw * length * unit_x - sin_yaw * width * unit_y + x 723 | ty = sin_yaw * length * unit_x + cos_yaw * width * unit_y + y 724 | return tx, ty 725 | 726 | 727 | def _sample_points_from_agent_boxes( 728 | x, y, z, bbox_yaw, width, length, agent_type, valid, unit_x, unit_y 729 | ): 730 | assert_shapes = tf.debugging.assert_shapes 731 | assert_shapes([(x, [..., 1])]) 732 | x_shape = x.get_shape().as_list() 733 | 734 | # Transform the unit square points to agent dimensions and coordinate frames. 735 | sin_yaw = tf.sin(bbox_yaw) 736 | cos_yaw = tf.cos(bbox_yaw) 737 | 738 | # [..., num_points] 739 | tx = cos_yaw * length * unit_x - sin_yaw * width * unit_y + x 740 | ty = sin_yaw * length * unit_x + cos_yaw * width * unit_y + y 741 | tz = tf.broadcast_to(z, tx.shape) 742 | 743 | # points_shape = x_shape[:-1] + [num_points] 744 | agent_type = tf.broadcast_to(agent_type, tx.shape) 745 | valid = tf.broadcast_to(valid, tx.shape) 746 | 747 | return _SampledPoints(x=tx, y=ty, z=tz, agent_type=agent_type, valid=valid) 748 | 749 | 750 | def _get_num_steps_from_times( 751 | times, 752 | config): 753 | """Returns number of timesteps that exist in requested times.""" 754 | p, c, f = config.num_past_steps, 1, config.num_future_steps 755 | dict_1 = {'past':(0, p), 'current':(p, p+c), 'future':(p+c, p+c+f)} 756 | if len(times)==0: 757 | raise NotImplementedError() 758 | elif len(times)==1: 759 | return dict_1[times[0]] 760 | elif len(times)==2: 761 | assert times[0]=='past' 762 | return (0, p + c) 763 | else: 764 | return (0, p + c + f) 765 | 766 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import multiprocessing 8 | from multiprocessing import Pool, Process 9 | import argparse 10 | import random 11 | 12 | import math 13 | import time 14 | import pandas as pd 15 | 16 | from waymo_open_dataset.protos import occupancy_flow_metrics_pb2, scenario_pb2 17 | from google.protobuf import text_format 18 | 19 | from shapely.geometry import LineString, Point, Polygon 20 | from shapely.affinity import affine_transform, rotate 21 | 22 | import matplotlib.pyplot as plt 23 | import matplotlib as mpl 24 | 25 | from functools import partial 26 | from glob import glob 27 | from tqdm import tqdm 28 | 29 | from utils.occupancy_grid_utils import create_ground_truth_timestep_grids, \ 30 | create_ground_truth_waypoint_grids, _ego_ground_truth_occupancy 31 | from utils.occupancy_render_utils import pack_trajs, pack_maps, \ 32 | render_roadgraph_tf 33 | 34 | from utils.train_utils import * 35 | 36 | class Processor: 37 | def __init__( 38 | self, 39 | height=128, 40 | width=128, 41 | pixels_per_meter=1.6, 42 | hist_len=11, 43 | future_len=50, 44 | gap=5, 45 | num_observed=32, 46 | num_occluded=6, 47 | num_map=3, 48 | map_len=100, 49 | map_buffer=150, 50 | ego_map=6, 51 | ego_map_len=200, 52 | ego_buffer=300, 53 | ref_max_len=1000, 54 | planning_horizon=5, 55 | dt=0.1, 56 | data_files='', 57 | save_dir='', 58 | cumulative_waypoints='false', 59 | ol_test=False, 60 | timestep=199 61 | ): 62 | 63 | self.height = height 64 | self.width = width 65 | self.pixels_per_meter = pixels_per_meter 66 | self.cumulative_waypoints = cumulative_waypoints 67 | 68 | self.hist_len = hist_len 69 | self.future_len = future_len 70 | self.num_observed = num_observed 71 | self.num_occluded = num_occluded 72 | self.num_map = num_map 73 | self.map_len = map_len 74 | self.map_buffer = map_buffer 75 | self.ego_map = ego_map 76 | self.ego_map_len = ego_map_len 77 | self.ego_buffer = ego_buffer 78 | self.ref_max_len = ref_max_len 79 | 80 | self.horizon = planning_horizon #s 81 | self.dt = dt #s 82 | 83 | self.data_files = [data_files] 84 | self.save_dir = save_dir 85 | self.gap = gap 86 | 87 | self.ol_test = ol_test 88 | 89 | self.timestep = timestep 90 | print(f'timestep:{self.timestep}') 91 | 92 | self.test_scenario_ids = None 93 | 94 | self.get_config() 95 | 96 | def get_config(self): 97 | center_x,center_y = int(self.width / 2), int(self.height * 0.75) 98 | config = occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig() 99 | config_text = f""" 100 | num_past_steps: {self.hist_len - 1} 101 | num_future_steps: {self.future_len} 102 | num_waypoints: {self.future_len//5} 103 | cumulative_waypoints: {self.cumulative_waypoints} 104 | normalize_sdc_yaw: true 105 | grid_height_cells: {self.height} 106 | grid_width_cells: {self.width} 107 | sdc_y_in_grid: {center_y} 108 | sdc_x_in_grid: {center_x} 109 | pixels_per_meter: {self.pixels_per_meter} 110 | agent_points_per_side_length: 48 111 | agent_points_per_side_width: 16 112 | """ 113 | 114 | text_format.Parse(config_text, config) 115 | self.config = config 116 | 117 | input_config = occupancy_flow_metrics_pb2.OccupancyFlowTaskConfig() 118 | iconfig_text = f""" 119 | num_past_steps: {self.hist_len - 1} 120 | num_future_steps: {self.future_len} 121 | num_waypoints: {self.future_len//5} 122 | cumulative_waypoints: false 123 | normalize_sdc_yaw: true 124 | grid_height_cells: {self.height*2} 125 | grid_width_cells: {self.width*2} 126 | sdc_y_in_grid: {center_y + int(self.height*0.5)} 127 | sdc_x_in_grid: {center_x + int(self.width*0.5)} 128 | pixels_per_meter: {self.pixels_per_meter} 129 | agent_points_per_side_length: 48 130 | agent_points_per_side_width: 16 131 | """ 132 | 133 | text_format.Parse(iconfig_text, input_config) 134 | self.input_config = input_config 135 | 136 | def build_map(self, map_features, dynamic_map_states): 137 | self.lanes = {} 138 | self.roads = {} 139 | self.stop_signs = {} 140 | self.crosswalks = {} 141 | self.speed_bumps = {} 142 | 143 | # static map features 144 | for map in map_features: 145 | map_type = map.WhichOneof("feature_data") 146 | map_id = map.id 147 | map = getattr(map, map_type) 148 | 149 | if map_type == 'lane': 150 | self.lanes[map_id] = map 151 | elif map_type == 'road_line' or map_type == 'road_edge': 152 | self.roads[map_id] = map 153 | elif map_type == 'stop_sign': 154 | self.stop_signs[map_id] = map 155 | elif map_type == 'crosswalk': 156 | self.crosswalks[map_id] = map 157 | elif map_type == 'speed_bump': 158 | self.speed_bumps[map_id] = map 159 | else: 160 | continue 161 | 162 | # dynamic map features 163 | self.traffic_signals = dynamic_map_states 164 | 165 | def map_process(self, traj, num_map, map_len, map_buffer, ind, goal=None): 166 | ''' 167 | Map point attributes 168 | self_point (x, y, h), left_boundary_point (x, y, h), right_boundary_pont (x, y, h), speed limit (float), 169 | self_type (int), left_boundary_type (int), right_boundary_type (int), 170 | traffic light (int), stop_point (bool), interpolating (bool), stop_sign (bool) 171 | ''' 172 | vectorized_map = np.zeros(shape=(num_map, map_len, 17)) 173 | vectorized_crosswalks = np.zeros(shape=(3, 100, 3)) 174 | agent_type = int(traj[-1][-1]) 175 | 176 | # get all lane polylines 177 | lane_polylines = get_polylines(self.lanes) 178 | 179 | # get all road lines and edges polylines 180 | road_polylines = get_polylines(self.roads) 181 | 182 | # find current lanes for the agent 183 | ref_lane_ids = find_reference_lanes(agent_type, traj, lane_polylines) 184 | 185 | # find candidate lanes 186 | ref_lanes = [] 187 | 188 | # get current lane's forward lanes 189 | for curr_lane, start in ref_lane_ids.items(): 190 | candidate = depth_first_search(curr_lane, self.lanes, 191 | dist=lane_polylines[curr_lane][start:].shape[0], threshold=300) 192 | ref_lanes.extend(candidate) 193 | 194 | if agent_type != 2: 195 | # find current lanes' left and right lanes 196 | neighbor_lane_ids = find_neighbor_lanes(ref_lane_ids, traj, self.lanes, lane_polylines) 197 | 198 | # get neighbor lane's forward lanes 199 | for neighbor_lane, start in neighbor_lane_ids.items(): 200 | candidate = depth_first_search(neighbor_lane, self.lanes, 201 | dist=lane_polylines[neighbor_lane][start:].shape[0], threshold=300) 202 | ref_lanes.extend(candidate) 203 | 204 | # update reference lane ids 205 | ref_lane_ids.update(neighbor_lane_ids) 206 | 207 | # remove overlapping lanes 208 | ref_lanes = remove_overlapping_lane_seq(ref_lanes) 209 | 210 | # get traffic light controlled lanes and stop sign controlled lanes 211 | traffic_light_lanes = {} 212 | stop_sign_lanes = [] 213 | 214 | for signal in self.traffic_signals[ind-1].lane_states: 215 | traffic_light_lanes[signal.lane] = (signal.state, signal.stop_point.x, signal.stop_point.y) 216 | for lane in self.lanes[signal.lane].entry_lanes: 217 | traffic_light_lanes[lane] = (signal.state, signal.stop_point.x, signal.stop_point.y) 218 | 219 | for i, sign in self.stop_signs.items(): 220 | stop_sign_lanes.extend(sign.lane) 221 | 222 | # add lanes to the array 223 | added_lanes = 0 224 | for i, s_lane in enumerate(ref_lanes): 225 | added_points = 0 226 | if i > num_map - 1: 227 | break 228 | 229 | # create a data cache 230 | cache_lane = np.zeros(shape=(map_buffer, 17)) 231 | 232 | for lane in s_lane: 233 | curr_index = ref_lane_ids[lane] if lane in ref_lane_ids else 0 234 | self_line = lane_polylines[lane][curr_index:] 235 | 236 | if added_points >= map_buffer: 237 | break 238 | 239 | # add info to the array 240 | for point in self_line: 241 | # self_point and type 242 | cache_lane[added_points, 0:3] = point 243 | cache_lane[added_points, 10] = self.lanes[lane].type 244 | 245 | # left_boundary_point and type 246 | for left_boundary in self.lanes[lane].left_boundaries: 247 | left_boundary_id = left_boundary.boundary_feature_id 248 | left_start = left_boundary.lane_start_index 249 | left_end = left_boundary.lane_end_index 250 | left_boundary_type = left_boundary.boundary_type # road line type 251 | if left_boundary_type == 0: 252 | left_boundary_type = self.roads[left_boundary_id].type + 8 # road edge type 253 | 254 | if left_start <= curr_index <= left_end: 255 | left_boundary_line = road_polylines[left_boundary_id] 256 | nearest_point = find_neareast_point(point, left_boundary_line) 257 | cache_lane[added_points, 3:6] = nearest_point 258 | cache_lane[added_points, 11] = left_boundary_type 259 | 260 | # right_boundary_point and type 261 | for right_boundary in self.lanes[lane].right_boundaries: 262 | right_boundary_id = right_boundary.boundary_feature_id 263 | right_start = right_boundary.lane_start_index 264 | right_end = right_boundary.lane_end_index 265 | right_boundary_type = right_boundary.boundary_type # road line type 266 | if right_boundary_type == 0: 267 | right_boundary_type = self.roads[right_boundary_id].type + 8 # road edge type 268 | 269 | if right_start <= curr_index <= right_end: 270 | right_boundary_line = road_polylines[right_boundary_id] 271 | nearest_point = find_neareast_point(point, right_boundary_line) 272 | cache_lane[added_points, 6:9] = nearest_point 273 | cache_lane[added_points, 12] = right_boundary_type 274 | 275 | # speed limit 276 | cache_lane[added_points, 9] = self.lanes[lane].speed_limit_mph / 2.237 277 | 278 | # interpolating 279 | cache_lane[added_points, 15] = self.lanes[lane].interpolating 280 | 281 | # traffic_light 282 | if lane in traffic_light_lanes.keys(): 283 | cache_lane[added_points, 13] = traffic_light_lanes[lane][0] 284 | if np.linalg.norm(traffic_light_lanes[lane][1:] - point[:2]) < 3: 285 | cache_lane[added_points, 14] = True 286 | 287 | # add stop sign 288 | if lane in stop_sign_lanes: 289 | cache_lane[added_points, 16] = True 290 | 291 | # count 292 | added_points += 1 293 | curr_index += 1 294 | 295 | if added_points >= map_buffer: 296 | break 297 | 298 | # scale the lane 299 | vectorized_map[i] = cache_lane[np.linspace(0, added_points, num=map_len, endpoint=False, dtype=np.int)] 300 | 301 | # count 302 | added_lanes += 1 303 | 304 | if goal is not None: 305 | dist_list = {} 306 | for i in range(vectorized_map.shape[0]): 307 | dist_list[i] = np.min(np.linalg.norm(vectorized_map[i ,:, :2] - goal, axis=-1)) 308 | sorted_inds = sorted(dist_list.items(), key=lambda item:item[1])[:self.ego_map] 309 | sorted_inds = [ind[0] for ind in sorted_inds] 310 | vectorized_map = vectorized_map[sorted_inds] 311 | 312 | 313 | # find surrounding crosswalks and add them to the array 314 | added_cross_walks = 0 315 | detection = Polygon([(0, -5), (50, -20), (50, 20), (0, 5)]) 316 | detection = affine_transform(detection, [1, 0, 0, 1, traj[-1][0], traj[-1][1]]) 317 | detection = rotate(detection, traj[-1][2], origin=(traj[-1][0], traj[-1][1]), use_radians=True) 318 | 319 | for _, crosswalk in self.crosswalks.items(): 320 | polygon = Polygon([(point.x, point.y) for point in crosswalk.polygon]) 321 | polyline = polygon_completion(crosswalk.polygon) 322 | polyline = polyline[np.linspace(0, polyline.shape[0], num=100, endpoint=False, dtype=np.int)] 323 | 324 | if detection.intersects(polygon): 325 | vectorized_crosswalks[added_cross_walks, :polyline.shape[0]] = polyline 326 | added_cross_walks += 1 327 | 328 | if added_cross_walks >= 3: 329 | break 330 | 331 | #map [3, 100, 17] ; crosswalk [4, 50, 3] 332 | vectorized_map = vectorized_map[:, 0::2, :] 333 | vectorized_crosswalks = vectorized_crosswalks[:, 0::2, :] 334 | 335 | return vectorized_map.astype(np.float32), vectorized_crosswalks.astype(np.float32) 336 | 337 | @staticmethod 338 | def ego_frame_dynamics(v, theta): 339 | ego_v = v.copy() 340 | ego_v[0] = v[0] * np.cos(theta) + v[1] * np.sin(theta) 341 | ego_v[1] = v[1] * np.cos(theta) - v[0] * np.sin(theta) 342 | 343 | return ego_v 344 | 345 | def dynamic_state_process(self, ind=11): 346 | traffic_light_lanes = {} 347 | stop_sign_lanes = [] 348 | 349 | for signal in self.traffic_signals[10].lane_states: 350 | traffic_light_lanes[signal.lane] = (signal.state, signal.stop_point.x, signal.stop_point.y) 351 | for lane in self.lanes[signal.lane].entry_lanes: 352 | traffic_light_lanes[lane] = (signal.state, signal.stop_point.x, signal.stop_point.y) 353 | 354 | for i, sign in self.stop_signs.items(): 355 | stop_sign_lanes.extend(sign.lane) 356 | 357 | return traffic_light_lanes, stop_sign_lanes 358 | 359 | def history_ogm_process(self, traj_tensor, valid_tensor, ego_traj, sdc_ids=None): 360 | timestep_grids, observed_valid, ego_occupancy = create_ground_truth_timestep_grids(traj_tensor, valid_tensor, 361 | ego_traj, self.input_config,flow=True, flow_origin=False,sdc_ids=sdc_ids) 362 | 363 | vehicle_flow = timestep_grids.vehicles.all_flow[:, :, 0, :].numpy().astype(np.int8) 364 | ped_flow = timestep_grids.pedestrians.all_flow[:, :, 0, :].numpy().astype(np.int8) 365 | cyc_flow = timestep_grids.cyclists.all_flow[:, :, 0, :].numpy().astype(np.int8) 366 | 367 | hist_flow = np.concatenate([vehicle_flow,ped_flow,cyc_flow],axis=-1) 368 | 369 | hist_vehicles, hist_pedestrians, hist_cyclists = timestep_grids.vehicles, timestep_grids.pedestrians, timestep_grids.cyclists 370 | hist_v_ogm = tf.concat([hist_vehicles.past_occupancy,hist_vehicles.current_occupancy],axis=-1) 371 | hist_p_ogm = tf.concat([hist_pedestrians.past_occupancy,hist_pedestrians.current_occupancy],axis=-1) 372 | hist_c_ogm = tf.concat([hist_cyclists.past_occupancy,hist_cyclists.current_occupancy],axis=-1) 373 | 374 | hist_ogm = tf.stack([hist_v_ogm , hist_p_ogm , hist_c_ogm], axis=-1).numpy().astype(np.bool_) 375 | ego_mask_hist = ego_occupancy[:, :, :11] 376 | 377 | return hist_ogm, hist_flow, ego_mask_hist 378 | 379 | def gt_ogm_process(self, traj_tensor, valid_tensor, ego_traj, sdc_ids=None, test=False): 380 | timestep_grids, observed_valid, ego_occupancy = create_ground_truth_timestep_grids(traj_tensor, valid_tensor, 381 | ego_traj, self.config,flow=True, flow_origin=False, sdc_ids=sdc_ids, test=test) 382 | if test: 383 | return None, None, None, observed_valid, None 384 | true_waypoints = create_ground_truth_waypoint_grids(timestep_grids, self.config, flow=True, flow_origin=False) 385 | 386 | gt_obs_v = tf.stack(true_waypoints.vehicles.observed_occupancy,axis=0) 387 | gt_obs_p = tf.stack(true_waypoints.pedestrians.observed_occupancy,axis=0) 388 | gt_obs_c = tf.stack(true_waypoints.cyclists.observed_occupancy,axis=0) 389 | 390 | gt_occ_v = tf.stack(true_waypoints.vehicles.occluded_occupancy,axis=0) 391 | gt_occ_p = tf.stack(true_waypoints.pedestrians.occluded_occupancy,axis=0) 392 | gt_occ_c = tf.stack(true_waypoints.cyclists.occluded_occupancy,axis=0) 393 | 394 | gt_obs = tf.stack([gt_obs_v, gt_obs_p, gt_obs_c],axis=-1).numpy().astype(np.bool_) 395 | gt_occ = tf.clip_by_value(gt_occ_v + gt_occ_p + gt_occ_c, 0, 1).numpy().astype(np.bool_) 396 | 397 | gt_flow = tf.stack(true_waypoints.vehicles.flow,axis=0) 398 | gt_flow_p = tf.stack(true_waypoints.pedestrians.flow,axis=0) 399 | gt_flow_c = tf.stack(true_waypoints.cyclists.flow,axis=0) 400 | 401 | gt_flow = tf.stack([gt_flow, gt_flow_p, gt_flow_c],axis=-1).numpy().astype(np.int8) 402 | 403 | ego_mask = tf.stack([ego_occupancy[:, :, 10 * (i + 2)] for i in range(self.future_len//10)], axis=0).numpy().astype(np.bool_) 404 | 405 | return gt_obs, gt_occ, gt_flow, observed_valid, ego_mask 406 | 407 | def traj_process(self, traj_tensor, valid_tensor, ego_current, observed_valid, occluded_valid, sdc_ids): 408 | """ 409 | 1. Process trajectories for all agents (pack_trajs) 410 | 2. filter agents in fov (observed(currently present) and occluded) 411 | 3. sort neighbors agents in fov separately (observed, occluded) 412 | """ 413 | 414 | #ego traj 415 | ego_tensor = traj_tensor[sdc_ids] 416 | ego_traj, gt_traj = ego_tensor[:self.hist_len], ego_tensor[self.hist_len:, :5] 417 | self.current_xyh = [ego_current[0], ego_current[1], ego_current[4]] 418 | #filtered agents 419 | observed_ids, occluded_ids = {}, {} 420 | observed_agents = np.zeros((self.num_observed, self.hist_len, 10)) 421 | if self.ol_test: 422 | observed_agents = np.zeros((self.num_observed, self.hist_len + self.future_len, 10)) 423 | #x, y coordinates 424 | observed_valid, occluded_valid = observed_valid, occluded_valid 425 | for i in range(traj_tensor.shape[0]): 426 | if i==sdc_ids: 427 | continue 428 | if observed_valid[i]==1: 429 | observed_ids[i] = traj_tensor[i, 10, :2] 430 | 431 | sorted_observed = sorted(observed_ids.items(), 432 | key=lambda item: np.linalg.norm(item[1] - self.current_xyh[:2]))[:self.num_observed] 433 | for i, obs in enumerate(sorted_observed): 434 | if self.ol_test: 435 | observed_agents[i] = traj_tensor[obs[0], :, :] 436 | else: 437 | observed_agents[i] = traj_tensor[obs[0], :self.hist_len, :] 438 | neighbor_traj = observed_agents 439 | return ego_traj.astype(np.float32), neighbor_traj.astype(np.float32), gt_traj.astype(np.float32) 440 | 441 | def route_process(self, sdc_id, timestep, cur_pos, tracks): 442 | # find reference paths according to the gt trajectory 443 | gt_path = tracks[sdc_id].states 444 | # remove rare cases 445 | try: 446 | route = find_route(gt_path, timestep, cur_pos, self.lanes, self.crosswalks, self.traffic_signals) 447 | except: 448 | return None 449 | 450 | ref_path = np.array(route, dtype=np.float32) 451 | 452 | if ref_path.shape[0] < 1200: 453 | repeated_last_point = np.repeat(ref_path[np.newaxis, -1], 1200-ref_path.shape[0], axis=0) 454 | ref_path = np.append(ref_path, repeated_last_point, axis=0) 455 | 456 | return ref_path 457 | 458 | 459 | def occupancy_process(self,traj_tensor, valid_tensor, ego_traj, sdc_ids=None, infer=False): 460 | """ 461 | process the historical and future occupancy 462 | """ 463 | timestep_grids, observed_valid, ego_occupancy = create_ground_truth_timestep_grids(traj_tensor, valid_tensor, ego_traj, self.config,flow=True, 464 | flow_origin=False,sdc_ids=sdc_ids) 465 | 466 | hist_vehicles, hist_pedestrians, hist_cyclists = timestep_grids.vehicles, timestep_grids.pedestrians, timestep_grids.cyclists 467 | 468 | hist_v_ogm = tf.concat([hist_vehicles.past_occupancy,hist_vehicles.current_occupancy],axis=-1) 469 | hist_p_ogm = tf.concat([hist_pedestrians.past_occupancy,hist_pedestrians.current_occupancy],axis=-1) 470 | hist_c_ogm = tf.concat([hist_cyclists.past_occupancy,hist_cyclists.current_occupancy],axis=-1) 471 | ego_hist = ego_occupancy[...,:self.hist_len].numpy().astype(np.bool_) 472 | 473 | #[h, w, 3] 474 | hist_ogm = tf.stack([hist_v_ogm, hist_p_ogm, hist_c_ogm], axis=-1).numpy().astype(np.bool_) 475 | 476 | if infer: 477 | return hist_ogm, np.zeros((5, 128, 128, 3)), np.zeros((5, 128, 128, 3)), observed_valid.numpy(), None, np.zeros((5, 128, 128, 1)), ego_hist 478 | 479 | true_waypoints = create_ground_truth_waypoint_grids(timestep_grids, self.config,flow=True) 480 | vehicles, pedestrians, cyclists = true_waypoints.vehicles, true_waypoints.pedestrians, true_waypoints.cyclists 481 | 482 | gt_obs_v = tf.stack(true_waypoints.vehicles.observed_occupancy,axis=0) 483 | gt_occ_v = tf.stack(true_waypoints.vehicles.occluded_occupancy,axis=0) 484 | 485 | gt_obs_p = tf.stack(true_waypoints.pedestrians.observed_occupancy,axis=0) 486 | gt_occ_p = tf.stack(true_waypoints.pedestrians.occluded_occupancy,axis=0) 487 | 488 | gt_obs_c = tf.stack(true_waypoints.cyclists.observed_occupancy,axis=0) 489 | gt_occ_c = tf.stack(true_waypoints.cyclists.occluded_occupancy,axis=0) 490 | 491 | gt_obs = tf.concat([gt_obs_v, gt_obs_p, gt_obs_c] ,axis=-1).numpy().astype(np.bool_) 492 | gt_occ = tf.concat([gt_occ_v, gt_occ_p, gt_occ_c] ,axis=-1).numpy().astype(np.bool_) 493 | 494 | ego_future = _ego_ground_truth_occupancy(ego_occupancy, self.config) 495 | ego_future = tf.stack(ego_future, axis=0).numpy().astype(np.bool_) 496 | 497 | return hist_ogm, gt_obs, gt_occ, observed_valid.numpy(), None, ego_future, ego_hist 498 | 499 | 500 | def load_open_loop_files(self): 501 | path = args.ol_dir 502 | data = pd.read_csv(path) 503 | self.test_scenario_ids = set(data['Scenario ID'].to_list()) 504 | 505 | def normalize_data(self, ego, neighbors, map_lanes, map_crosswalks, ego_map, ego_crosswalk, 506 | ground_truth, goal, ref_line, viz=True,sc_ids='', plan_res=None): 507 | # get the center and heading (local view) 508 | center, angle = self.current_xyh[:2], self.current_xyh[2] 509 | # normalize agent trajectories 510 | ego[:, :5] = agent_norm(ego, center, angle) 511 | ground_truth = agent_norm(ground_truth, center, angle) 512 | 513 | for i in range(neighbors.shape[0]): 514 | if neighbors[i, 10, 0] != 0: 515 | neighbors[i, :, :5] = agent_norm(neighbors[i, :], center, angle, impute=False) 516 | 517 | # normalize map points 518 | for i in range(map_lanes.shape[0]): 519 | lanes = map_lanes[i] 520 | crosswalks = map_crosswalks[i] 521 | 522 | for j in range(map_lanes.shape[1]): 523 | lane = lanes[j] 524 | if lane[0][0] != 0: 525 | lane[:, :9] = map_norm(lane, center, angle) 526 | 527 | for k in range(map_crosswalks.shape[1]): 528 | crosswalk = crosswalks[k] 529 | if crosswalk[0][0] != 0: 530 | crosswalk[:, :3] = map_norm(crosswalk, center, angle) 531 | 532 | for j in range(ego_map.shape[0]): 533 | lane = ego_map[j] 534 | if lane[0][0] != 0: 535 | lane[:, :9] = map_norm(lane, center, angle) 536 | 537 | for k in range(ego_crosswalk.shape[0]): 538 | crosswalk = ego_crosswalk[k] 539 | if crosswalk[0][0] != 0: 540 | crosswalk[:, :3] = map_norm(crosswalk, center, angle) 541 | 542 | plan_lines = np.zeros_like(ref_line) 543 | 544 | ref_line = ref_line_norm(ref_line, center, angle).astype(np.float32) 545 | 546 | goal = goal_norm(goal, center, angle) 547 | 548 | # visulization 549 | if viz: 550 | plt.figure() 551 | for i in range(map_lanes.shape[0]): 552 | lanes = map_lanes[i] 553 | crosswalks = map_crosswalks[i] 554 | for j in range(map_lanes.shape[1]): 555 | lane = lanes[j] 556 | if lane[0][0] != 0: 557 | centerline = lane[:, 0:2] 558 | centerline = centerline[centerline[:, 0] != 0] 559 | 560 | for k in range(map_crosswalks.shape[1]): 561 | crosswalk = crosswalks[k] 562 | if crosswalk[0][0] != 0: 563 | crosswalk = crosswalk[crosswalk[:, 0] != 0] 564 | plt.plot(crosswalk[:, 0], crosswalk[:, 1], 'b', linewidth=4) # plot crosswalk 565 | 566 | for j in range(ego_map.shape[0]): 567 | lane = ego_map[j] 568 | if lane[0][0] != 0: 569 | centerline = lane[:, 0:2] 570 | centerline = centerline[centerline[:, 0] != 0] 571 | left = lane[:, 3:5] 572 | left = left[left[:, 0] != 0] 573 | right = lane[:, 6:8] 574 | right = right[right[:, 0] != 0] 575 | plt.plot(centerline[:, 0], centerline[:, 1],'k', linewidth=1) # plot centerline 576 | 577 | for k in range(ego_crosswalk.shape[0]): 578 | crosswalk = ego_crosswalk[k] 579 | if crosswalk[0][0] != 0: 580 | crosswalk = crosswalk[crosswalk[:, 0] != 0] 581 | plt.plot(crosswalk[:, 0], crosswalk[:, 1], 'b', linewidth=4) # plot crosswalk 582 | 583 | 584 | rect = plt.Rectangle((ego[-1, 0]-ego[-1, 6]/2, ego[-1, 1]-ego[-1, 7]/2), ego[-1, 6], ego[-1, 7], linewidth=2, color='r', alpha=0.6, zorder=3, 585 | transform=mpl.transforms.Affine2D().rotate_around(*(ego[-1, 0], ego[-1, 1]), ego[-1, 2]) + plt.gca().transData) 586 | plt.gca().add_patch(rect) 587 | plt.plot(ref_line[:, 0][ref_line[:, 0]!=0], ref_line[:, 1][ref_line[:, 0]!=0], 'y', linewidth=2, zorder=4) 588 | 589 | plt.plot(ego[:, 0], ego[:, 1],'royalblue' ,linewidth=3,zorder=3) 590 | future = ground_truth[ground_truth[:, 0] != 0] 591 | plt.plot(future[:, 0], future[:, 1], 'r', linewidth=3, zorder=3) 592 | color_map=['purple','brown','pink','olive','gold','royalblue'] 593 | 594 | for i in range(neighbors.shape[0]): 595 | if neighbors[i, 10, 0] != 0: 596 | rect = plt.Rectangle((neighbors[i, 10, 0]-neighbors[i, 10, 6]/2, neighbors[i, 10, 1]-neighbors[i, 10, 7]/2), 597 | neighbors[i, 10, 6], neighbors[i, 10, 7], linewidth=2, color='m', alpha=0.6, zorder=3, 598 | transform=mpl.transforms.Affine2D().rotate_around(*(neighbors[i, 10, 0], neighbors[i, 10, 1]), neighbors[i, 10, 2]) + plt.gca().transData) 599 | plt.gca().add_patch(rect) 600 | mask = neighbors[i, :, 0] + neighbors[i, :, 1] != 0 601 | plt.plot(neighbors[i, mask, 0], neighbors[i, mask, 1], 'm', linewidth=1, zorder=3) 602 | 603 | if plan_res is not None: 604 | plt.plot(plan_res[:, 0], plan_res[:, 1], 'c', linewidth=3,zorder=5) 605 | plt.scatter(plan_res[9::10, 0], plan_res[9::10, 1], 10, 'c',zorder=5) 606 | circle = plt.Circle([goal[0], goal[1]],color='r') 607 | plt.gca().add_patch(circle) 608 | plt.gca().set_aspect('equal') 609 | plt.tight_layout() 610 | plt.show(block=False) 611 | plt.pause(1) 612 | plt.close() 613 | 614 | return ego, neighbors, map_lanes, map_crosswalks, ego_map, ego_crosswalk ,ref_line, ground_truth, goal, plan_lines 615 | 616 | def data_process(self,vis=False): 617 | for data_file in self.data_files: 618 | dataset = tf.data.TFRecordDataset(data_file) 619 | sample_index = [i for i in range(self.hist_len, self.timestep-self.future_len, self.gap)] 620 | if len(sample_index) ==0: 621 | sample_index = [11] 622 | total_len = len(list(dataset))*len(sample_index) 623 | print(f"Processing {data_file.split('/')[-1]}", total_len) 624 | start_time = time.time() 625 | current = 1 626 | for data in dataset: 627 | parsed_data = scenario_pb2.Scenario() 628 | parsed_data.ParseFromString(data.numpy()) 629 | 630 | scenario_id = parsed_data.scenario_id 631 | if self.ol_test: 632 | if scenario_id not in self.test_scenario_ids: 633 | continue 634 | 635 | sdc_id = parsed_data.sdc_track_index 636 | time_len = len(parsed_data.tracks[sdc_id].states) 637 | self.build_map(parsed_data.map_features, parsed_data.dynamic_map_states) 638 | 639 | traj_tensor, valid_tensor, goal = pack_trajs(parsed_data, self.timestep) 640 | sdc_id = parsed_data.sdc_track_index 641 | cnt = 0 642 | for ind in sample_index: 643 | #slice current points 644 | traj_window, valid_window = traj_tensor[:, ind-self.hist_len:ind+self.future_len, :], valid_tensor[:, ind-self.hist_len:ind+self.future_len, :] 645 | 646 | ego_current = traj_tensor[sdc_id, ind] 647 | 648 | ego_goal_dist = np.linalg.norm(ego_current.numpy()[:2] - np.array(goal)) 649 | 650 | if ego_goal_dist < 3 and cnt < 5: 651 | continue 652 | 653 | traffic_light_lanes, stop_sign_lanes = self.dynamic_state_process() 654 | map_dict, org_maps = pack_maps(self.lanes, self.roads, self.crosswalks, traffic_light_lanes, 655 | stop_sign_lanes, self.input_config, [ego_current[0], ego_current[1], ego_current[4]]) 656 | 657 | hist_ogm, hist_flow,ego_ogm = self.history_ogm_process(traj_window, valid_window, ego_current, sdc_id) 658 | gt_obs, gt_occ, gt_flow, observed_valid, ego_ogm_gt = self.gt_ogm_process(traj_window, valid_window, ego_current, sdc_id, False) 659 | 660 | #vectorized trajs for agents in fovs: 661 | rg = render_roadgraph_tf(org_maps).numpy().astype(np.uint8) 662 | 663 | ego_traj, neighbor_traj, gt_traj = self.traj_process(traj_window.numpy(), valid_window.numpy(), ego_current.numpy(), observed_valid, None, sdc_id) 664 | ref_line = self.route_process(sdc_id, ind, self.current_xyh, parsed_data.tracks) 665 | if ref_line is None: 666 | continue 667 | 668 | #agents map and crosswalks: 669 | neighbor_map_lanes = np.zeros((self.num_observed, self.num_map, self.map_len//2, 17)) 670 | neighbor_map_crosswalks = np.zeros((self.num_observed, 3, 50, 3)) 671 | ego_map_lane, ego_map_crosswalk = self.map_process(ego_traj, self.ego_map*2, self.ego_map_len, self.ego_buffer, ind, goal) 672 | 673 | 674 | for i in range(neighbor_traj.shape[0]): 675 | if neighbor_traj[i, -1, 0] != 0: 676 | neighbor_map_lanes[i], neighbor_map_crosswalks[i] = self.map_process(neighbor_traj[i], self.num_map, self.map_len, self.map_buffer, ind) 677 | 678 | #normalize all 679 | self.sc_ids = f'{scenario_id}_{ind}' 680 | 681 | ego, neighbors, neighbor_map_lanes, neighbor_map_crosswalks, ego_map_lane, ego_map_crosswalk , ref_line, ground_truth, new_goal, plan_lines = self.normalize_data(ego_traj, neighbor_traj, 682 | neighbor_map_lanes, neighbor_map_crosswalks, ego_map_lane,ego_map_crosswalk, gt_traj, goal, ref_line, vis, f'{scenario_id}_{ind}') 683 | 684 | sys.stdout.write(f"\rProcessing{data_file.split('/')[-1]}|length:{current}/{total_len}|{(time.time()-start_time)/current:>.4f}s/sample") 685 | sys.stdout.flush() 686 | current += 1 687 | cnt += 1 688 | 689 | filename = self.save_dir + f"{scenario_id}_{ind}.npz" 690 | np.savez(filename, ego=ego, neighbors=neighbors, neighbor_map_lanes=neighbor_map_lanes, 691 | ego_map_lane=ego_map_lane,ego_map_crosswalk=ego_map_crosswalk, 692 | neighbor_map_crosswalks=neighbor_map_crosswalks, gt_future_states=ground_truth, 693 | ref_line=ref_line,goal=new_goal,hist_ogm=hist_ogm, gt_obs=gt_obs, gt_occ=gt_occ, 694 | ego_ogm=ego_ogm,ego_ogm_gt=ego_ogm_gt,rg=rg,hist_flow=hist_flow,gt_flow=gt_flow 695 | ) 696 | 697 | def process_ol_test_data(data_files): 698 | processor = Processor(data_files=data_files,height=128,width=128,gap=10,ref_max_len=1200,ego_map=3,future_len=50, 699 | save_dir=args.save_dir+'/open_loop_test2/',ol_test=True, timestep=91) 700 | processor.load_open_loop_files() 701 | processor.data_process(vis=False) 702 | print(f'{data_files}-done!') 703 | with open(args.save_dir+'ol_log.txt','a') as writer: 704 | writer.write(data_files+'\n') 705 | 706 | def process_training_data(data_files): 707 | for data_file in data_files: 708 | processor = Processor(data_files=data_file,height=128,width=128,gap=10,ref_max_len=1200,ego_map=3,future_len=50, 709 | save_dir=args.save_dir+'train/', timestep=199) 710 | processor.data_process(vis=False) 711 | print(f'{data_file},training_done!') 712 | with open(args.save_dir+'train_log.txt','a') as writer: 713 | writer.write(data_file+'\n') 714 | 715 | def process_validation_data(data_files): 716 | # for data_file in data_files: 717 | processor = Processor(data_files=data_files,height=128,width=128,gap=10,ref_max_len=1200,ego_map=3,future_len=50, 718 | save_dir=args.save_dir+'valid/',timestep=91) 719 | processor.data_process(vis=False) 720 | print(f'{data_files},_done!') 721 | with open(args.save_dir+'val_log.txt','a') as writer: 722 | writer.write(data_files+'\n') 723 | 724 | def file_slice(files, n): 725 | return [files[i:i+len(files)//n] for i in range(0, len(files), len(files)//n)] 726 | 727 | def process_map(func, files, n): 728 | sliced_files = file_slice(files, n) 729 | 730 | process_list = [] 731 | for i in range(n): 732 | p = Process(target=func, args=(sliced_files[i],)) 733 | process_list.append(p) 734 | for i in range(n): 735 | process_list[i].start() 736 | for i in range(n): 737 | process_list[i].join() 738 | 739 | 740 | if __name__ == "__main__": 741 | parser = argparse.ArgumentParser() 742 | parser.add_argument("--processes", type=int,default=16) 743 | parser.add_argument("--root_dir", type=str,default='', help='path to load original Waymo Datasets') 744 | parser.add_argument("--save_dir", type=str,default='', help='path to save processed datasets') 745 | parser.add_argument("--ol_dir", type=str,default='', help='path for open loop test ids csv (optional)') 746 | args = parser.parse_args() 747 | processes = args.processes 748 | 749 | #Hint: processing full train set is time consuming, so you may sample a ratio for training 750 | train_root_dir = args.root_dir + 'training_20s/' 751 | train_list = glob(train_root_dir+'*') 752 | process_map(process_training_data, train_list, processes) 753 | 754 | # randomly select half portions of val sets: 755 | val_root_dir = args.root_dir + 'validation/' 756 | val_list =[ 757 | val_root_dir + 'validation.tfrecord-' + "%05d" % i + '-of-00150' 758 | for i in random.sample(range(150), 75) 759 | ] 760 | 761 | with Pool(processes=processes) as p: 762 | p.map(process_validation_data, val_list) 763 | 764 | # process a sample or open-loop test, you may also directly employ the sampled val set for ol-testing 765 | if args.ol_dir != '': 766 | full_val_list = glob(val_root_dir + '*') 767 | with Pool(processes=processes) as p: 768 | p.map(process_ol_test_data, full_val_list) --------------------------------------------------------------------------------