├── .github └── FUNDING.yml ├── .gitignore ├── README.md ├── docs ├── DATA_PREP.md ├── INSTALL.md └── TRAIN_EVAL.md ├── preprocess ├── map_pb2.py ├── preprocess_waymo.py └── scenario_pb2.py ├── requirements.txt ├── src └── pipeline.PNG └── training ├── model.py ├── train.py └── waymo_dataset.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | dataset/waymo/* 2 | *tfrecord* 3 | *.pkl 4 | *logs* 5 | *__pycache__* 6 | *.log 7 | *.out 8 | *.tar.bz2 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HDGT: Modeling the Driving Scene with Heterogenity and Relativity 2 | 3 | > **HDGT: Heterogeneous Driving Graph Transformer for Multi-Agent Trajectory Prediction via Scene Encoding** [IEEE TPAMI 2023] 4 | >![pipeline](src/pipeline.PNG) 5 | > - [Paper](http://arxiv.org/abs/2205.09753) 6 | 7 | ## Introduction 8 | 9 | HDGT is an unified heterogeneous transformer-based graph neural network for driving scene encoding. It is a **SOTA method** on [INTERACTION](http://challenge.interaction-dataset.com/leader-board) and [Waymo](https://waymo.com/open/challenges/2021/motion-prediction/) Motion Prediction Chanllege. 10 | 11 | By time of release in April 2022, the proposed method achieves new state-of-the-art on INTERACTION Prediction Challenge and Waymo Open Motion Challenge, in which we rank the **first** and **second** respectively in terms of the minADE/minFDE metric. 12 | 13 | ## Getting Started 14 | 15 | - [Installation](docs/INSTALL.md) 16 | - [Prepare Dataset](docs/DATA_PREP.md) 17 | - [Train & Evaluation](docs/TRAIN_EVAL.md) 18 | 19 | 20 | ## License 21 | 22 | All assets and code are under the [Apache 2.0 license](./LICENSE) unless specified otherwise. 23 | 24 | ## Bibtex 25 | If this work is helpful for your research, please consider citing the following BibTeX entry. 26 | 27 | ``` 28 | @article{jia2023hdgt, 29 | title={HDGT: Heterogeneous Driving Graph Transformer for Multi-Agent Trajectory Prediction via Scene Encoding}, 30 | author={Jia, Xiaosong and Wu, Penghao and Chen, Li and Liu, Yu and Li, Hongyang and Yan, Junchi}, 31 | journal = {IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI)}, 32 | year = {2023}, 33 | } 34 | ``` 35 | 36 | ``` 37 | @inproceedings{jia2022temporal, 38 | title={Towards Capturing the Temporal Dynamics for Trajectory Prediction: a Coarse-to-Fine Approach}, 39 | author={Jia, Xiaosong and Chen, Li and Wu, Penghao and Zeng, Jia and Yan, Junchi and Li, Hongyang and Qiao, Yu}, 40 | booktitle={CoRL}, 41 | year={2022} 42 | } 43 | ``` 44 | 45 | 46 | -------------------------------------------------------------------------------- /docs/DATA_PREP.md: -------------------------------------------------------------------------------- 1 | # Prepare Dataset 2 | 3 | ## Download the Dataset 4 | You should prepare ~2TB disk space in total to run HDGT in Waymo. First, download the Waymo Open Motion Dataset from their official website. The folder structure should be: 5 | 6 | HDGT/ 7 | ├── dataset 8 | ├── waymo 9 | ├── training 10 | ├── validation 11 | ├── ... 12 | |-preprocess 13 | ├── preprocess_waymo.py 14 | ├── ... 15 | Note that you could learn about the linux command *ln -s* to avoid copy the huge Waymo dataset around. 16 | 17 | 18 | ## Preprocess the Dataset 19 | Then, we preprocess these tfrecords so that each scene (91 steps) is saved as one pickle file. To parallelly preprocess files, we split all train tfrecords into 12 parts by index (the validation is just 1 part) and thus we could run 13 process to conduct preprocessing: 20 | ```shell 21 | ## In the HDGT/preprocess directory 22 | preprocess_data_folder_name=hdgt_waymo 23 | for i in {0..12} 24 | do 25 | nohup python preprocess_waymo.py $i $preprocess_data_folder_name 2>&1 > preprocess_$i.log & 26 | done 27 | ``` 28 | It could take ~12 hours and ~700 GB disk space. -------------------------------------------------------------------------------- /docs/INSTALL.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | HDGT was developed under a certain version of [DGL](https://www.dgl.ai/) after which DGL has made a major refactor. Thus, we suggest using exactly the same environment as provided below to avoid any issues: 3 | 4 | ```shell 5 | conda create -n hdgt python=3.8 6 | conda activate hdgt 7 | 8 | pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu113/torch_stable.html 9 | pip install tensorflow tensorboard 10 | 11 | wget https://anaconda.org/dglteam/dgl-cuda11.3/0.7.2/download/linux-64/dgl-cuda11.3-0.7.2-py38_0.tar.bz2 12 | conda install --use-local dgl-cuda11.3-0.7.2-py38_0.tar.bz2 -y 13 | conda install protobuf=3.20 -y 14 | 15 | pip install -r requirements.txt 16 | ``` -------------------------------------------------------------------------------- /docs/TRAIN_EVAL.md: -------------------------------------------------------------------------------- 1 | ## Training & Evaluation 2 | 3 | We train our model with 8 GPUs in 4-5 dayas on Waymo Open Motion Dataset with the following command line: 4 | 5 | ``` 6 | ##Use all available GPUs by default 7 | 8 | python train.py --n_epoch 30 --batch_size 16 --val_batch_size 128 --ddp_mode True --port 31253 --name hdgt_refine --amp bf16 9 | ``` 10 | 11 | - batch_size=16 is ok for 3090/V100/A100 while the bfloat16 is only available for 3090/A100. You could remove "--amp bf16" when your GPU is not 3090 or A100. 12 | - Our code adopts Pytorch DDP by manually spawning mutliple processes. 13 | - The results on validation set are around **ade6 0.5806, fde6 1.1757, mr6 0.1495**. Better results could be achieved by adopting more data augmentation tricks, heavy regularization, and longer epoch. 14 | - This codebase is a mix of our TPAMI paper, CoRL paper, and some modifications used during competition. 15 | 16 | -------------------------------------------------------------------------------- /preprocess/map_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: map.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | 15 | 16 | DESCRIPTOR = _descriptor.FileDescriptor( 17 | name='map.proto', 18 | package='waymo.open_dataset', 19 | syntax='proto2', 20 | serialized_options=None, 21 | create_key=_descriptor._internal_create_key, 22 | serialized_pb=b'\n\tmap.proto\x12\x12waymo.open_dataset\"u\n\x03Map\x12\x34\n\x0cmap_features\x18\x01 \x03(\x0b\x32\x1e.waymo.open_dataset.MapFeature\x12\x38\n\x0e\x64ynamic_states\x18\x02 \x03(\x0b\x32 .waymo.open_dataset.DynamicState\"j\n\x0c\x44ynamicState\x12\x19\n\x11timestamp_seconds\x18\x01 \x01(\x01\x12?\n\x0blane_states\x18\x02 \x03(\x0b\x32*.waymo.open_dataset.TrafficSignalLaneState\"\x8c\x03\n\x16TrafficSignalLaneState\x12\x0c\n\x04lane\x18\x01 \x01(\x03\x12?\n\x05state\x18\x02 \x01(\x0e\x32\x30.waymo.open_dataset.TrafficSignalLaneState.State\x12\x30\n\nstop_point\x18\x03 \x01(\x0b\x32\x1c.waymo.open_dataset.MapPoint\"\xf0\x01\n\x05State\x12\x16\n\x12LANE_STATE_UNKNOWN\x10\x00\x12\x19\n\x15LANE_STATE_ARROW_STOP\x10\x01\x12\x1c\n\x18LANE_STATE_ARROW_CAUTION\x10\x02\x12\x17\n\x13LANE_STATE_ARROW_GO\x10\x03\x12\x13\n\x0fLANE_STATE_STOP\x10\x04\x12\x16\n\x12LANE_STATE_CAUTION\x10\x05\x12\x11\n\rLANE_STATE_GO\x10\x06\x12\x1c\n\x18LANE_STATE_FLASHING_STOP\x10\x07\x12\x1f\n\x1bLANE_STATE_FLASHING_CAUTION\x10\x08\"\xda\x02\n\nMapFeature\x12\n\n\x02id\x18\x01 \x01(\x03\x12.\n\x04lane\x18\x03 \x01(\x0b\x32\x1e.waymo.open_dataset.LaneCenterH\x00\x12\x31\n\troad_line\x18\x04 \x01(\x0b\x32\x1c.waymo.open_dataset.RoadLineH\x00\x12\x31\n\troad_edge\x18\x05 \x01(\x0b\x32\x1c.waymo.open_dataset.RoadEdgeH\x00\x12\x31\n\tstop_sign\x18\x07 \x01(\x0b\x32\x1c.waymo.open_dataset.StopSignH\x00\x12\x32\n\tcrosswalk\x18\x08 \x01(\x0b\x32\x1d.waymo.open_dataset.CrosswalkH\x00\x12\x33\n\nspeed_bump\x18\t \x01(\x0b\x32\x1d.waymo.open_dataset.SpeedBumpH\x00\x42\x0e\n\x0c\x66\x65\x61ture_data\"+\n\x08MapPoint\x12\t\n\x01x\x18\x01 \x01(\x01\x12\t\n\x01y\x18\x02 \x01(\x01\x12\t\n\x01z\x18\x03 \x01(\x01\"\xa2\x01\n\x0f\x42oundarySegment\x12\x18\n\x10lane_start_index\x18\x01 \x01(\x05\x12\x16\n\x0elane_end_index\x18\x02 \x01(\x05\x12\x1b\n\x13\x62oundary_feature_id\x18\x03 \x01(\x03\x12@\n\rboundary_type\x18\x04 \x01(\x0e\x32).waymo.open_dataset.RoadLine.RoadLineType\"\xc7\x01\n\x0cLaneNeighbor\x12\x12\n\nfeature_id\x18\x01 \x01(\x03\x12\x18\n\x10self_start_index\x18\x02 \x01(\x05\x12\x16\n\x0eself_end_index\x18\x03 \x01(\x05\x12\x1c\n\x14neighbor_start_index\x18\x04 \x01(\x05\x12\x1a\n\x12neighbor_end_index\x18\x05 \x01(\x05\x12\x37\n\nboundaries\x18\x06 \x03(\x0b\x32#.waymo.open_dataset.BoundarySegment\"\xa5\x04\n\nLaneCenter\x12\x17\n\x0fspeed_limit_mph\x18\x01 \x01(\x01\x12\x35\n\x04type\x18\x02 \x01(\x0e\x32\'.waymo.open_dataset.LaneCenter.LaneType\x12\x15\n\rinterpolating\x18\x03 \x01(\x08\x12.\n\x08polyline\x18\x08 \x03(\x0b\x32\x1c.waymo.open_dataset.MapPoint\x12\x17\n\x0b\x65ntry_lanes\x18\t \x03(\x03\x42\x02\x10\x01\x12\x16\n\nexit_lanes\x18\n \x03(\x03\x42\x02\x10\x01\x12<\n\x0fleft_boundaries\x18\r \x03(\x0b\x32#.waymo.open_dataset.BoundarySegment\x12=\n\x10right_boundaries\x18\x0e \x03(\x0b\x32#.waymo.open_dataset.BoundarySegment\x12\x38\n\x0eleft_neighbors\x18\x0b \x03(\x0b\x32 .waymo.open_dataset.LaneNeighbor\x12\x39\n\x0fright_neighbors\x18\x0c \x03(\x0b\x32 .waymo.open_dataset.LaneNeighbor\"]\n\x08LaneType\x12\x12\n\x0eTYPE_UNDEFINED\x10\x00\x12\x10\n\x0cTYPE_FREEWAY\x10\x01\x12\x17\n\x13TYPE_SURFACE_STREET\x10\x02\x12\x12\n\x0eTYPE_BIKE_LANE\x10\x03\"\xcd\x01\n\x08RoadEdge\x12\x37\n\x04type\x18\x01 \x01(\x0e\x32).waymo.open_dataset.RoadEdge.RoadEdgeType\x12.\n\x08polyline\x18\x02 \x03(\x0b\x32\x1c.waymo.open_dataset.MapPoint\"X\n\x0cRoadEdgeType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x1b\n\x17TYPE_ROAD_EDGE_BOUNDARY\x10\x01\x12\x19\n\x15TYPE_ROAD_EDGE_MEDIAN\x10\x02\"\x88\x03\n\x08RoadLine\x12\x37\n\x04type\x18\x01 \x01(\x0e\x32).waymo.open_dataset.RoadLine.RoadLineType\x12.\n\x08polyline\x18\x02 \x03(\x0b\x32\x1c.waymo.open_dataset.MapPoint\"\x92\x02\n\x0cRoadLineType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\x1c\n\x18TYPE_BROKEN_SINGLE_WHITE\x10\x01\x12\x1b\n\x17TYPE_SOLID_SINGLE_WHITE\x10\x02\x12\x1b\n\x17TYPE_SOLID_DOUBLE_WHITE\x10\x03\x12\x1d\n\x19TYPE_BROKEN_SINGLE_YELLOW\x10\x04\x12\x1d\n\x19TYPE_BROKEN_DOUBLE_YELLOW\x10\x05\x12\x1c\n\x18TYPE_SOLID_SINGLE_YELLOW\x10\x06\x12\x1c\n\x18TYPE_SOLID_DOUBLE_YELLOW\x10\x07\x12\x1e\n\x1aTYPE_PASSING_DOUBLE_YELLOW\x10\x08\"H\n\x08StopSign\x12\x0c\n\x04lane\x18\x01 \x03(\x03\x12.\n\x08position\x18\x02 \x01(\x0b\x32\x1c.waymo.open_dataset.MapPoint\":\n\tCrosswalk\x12-\n\x07polygon\x18\x01 \x03(\x0b\x32\x1c.waymo.open_dataset.MapPoint\":\n\tSpeedBump\x12-\n\x07polygon\x18\x01 \x03(\x0b\x32\x1c.waymo.open_dataset.MapPoint' 23 | ) 24 | 25 | 26 | 27 | _TRAFFICSIGNALLANESTATE_STATE = _descriptor.EnumDescriptor( 28 | name='State', 29 | full_name='waymo.open_dataset.TrafficSignalLaneState.State', 30 | filename=None, 31 | file=DESCRIPTOR, 32 | create_key=_descriptor._internal_create_key, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='LANE_STATE_UNKNOWN', index=0, number=0, 36 | serialized_options=None, 37 | type=None, 38 | create_key=_descriptor._internal_create_key), 39 | _descriptor.EnumValueDescriptor( 40 | name='LANE_STATE_ARROW_STOP', index=1, number=1, 41 | serialized_options=None, 42 | type=None, 43 | create_key=_descriptor._internal_create_key), 44 | _descriptor.EnumValueDescriptor( 45 | name='LANE_STATE_ARROW_CAUTION', index=2, number=2, 46 | serialized_options=None, 47 | type=None, 48 | create_key=_descriptor._internal_create_key), 49 | _descriptor.EnumValueDescriptor( 50 | name='LANE_STATE_ARROW_GO', index=3, number=3, 51 | serialized_options=None, 52 | type=None, 53 | create_key=_descriptor._internal_create_key), 54 | _descriptor.EnumValueDescriptor( 55 | name='LANE_STATE_STOP', index=4, number=4, 56 | serialized_options=None, 57 | type=None, 58 | create_key=_descriptor._internal_create_key), 59 | _descriptor.EnumValueDescriptor( 60 | name='LANE_STATE_CAUTION', index=5, number=5, 61 | serialized_options=None, 62 | type=None, 63 | create_key=_descriptor._internal_create_key), 64 | _descriptor.EnumValueDescriptor( 65 | name='LANE_STATE_GO', index=6, number=6, 66 | serialized_options=None, 67 | type=None, 68 | create_key=_descriptor._internal_create_key), 69 | _descriptor.EnumValueDescriptor( 70 | name='LANE_STATE_FLASHING_STOP', index=7, number=7, 71 | serialized_options=None, 72 | type=None, 73 | create_key=_descriptor._internal_create_key), 74 | _descriptor.EnumValueDescriptor( 75 | name='LANE_STATE_FLASHING_CAUTION', index=8, number=8, 76 | serialized_options=None, 77 | type=None, 78 | create_key=_descriptor._internal_create_key), 79 | ], 80 | containing_type=None, 81 | serialized_options=None, 82 | serialized_start=417, 83 | serialized_end=657, 84 | ) 85 | _sym_db.RegisterEnumDescriptor(_TRAFFICSIGNALLANESTATE_STATE) 86 | 87 | _LANECENTER_LANETYPE = _descriptor.EnumDescriptor( 88 | name='LaneType', 89 | full_name='waymo.open_dataset.LaneCenter.LaneType', 90 | filename=None, 91 | file=DESCRIPTOR, 92 | create_key=_descriptor._internal_create_key, 93 | values=[ 94 | _descriptor.EnumValueDescriptor( 95 | name='TYPE_UNDEFINED', index=0, number=0, 96 | serialized_options=None, 97 | type=None, 98 | create_key=_descriptor._internal_create_key), 99 | _descriptor.EnumValueDescriptor( 100 | name='TYPE_FREEWAY', index=1, number=1, 101 | serialized_options=None, 102 | type=None, 103 | create_key=_descriptor._internal_create_key), 104 | _descriptor.EnumValueDescriptor( 105 | name='TYPE_SURFACE_STREET', index=2, number=2, 106 | serialized_options=None, 107 | type=None, 108 | create_key=_descriptor._internal_create_key), 109 | _descriptor.EnumValueDescriptor( 110 | name='TYPE_BIKE_LANE', index=3, number=3, 111 | serialized_options=None, 112 | type=None, 113 | create_key=_descriptor._internal_create_key), 114 | ], 115 | containing_type=None, 116 | serialized_options=None, 117 | serialized_start=1877, 118 | serialized_end=1970, 119 | ) 120 | _sym_db.RegisterEnumDescriptor(_LANECENTER_LANETYPE) 121 | 122 | _ROADEDGE_ROADEDGETYPE = _descriptor.EnumDescriptor( 123 | name='RoadEdgeType', 124 | full_name='waymo.open_dataset.RoadEdge.RoadEdgeType', 125 | filename=None, 126 | file=DESCRIPTOR, 127 | create_key=_descriptor._internal_create_key, 128 | values=[ 129 | _descriptor.EnumValueDescriptor( 130 | name='TYPE_UNKNOWN', index=0, number=0, 131 | serialized_options=None, 132 | type=None, 133 | create_key=_descriptor._internal_create_key), 134 | _descriptor.EnumValueDescriptor( 135 | name='TYPE_ROAD_EDGE_BOUNDARY', index=1, number=1, 136 | serialized_options=None, 137 | type=None, 138 | create_key=_descriptor._internal_create_key), 139 | _descriptor.EnumValueDescriptor( 140 | name='TYPE_ROAD_EDGE_MEDIAN', index=2, number=2, 141 | serialized_options=None, 142 | type=None, 143 | create_key=_descriptor._internal_create_key), 144 | ], 145 | containing_type=None, 146 | serialized_options=None, 147 | serialized_start=2090, 148 | serialized_end=2178, 149 | ) 150 | _sym_db.RegisterEnumDescriptor(_ROADEDGE_ROADEDGETYPE) 151 | 152 | _ROADLINE_ROADLINETYPE = _descriptor.EnumDescriptor( 153 | name='RoadLineType', 154 | full_name='waymo.open_dataset.RoadLine.RoadLineType', 155 | filename=None, 156 | file=DESCRIPTOR, 157 | create_key=_descriptor._internal_create_key, 158 | values=[ 159 | _descriptor.EnumValueDescriptor( 160 | name='TYPE_UNKNOWN', index=0, number=0, 161 | serialized_options=None, 162 | type=None, 163 | create_key=_descriptor._internal_create_key), 164 | _descriptor.EnumValueDescriptor( 165 | name='TYPE_BROKEN_SINGLE_WHITE', index=1, number=1, 166 | serialized_options=None, 167 | type=None, 168 | create_key=_descriptor._internal_create_key), 169 | _descriptor.EnumValueDescriptor( 170 | name='TYPE_SOLID_SINGLE_WHITE', index=2, number=2, 171 | serialized_options=None, 172 | type=None, 173 | create_key=_descriptor._internal_create_key), 174 | _descriptor.EnumValueDescriptor( 175 | name='TYPE_SOLID_DOUBLE_WHITE', index=3, number=3, 176 | serialized_options=None, 177 | type=None, 178 | create_key=_descriptor._internal_create_key), 179 | _descriptor.EnumValueDescriptor( 180 | name='TYPE_BROKEN_SINGLE_YELLOW', index=4, number=4, 181 | serialized_options=None, 182 | type=None, 183 | create_key=_descriptor._internal_create_key), 184 | _descriptor.EnumValueDescriptor( 185 | name='TYPE_BROKEN_DOUBLE_YELLOW', index=5, number=5, 186 | serialized_options=None, 187 | type=None, 188 | create_key=_descriptor._internal_create_key), 189 | _descriptor.EnumValueDescriptor( 190 | name='TYPE_SOLID_SINGLE_YELLOW', index=6, number=6, 191 | serialized_options=None, 192 | type=None, 193 | create_key=_descriptor._internal_create_key), 194 | _descriptor.EnumValueDescriptor( 195 | name='TYPE_SOLID_DOUBLE_YELLOW', index=7, number=7, 196 | serialized_options=None, 197 | type=None, 198 | create_key=_descriptor._internal_create_key), 199 | _descriptor.EnumValueDescriptor( 200 | name='TYPE_PASSING_DOUBLE_YELLOW', index=8, number=8, 201 | serialized_options=None, 202 | type=None, 203 | create_key=_descriptor._internal_create_key), 204 | ], 205 | containing_type=None, 206 | serialized_options=None, 207 | serialized_start=2299, 208 | serialized_end=2573, 209 | ) 210 | _sym_db.RegisterEnumDescriptor(_ROADLINE_ROADLINETYPE) 211 | 212 | 213 | _MAP = _descriptor.Descriptor( 214 | name='Map', 215 | full_name='waymo.open_dataset.Map', 216 | filename=None, 217 | file=DESCRIPTOR, 218 | containing_type=None, 219 | create_key=_descriptor._internal_create_key, 220 | fields=[ 221 | _descriptor.FieldDescriptor( 222 | name='map_features', full_name='waymo.open_dataset.Map.map_features', index=0, 223 | number=1, type=11, cpp_type=10, label=3, 224 | has_default_value=False, default_value=[], 225 | message_type=None, enum_type=None, containing_type=None, 226 | is_extension=False, extension_scope=None, 227 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 228 | _descriptor.FieldDescriptor( 229 | name='dynamic_states', full_name='waymo.open_dataset.Map.dynamic_states', index=1, 230 | number=2, type=11, cpp_type=10, label=3, 231 | has_default_value=False, default_value=[], 232 | message_type=None, enum_type=None, containing_type=None, 233 | is_extension=False, extension_scope=None, 234 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 235 | ], 236 | extensions=[ 237 | ], 238 | nested_types=[], 239 | enum_types=[ 240 | ], 241 | serialized_options=None, 242 | is_extendable=False, 243 | syntax='proto2', 244 | extension_ranges=[], 245 | oneofs=[ 246 | ], 247 | serialized_start=33, 248 | serialized_end=150, 249 | ) 250 | 251 | 252 | _DYNAMICSTATE = _descriptor.Descriptor( 253 | name='DynamicState', 254 | full_name='waymo.open_dataset.DynamicState', 255 | filename=None, 256 | file=DESCRIPTOR, 257 | containing_type=None, 258 | create_key=_descriptor._internal_create_key, 259 | fields=[ 260 | _descriptor.FieldDescriptor( 261 | name='timestamp_seconds', full_name='waymo.open_dataset.DynamicState.timestamp_seconds', index=0, 262 | number=1, type=1, cpp_type=5, label=1, 263 | has_default_value=False, default_value=float(0), 264 | message_type=None, enum_type=None, containing_type=None, 265 | is_extension=False, extension_scope=None, 266 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 267 | _descriptor.FieldDescriptor( 268 | name='lane_states', full_name='waymo.open_dataset.DynamicState.lane_states', index=1, 269 | number=2, type=11, cpp_type=10, label=3, 270 | has_default_value=False, default_value=[], 271 | message_type=None, enum_type=None, containing_type=None, 272 | is_extension=False, extension_scope=None, 273 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 274 | ], 275 | extensions=[ 276 | ], 277 | nested_types=[], 278 | enum_types=[ 279 | ], 280 | serialized_options=None, 281 | is_extendable=False, 282 | syntax='proto2', 283 | extension_ranges=[], 284 | oneofs=[ 285 | ], 286 | serialized_start=152, 287 | serialized_end=258, 288 | ) 289 | 290 | 291 | _TRAFFICSIGNALLANESTATE = _descriptor.Descriptor( 292 | name='TrafficSignalLaneState', 293 | full_name='waymo.open_dataset.TrafficSignalLaneState', 294 | filename=None, 295 | file=DESCRIPTOR, 296 | containing_type=None, 297 | create_key=_descriptor._internal_create_key, 298 | fields=[ 299 | _descriptor.FieldDescriptor( 300 | name='lane', full_name='waymo.open_dataset.TrafficSignalLaneState.lane', index=0, 301 | number=1, type=3, cpp_type=2, label=1, 302 | has_default_value=False, default_value=0, 303 | message_type=None, enum_type=None, containing_type=None, 304 | is_extension=False, extension_scope=None, 305 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 306 | _descriptor.FieldDescriptor( 307 | name='state', full_name='waymo.open_dataset.TrafficSignalLaneState.state', index=1, 308 | number=2, type=14, cpp_type=8, label=1, 309 | has_default_value=False, default_value=0, 310 | message_type=None, enum_type=None, containing_type=None, 311 | is_extension=False, extension_scope=None, 312 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 313 | _descriptor.FieldDescriptor( 314 | name='stop_point', full_name='waymo.open_dataset.TrafficSignalLaneState.stop_point', index=2, 315 | number=3, type=11, cpp_type=10, label=1, 316 | has_default_value=False, default_value=None, 317 | message_type=None, enum_type=None, containing_type=None, 318 | is_extension=False, extension_scope=None, 319 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 320 | ], 321 | extensions=[ 322 | ], 323 | nested_types=[], 324 | enum_types=[ 325 | _TRAFFICSIGNALLANESTATE_STATE, 326 | ], 327 | serialized_options=None, 328 | is_extendable=False, 329 | syntax='proto2', 330 | extension_ranges=[], 331 | oneofs=[ 332 | ], 333 | serialized_start=261, 334 | serialized_end=657, 335 | ) 336 | 337 | 338 | _MAPFEATURE = _descriptor.Descriptor( 339 | name='MapFeature', 340 | full_name='waymo.open_dataset.MapFeature', 341 | filename=None, 342 | file=DESCRIPTOR, 343 | containing_type=None, 344 | create_key=_descriptor._internal_create_key, 345 | fields=[ 346 | _descriptor.FieldDescriptor( 347 | name='id', full_name='waymo.open_dataset.MapFeature.id', index=0, 348 | number=1, type=3, cpp_type=2, label=1, 349 | has_default_value=False, default_value=0, 350 | message_type=None, enum_type=None, containing_type=None, 351 | is_extension=False, extension_scope=None, 352 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 353 | _descriptor.FieldDescriptor( 354 | name='lane', full_name='waymo.open_dataset.MapFeature.lane', index=1, 355 | number=3, type=11, cpp_type=10, label=1, 356 | has_default_value=False, default_value=None, 357 | message_type=None, enum_type=None, containing_type=None, 358 | is_extension=False, extension_scope=None, 359 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 360 | _descriptor.FieldDescriptor( 361 | name='road_line', full_name='waymo.open_dataset.MapFeature.road_line', index=2, 362 | number=4, type=11, cpp_type=10, label=1, 363 | has_default_value=False, default_value=None, 364 | message_type=None, enum_type=None, containing_type=None, 365 | is_extension=False, extension_scope=None, 366 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 367 | _descriptor.FieldDescriptor( 368 | name='road_edge', full_name='waymo.open_dataset.MapFeature.road_edge', index=3, 369 | number=5, type=11, cpp_type=10, label=1, 370 | has_default_value=False, default_value=None, 371 | message_type=None, enum_type=None, containing_type=None, 372 | is_extension=False, extension_scope=None, 373 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 374 | _descriptor.FieldDescriptor( 375 | name='stop_sign', full_name='waymo.open_dataset.MapFeature.stop_sign', index=4, 376 | number=7, type=11, cpp_type=10, label=1, 377 | has_default_value=False, default_value=None, 378 | message_type=None, enum_type=None, containing_type=None, 379 | is_extension=False, extension_scope=None, 380 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 381 | _descriptor.FieldDescriptor( 382 | name='crosswalk', full_name='waymo.open_dataset.MapFeature.crosswalk', index=5, 383 | number=8, type=11, cpp_type=10, label=1, 384 | has_default_value=False, default_value=None, 385 | message_type=None, enum_type=None, containing_type=None, 386 | is_extension=False, extension_scope=None, 387 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 388 | _descriptor.FieldDescriptor( 389 | name='speed_bump', full_name='waymo.open_dataset.MapFeature.speed_bump', index=6, 390 | number=9, type=11, cpp_type=10, label=1, 391 | has_default_value=False, default_value=None, 392 | message_type=None, enum_type=None, containing_type=None, 393 | is_extension=False, extension_scope=None, 394 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 395 | ], 396 | extensions=[ 397 | ], 398 | nested_types=[], 399 | enum_types=[ 400 | ], 401 | serialized_options=None, 402 | is_extendable=False, 403 | syntax='proto2', 404 | extension_ranges=[], 405 | oneofs=[ 406 | _descriptor.OneofDescriptor( 407 | name='feature_data', full_name='waymo.open_dataset.MapFeature.feature_data', 408 | index=0, containing_type=None, 409 | create_key=_descriptor._internal_create_key, 410 | fields=[]), 411 | ], 412 | serialized_start=660, 413 | serialized_end=1006, 414 | ) 415 | 416 | 417 | _MAPPOINT = _descriptor.Descriptor( 418 | name='MapPoint', 419 | full_name='waymo.open_dataset.MapPoint', 420 | filename=None, 421 | file=DESCRIPTOR, 422 | containing_type=None, 423 | create_key=_descriptor._internal_create_key, 424 | fields=[ 425 | _descriptor.FieldDescriptor( 426 | name='x', full_name='waymo.open_dataset.MapPoint.x', index=0, 427 | number=1, type=1, cpp_type=5, label=1, 428 | has_default_value=False, default_value=float(0), 429 | message_type=None, enum_type=None, containing_type=None, 430 | is_extension=False, extension_scope=None, 431 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 432 | _descriptor.FieldDescriptor( 433 | name='y', full_name='waymo.open_dataset.MapPoint.y', index=1, 434 | number=2, type=1, cpp_type=5, label=1, 435 | has_default_value=False, default_value=float(0), 436 | message_type=None, enum_type=None, containing_type=None, 437 | is_extension=False, extension_scope=None, 438 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 439 | _descriptor.FieldDescriptor( 440 | name='z', full_name='waymo.open_dataset.MapPoint.z', index=2, 441 | number=3, type=1, cpp_type=5, label=1, 442 | has_default_value=False, default_value=float(0), 443 | message_type=None, enum_type=None, containing_type=None, 444 | is_extension=False, extension_scope=None, 445 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 446 | ], 447 | extensions=[ 448 | ], 449 | nested_types=[], 450 | enum_types=[ 451 | ], 452 | serialized_options=None, 453 | is_extendable=False, 454 | syntax='proto2', 455 | extension_ranges=[], 456 | oneofs=[ 457 | ], 458 | serialized_start=1008, 459 | serialized_end=1051, 460 | ) 461 | 462 | 463 | _BOUNDARYSEGMENT = _descriptor.Descriptor( 464 | name='BoundarySegment', 465 | full_name='waymo.open_dataset.BoundarySegment', 466 | filename=None, 467 | file=DESCRIPTOR, 468 | containing_type=None, 469 | create_key=_descriptor._internal_create_key, 470 | fields=[ 471 | _descriptor.FieldDescriptor( 472 | name='lane_start_index', full_name='waymo.open_dataset.BoundarySegment.lane_start_index', index=0, 473 | number=1, type=5, cpp_type=1, label=1, 474 | has_default_value=False, default_value=0, 475 | message_type=None, enum_type=None, containing_type=None, 476 | is_extension=False, extension_scope=None, 477 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 478 | _descriptor.FieldDescriptor( 479 | name='lane_end_index', full_name='waymo.open_dataset.BoundarySegment.lane_end_index', index=1, 480 | number=2, type=5, cpp_type=1, label=1, 481 | has_default_value=False, default_value=0, 482 | message_type=None, enum_type=None, containing_type=None, 483 | is_extension=False, extension_scope=None, 484 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 485 | _descriptor.FieldDescriptor( 486 | name='boundary_feature_id', full_name='waymo.open_dataset.BoundarySegment.boundary_feature_id', index=2, 487 | number=3, type=3, cpp_type=2, label=1, 488 | has_default_value=False, default_value=0, 489 | message_type=None, enum_type=None, containing_type=None, 490 | is_extension=False, extension_scope=None, 491 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 492 | _descriptor.FieldDescriptor( 493 | name='boundary_type', full_name='waymo.open_dataset.BoundarySegment.boundary_type', index=3, 494 | number=4, type=14, cpp_type=8, label=1, 495 | has_default_value=False, default_value=0, 496 | message_type=None, enum_type=None, containing_type=None, 497 | is_extension=False, extension_scope=None, 498 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 499 | ], 500 | extensions=[ 501 | ], 502 | nested_types=[], 503 | enum_types=[ 504 | ], 505 | serialized_options=None, 506 | is_extendable=False, 507 | syntax='proto2', 508 | extension_ranges=[], 509 | oneofs=[ 510 | ], 511 | serialized_start=1054, 512 | serialized_end=1216, 513 | ) 514 | 515 | 516 | _LANENEIGHBOR = _descriptor.Descriptor( 517 | name='LaneNeighbor', 518 | full_name='waymo.open_dataset.LaneNeighbor', 519 | filename=None, 520 | file=DESCRIPTOR, 521 | containing_type=None, 522 | create_key=_descriptor._internal_create_key, 523 | fields=[ 524 | _descriptor.FieldDescriptor( 525 | name='feature_id', full_name='waymo.open_dataset.LaneNeighbor.feature_id', index=0, 526 | number=1, type=3, cpp_type=2, label=1, 527 | has_default_value=False, default_value=0, 528 | message_type=None, enum_type=None, containing_type=None, 529 | is_extension=False, extension_scope=None, 530 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 531 | _descriptor.FieldDescriptor( 532 | name='self_start_index', full_name='waymo.open_dataset.LaneNeighbor.self_start_index', index=1, 533 | number=2, type=5, cpp_type=1, label=1, 534 | has_default_value=False, default_value=0, 535 | message_type=None, enum_type=None, containing_type=None, 536 | is_extension=False, extension_scope=None, 537 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 538 | _descriptor.FieldDescriptor( 539 | name='self_end_index', full_name='waymo.open_dataset.LaneNeighbor.self_end_index', index=2, 540 | number=3, type=5, cpp_type=1, label=1, 541 | has_default_value=False, default_value=0, 542 | message_type=None, enum_type=None, containing_type=None, 543 | is_extension=False, extension_scope=None, 544 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 545 | _descriptor.FieldDescriptor( 546 | name='neighbor_start_index', full_name='waymo.open_dataset.LaneNeighbor.neighbor_start_index', index=3, 547 | number=4, type=5, cpp_type=1, label=1, 548 | has_default_value=False, default_value=0, 549 | message_type=None, enum_type=None, containing_type=None, 550 | is_extension=False, extension_scope=None, 551 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 552 | _descriptor.FieldDescriptor( 553 | name='neighbor_end_index', full_name='waymo.open_dataset.LaneNeighbor.neighbor_end_index', index=4, 554 | number=5, type=5, cpp_type=1, label=1, 555 | has_default_value=False, default_value=0, 556 | message_type=None, enum_type=None, containing_type=None, 557 | is_extension=False, extension_scope=None, 558 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 559 | _descriptor.FieldDescriptor( 560 | name='boundaries', full_name='waymo.open_dataset.LaneNeighbor.boundaries', index=5, 561 | number=6, type=11, cpp_type=10, label=3, 562 | has_default_value=False, default_value=[], 563 | message_type=None, enum_type=None, containing_type=None, 564 | is_extension=False, extension_scope=None, 565 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 566 | ], 567 | extensions=[ 568 | ], 569 | nested_types=[], 570 | enum_types=[ 571 | ], 572 | serialized_options=None, 573 | is_extendable=False, 574 | syntax='proto2', 575 | extension_ranges=[], 576 | oneofs=[ 577 | ], 578 | serialized_start=1219, 579 | serialized_end=1418, 580 | ) 581 | 582 | 583 | _LANECENTER = _descriptor.Descriptor( 584 | name='LaneCenter', 585 | full_name='waymo.open_dataset.LaneCenter', 586 | filename=None, 587 | file=DESCRIPTOR, 588 | containing_type=None, 589 | create_key=_descriptor._internal_create_key, 590 | fields=[ 591 | _descriptor.FieldDescriptor( 592 | name='speed_limit_mph', full_name='waymo.open_dataset.LaneCenter.speed_limit_mph', index=0, 593 | number=1, type=1, cpp_type=5, label=1, 594 | has_default_value=False, default_value=float(0), 595 | message_type=None, enum_type=None, containing_type=None, 596 | is_extension=False, extension_scope=None, 597 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 598 | _descriptor.FieldDescriptor( 599 | name='type', full_name='waymo.open_dataset.LaneCenter.type', index=1, 600 | number=2, type=14, cpp_type=8, label=1, 601 | has_default_value=False, default_value=0, 602 | message_type=None, enum_type=None, containing_type=None, 603 | is_extension=False, extension_scope=None, 604 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 605 | _descriptor.FieldDescriptor( 606 | name='interpolating', full_name='waymo.open_dataset.LaneCenter.interpolating', index=2, 607 | number=3, type=8, cpp_type=7, label=1, 608 | has_default_value=False, default_value=False, 609 | message_type=None, enum_type=None, containing_type=None, 610 | is_extension=False, extension_scope=None, 611 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 612 | _descriptor.FieldDescriptor( 613 | name='polyline', full_name='waymo.open_dataset.LaneCenter.polyline', index=3, 614 | number=8, type=11, cpp_type=10, label=3, 615 | has_default_value=False, default_value=[], 616 | message_type=None, enum_type=None, containing_type=None, 617 | is_extension=False, extension_scope=None, 618 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 619 | _descriptor.FieldDescriptor( 620 | name='entry_lanes', full_name='waymo.open_dataset.LaneCenter.entry_lanes', index=4, 621 | number=9, type=3, cpp_type=2, label=3, 622 | has_default_value=False, default_value=[], 623 | message_type=None, enum_type=None, containing_type=None, 624 | is_extension=False, extension_scope=None, 625 | serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 626 | _descriptor.FieldDescriptor( 627 | name='exit_lanes', full_name='waymo.open_dataset.LaneCenter.exit_lanes', index=5, 628 | number=10, type=3, cpp_type=2, label=3, 629 | has_default_value=False, default_value=[], 630 | message_type=None, enum_type=None, containing_type=None, 631 | is_extension=False, extension_scope=None, 632 | serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 633 | _descriptor.FieldDescriptor( 634 | name='left_boundaries', full_name='waymo.open_dataset.LaneCenter.left_boundaries', index=6, 635 | number=13, type=11, cpp_type=10, label=3, 636 | has_default_value=False, default_value=[], 637 | message_type=None, enum_type=None, containing_type=None, 638 | is_extension=False, extension_scope=None, 639 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 640 | _descriptor.FieldDescriptor( 641 | name='right_boundaries', full_name='waymo.open_dataset.LaneCenter.right_boundaries', index=7, 642 | number=14, type=11, cpp_type=10, label=3, 643 | has_default_value=False, default_value=[], 644 | message_type=None, enum_type=None, containing_type=None, 645 | is_extension=False, extension_scope=None, 646 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 647 | _descriptor.FieldDescriptor( 648 | name='left_neighbors', full_name='waymo.open_dataset.LaneCenter.left_neighbors', index=8, 649 | number=11, type=11, cpp_type=10, label=3, 650 | has_default_value=False, default_value=[], 651 | message_type=None, enum_type=None, containing_type=None, 652 | is_extension=False, extension_scope=None, 653 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 654 | _descriptor.FieldDescriptor( 655 | name='right_neighbors', full_name='waymo.open_dataset.LaneCenter.right_neighbors', index=9, 656 | number=12, type=11, cpp_type=10, label=3, 657 | has_default_value=False, default_value=[], 658 | message_type=None, enum_type=None, containing_type=None, 659 | is_extension=False, extension_scope=None, 660 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 661 | ], 662 | extensions=[ 663 | ], 664 | nested_types=[], 665 | enum_types=[ 666 | _LANECENTER_LANETYPE, 667 | ], 668 | serialized_options=None, 669 | is_extendable=False, 670 | syntax='proto2', 671 | extension_ranges=[], 672 | oneofs=[ 673 | ], 674 | serialized_start=1421, 675 | serialized_end=1970, 676 | ) 677 | 678 | 679 | _ROADEDGE = _descriptor.Descriptor( 680 | name='RoadEdge', 681 | full_name='waymo.open_dataset.RoadEdge', 682 | filename=None, 683 | file=DESCRIPTOR, 684 | containing_type=None, 685 | create_key=_descriptor._internal_create_key, 686 | fields=[ 687 | _descriptor.FieldDescriptor( 688 | name='type', full_name='waymo.open_dataset.RoadEdge.type', index=0, 689 | number=1, type=14, cpp_type=8, label=1, 690 | has_default_value=False, default_value=0, 691 | message_type=None, enum_type=None, containing_type=None, 692 | is_extension=False, extension_scope=None, 693 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 694 | _descriptor.FieldDescriptor( 695 | name='polyline', full_name='waymo.open_dataset.RoadEdge.polyline', index=1, 696 | number=2, type=11, cpp_type=10, label=3, 697 | has_default_value=False, default_value=[], 698 | message_type=None, enum_type=None, containing_type=None, 699 | is_extension=False, extension_scope=None, 700 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 701 | ], 702 | extensions=[ 703 | ], 704 | nested_types=[], 705 | enum_types=[ 706 | _ROADEDGE_ROADEDGETYPE, 707 | ], 708 | serialized_options=None, 709 | is_extendable=False, 710 | syntax='proto2', 711 | extension_ranges=[], 712 | oneofs=[ 713 | ], 714 | serialized_start=1973, 715 | serialized_end=2178, 716 | ) 717 | 718 | 719 | _ROADLINE = _descriptor.Descriptor( 720 | name='RoadLine', 721 | full_name='waymo.open_dataset.RoadLine', 722 | filename=None, 723 | file=DESCRIPTOR, 724 | containing_type=None, 725 | create_key=_descriptor._internal_create_key, 726 | fields=[ 727 | _descriptor.FieldDescriptor( 728 | name='type', full_name='waymo.open_dataset.RoadLine.type', index=0, 729 | number=1, type=14, cpp_type=8, label=1, 730 | has_default_value=False, default_value=0, 731 | message_type=None, enum_type=None, containing_type=None, 732 | is_extension=False, extension_scope=None, 733 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 734 | _descriptor.FieldDescriptor( 735 | name='polyline', full_name='waymo.open_dataset.RoadLine.polyline', index=1, 736 | number=2, type=11, cpp_type=10, label=3, 737 | has_default_value=False, default_value=[], 738 | message_type=None, enum_type=None, containing_type=None, 739 | is_extension=False, extension_scope=None, 740 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 741 | ], 742 | extensions=[ 743 | ], 744 | nested_types=[], 745 | enum_types=[ 746 | _ROADLINE_ROADLINETYPE, 747 | ], 748 | serialized_options=None, 749 | is_extendable=False, 750 | syntax='proto2', 751 | extension_ranges=[], 752 | oneofs=[ 753 | ], 754 | serialized_start=2181, 755 | serialized_end=2573, 756 | ) 757 | 758 | 759 | _STOPSIGN = _descriptor.Descriptor( 760 | name='StopSign', 761 | full_name='waymo.open_dataset.StopSign', 762 | filename=None, 763 | file=DESCRIPTOR, 764 | containing_type=None, 765 | create_key=_descriptor._internal_create_key, 766 | fields=[ 767 | _descriptor.FieldDescriptor( 768 | name='lane', full_name='waymo.open_dataset.StopSign.lane', index=0, 769 | number=1, type=3, cpp_type=2, label=3, 770 | has_default_value=False, default_value=[], 771 | message_type=None, enum_type=None, containing_type=None, 772 | is_extension=False, extension_scope=None, 773 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 774 | _descriptor.FieldDescriptor( 775 | name='position', full_name='waymo.open_dataset.StopSign.position', index=1, 776 | number=2, type=11, cpp_type=10, label=1, 777 | has_default_value=False, default_value=None, 778 | message_type=None, enum_type=None, containing_type=None, 779 | is_extension=False, extension_scope=None, 780 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 781 | ], 782 | extensions=[ 783 | ], 784 | nested_types=[], 785 | enum_types=[ 786 | ], 787 | serialized_options=None, 788 | is_extendable=False, 789 | syntax='proto2', 790 | extension_ranges=[], 791 | oneofs=[ 792 | ], 793 | serialized_start=2575, 794 | serialized_end=2647, 795 | ) 796 | 797 | 798 | _CROSSWALK = _descriptor.Descriptor( 799 | name='Crosswalk', 800 | full_name='waymo.open_dataset.Crosswalk', 801 | filename=None, 802 | file=DESCRIPTOR, 803 | containing_type=None, 804 | create_key=_descriptor._internal_create_key, 805 | fields=[ 806 | _descriptor.FieldDescriptor( 807 | name='polygon', full_name='waymo.open_dataset.Crosswalk.polygon', index=0, 808 | number=1, type=11, cpp_type=10, label=3, 809 | has_default_value=False, default_value=[], 810 | message_type=None, enum_type=None, containing_type=None, 811 | is_extension=False, extension_scope=None, 812 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 813 | ], 814 | extensions=[ 815 | ], 816 | nested_types=[], 817 | enum_types=[ 818 | ], 819 | serialized_options=None, 820 | is_extendable=False, 821 | syntax='proto2', 822 | extension_ranges=[], 823 | oneofs=[ 824 | ], 825 | serialized_start=2649, 826 | serialized_end=2707, 827 | ) 828 | 829 | 830 | _SPEEDBUMP = _descriptor.Descriptor( 831 | name='SpeedBump', 832 | full_name='waymo.open_dataset.SpeedBump', 833 | filename=None, 834 | file=DESCRIPTOR, 835 | containing_type=None, 836 | create_key=_descriptor._internal_create_key, 837 | fields=[ 838 | _descriptor.FieldDescriptor( 839 | name='polygon', full_name='waymo.open_dataset.SpeedBump.polygon', index=0, 840 | number=1, type=11, cpp_type=10, label=3, 841 | has_default_value=False, default_value=[], 842 | message_type=None, enum_type=None, containing_type=None, 843 | is_extension=False, extension_scope=None, 844 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 845 | ], 846 | extensions=[ 847 | ], 848 | nested_types=[], 849 | enum_types=[ 850 | ], 851 | serialized_options=None, 852 | is_extendable=False, 853 | syntax='proto2', 854 | extension_ranges=[], 855 | oneofs=[ 856 | ], 857 | serialized_start=2709, 858 | serialized_end=2767, 859 | ) 860 | 861 | _MAP.fields_by_name['map_features'].message_type = _MAPFEATURE 862 | _MAP.fields_by_name['dynamic_states'].message_type = _DYNAMICSTATE 863 | _DYNAMICSTATE.fields_by_name['lane_states'].message_type = _TRAFFICSIGNALLANESTATE 864 | _TRAFFICSIGNALLANESTATE.fields_by_name['state'].enum_type = _TRAFFICSIGNALLANESTATE_STATE 865 | _TRAFFICSIGNALLANESTATE.fields_by_name['stop_point'].message_type = _MAPPOINT 866 | _TRAFFICSIGNALLANESTATE_STATE.containing_type = _TRAFFICSIGNALLANESTATE 867 | _MAPFEATURE.fields_by_name['lane'].message_type = _LANECENTER 868 | _MAPFEATURE.fields_by_name['road_line'].message_type = _ROADLINE 869 | _MAPFEATURE.fields_by_name['road_edge'].message_type = _ROADEDGE 870 | _MAPFEATURE.fields_by_name['stop_sign'].message_type = _STOPSIGN 871 | _MAPFEATURE.fields_by_name['crosswalk'].message_type = _CROSSWALK 872 | _MAPFEATURE.fields_by_name['speed_bump'].message_type = _SPEEDBUMP 873 | _MAPFEATURE.oneofs_by_name['feature_data'].fields.append( 874 | _MAPFEATURE.fields_by_name['lane']) 875 | _MAPFEATURE.fields_by_name['lane'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] 876 | _MAPFEATURE.oneofs_by_name['feature_data'].fields.append( 877 | _MAPFEATURE.fields_by_name['road_line']) 878 | _MAPFEATURE.fields_by_name['road_line'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] 879 | _MAPFEATURE.oneofs_by_name['feature_data'].fields.append( 880 | _MAPFEATURE.fields_by_name['road_edge']) 881 | _MAPFEATURE.fields_by_name['road_edge'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] 882 | _MAPFEATURE.oneofs_by_name['feature_data'].fields.append( 883 | _MAPFEATURE.fields_by_name['stop_sign']) 884 | _MAPFEATURE.fields_by_name['stop_sign'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] 885 | _MAPFEATURE.oneofs_by_name['feature_data'].fields.append( 886 | _MAPFEATURE.fields_by_name['crosswalk']) 887 | _MAPFEATURE.fields_by_name['crosswalk'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] 888 | _MAPFEATURE.oneofs_by_name['feature_data'].fields.append( 889 | _MAPFEATURE.fields_by_name['speed_bump']) 890 | _MAPFEATURE.fields_by_name['speed_bump'].containing_oneof = _MAPFEATURE.oneofs_by_name['feature_data'] 891 | _BOUNDARYSEGMENT.fields_by_name['boundary_type'].enum_type = _ROADLINE_ROADLINETYPE 892 | _LANENEIGHBOR.fields_by_name['boundaries'].message_type = _BOUNDARYSEGMENT 893 | _LANECENTER.fields_by_name['type'].enum_type = _LANECENTER_LANETYPE 894 | _LANECENTER.fields_by_name['polyline'].message_type = _MAPPOINT 895 | _LANECENTER.fields_by_name['left_boundaries'].message_type = _BOUNDARYSEGMENT 896 | _LANECENTER.fields_by_name['right_boundaries'].message_type = _BOUNDARYSEGMENT 897 | _LANECENTER.fields_by_name['left_neighbors'].message_type = _LANENEIGHBOR 898 | _LANECENTER.fields_by_name['right_neighbors'].message_type = _LANENEIGHBOR 899 | _LANECENTER_LANETYPE.containing_type = _LANECENTER 900 | _ROADEDGE.fields_by_name['type'].enum_type = _ROADEDGE_ROADEDGETYPE 901 | _ROADEDGE.fields_by_name['polyline'].message_type = _MAPPOINT 902 | _ROADEDGE_ROADEDGETYPE.containing_type = _ROADEDGE 903 | _ROADLINE.fields_by_name['type'].enum_type = _ROADLINE_ROADLINETYPE 904 | _ROADLINE.fields_by_name['polyline'].message_type = _MAPPOINT 905 | _ROADLINE_ROADLINETYPE.containing_type = _ROADLINE 906 | _STOPSIGN.fields_by_name['position'].message_type = _MAPPOINT 907 | _CROSSWALK.fields_by_name['polygon'].message_type = _MAPPOINT 908 | _SPEEDBUMP.fields_by_name['polygon'].message_type = _MAPPOINT 909 | DESCRIPTOR.message_types_by_name['Map'] = _MAP 910 | DESCRIPTOR.message_types_by_name['DynamicState'] = _DYNAMICSTATE 911 | DESCRIPTOR.message_types_by_name['TrafficSignalLaneState'] = _TRAFFICSIGNALLANESTATE 912 | DESCRIPTOR.message_types_by_name['MapFeature'] = _MAPFEATURE 913 | DESCRIPTOR.message_types_by_name['MapPoint'] = _MAPPOINT 914 | DESCRIPTOR.message_types_by_name['BoundarySegment'] = _BOUNDARYSEGMENT 915 | DESCRIPTOR.message_types_by_name['LaneNeighbor'] = _LANENEIGHBOR 916 | DESCRIPTOR.message_types_by_name['LaneCenter'] = _LANECENTER 917 | DESCRIPTOR.message_types_by_name['RoadEdge'] = _ROADEDGE 918 | DESCRIPTOR.message_types_by_name['RoadLine'] = _ROADLINE 919 | DESCRIPTOR.message_types_by_name['StopSign'] = _STOPSIGN 920 | DESCRIPTOR.message_types_by_name['Crosswalk'] = _CROSSWALK 921 | DESCRIPTOR.message_types_by_name['SpeedBump'] = _SPEEDBUMP 922 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 923 | 924 | Map = _reflection.GeneratedProtocolMessageType('Map', (_message.Message,), { 925 | 'DESCRIPTOR' : _MAP, 926 | '__module__' : 'map_pb2' 927 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.Map) 928 | }) 929 | _sym_db.RegisterMessage(Map) 930 | 931 | DynamicState = _reflection.GeneratedProtocolMessageType('DynamicState', (_message.Message,), { 932 | 'DESCRIPTOR' : _DYNAMICSTATE, 933 | '__module__' : 'map_pb2' 934 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.DynamicState) 935 | }) 936 | _sym_db.RegisterMessage(DynamicState) 937 | 938 | TrafficSignalLaneState = _reflection.GeneratedProtocolMessageType('TrafficSignalLaneState', (_message.Message,), { 939 | 'DESCRIPTOR' : _TRAFFICSIGNALLANESTATE, 940 | '__module__' : 'map_pb2' 941 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.TrafficSignalLaneState) 942 | }) 943 | _sym_db.RegisterMessage(TrafficSignalLaneState) 944 | 945 | MapFeature = _reflection.GeneratedProtocolMessageType('MapFeature', (_message.Message,), { 946 | 'DESCRIPTOR' : _MAPFEATURE, 947 | '__module__' : 'map_pb2' 948 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.MapFeature) 949 | }) 950 | _sym_db.RegisterMessage(MapFeature) 951 | 952 | MapPoint = _reflection.GeneratedProtocolMessageType('MapPoint', (_message.Message,), { 953 | 'DESCRIPTOR' : _MAPPOINT, 954 | '__module__' : 'map_pb2' 955 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.MapPoint) 956 | }) 957 | _sym_db.RegisterMessage(MapPoint) 958 | 959 | BoundarySegment = _reflection.GeneratedProtocolMessageType('BoundarySegment', (_message.Message,), { 960 | 'DESCRIPTOR' : _BOUNDARYSEGMENT, 961 | '__module__' : 'map_pb2' 962 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.BoundarySegment) 963 | }) 964 | _sym_db.RegisterMessage(BoundarySegment) 965 | 966 | LaneNeighbor = _reflection.GeneratedProtocolMessageType('LaneNeighbor', (_message.Message,), { 967 | 'DESCRIPTOR' : _LANENEIGHBOR, 968 | '__module__' : 'map_pb2' 969 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.LaneNeighbor) 970 | }) 971 | _sym_db.RegisterMessage(LaneNeighbor) 972 | 973 | LaneCenter = _reflection.GeneratedProtocolMessageType('LaneCenter', (_message.Message,), { 974 | 'DESCRIPTOR' : _LANECENTER, 975 | '__module__' : 'map_pb2' 976 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.LaneCenter) 977 | }) 978 | _sym_db.RegisterMessage(LaneCenter) 979 | 980 | RoadEdge = _reflection.GeneratedProtocolMessageType('RoadEdge', (_message.Message,), { 981 | 'DESCRIPTOR' : _ROADEDGE, 982 | '__module__' : 'map_pb2' 983 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.RoadEdge) 984 | }) 985 | _sym_db.RegisterMessage(RoadEdge) 986 | 987 | RoadLine = _reflection.GeneratedProtocolMessageType('RoadLine', (_message.Message,), { 988 | 'DESCRIPTOR' : _ROADLINE, 989 | '__module__' : 'map_pb2' 990 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.RoadLine) 991 | }) 992 | _sym_db.RegisterMessage(RoadLine) 993 | 994 | StopSign = _reflection.GeneratedProtocolMessageType('StopSign', (_message.Message,), { 995 | 'DESCRIPTOR' : _STOPSIGN, 996 | '__module__' : 'map_pb2' 997 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.StopSign) 998 | }) 999 | _sym_db.RegisterMessage(StopSign) 1000 | 1001 | Crosswalk = _reflection.GeneratedProtocolMessageType('Crosswalk', (_message.Message,), { 1002 | 'DESCRIPTOR' : _CROSSWALK, 1003 | '__module__' : 'map_pb2' 1004 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.Crosswalk) 1005 | }) 1006 | _sym_db.RegisterMessage(Crosswalk) 1007 | 1008 | SpeedBump = _reflection.GeneratedProtocolMessageType('SpeedBump', (_message.Message,), { 1009 | 'DESCRIPTOR' : _SPEEDBUMP, 1010 | '__module__' : 'map_pb2' 1011 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.SpeedBump) 1012 | }) 1013 | _sym_db.RegisterMessage(SpeedBump) 1014 | 1015 | 1016 | _LANECENTER.fields_by_name['entry_lanes']._options = None 1017 | _LANECENTER.fields_by_name['exit_lanes']._options = None 1018 | # @@protoc_insertion_point(module_scope) 1019 | -------------------------------------------------------------------------------- /preprocess/preprocess_waymo.py: -------------------------------------------------------------------------------- 1 | ## Copyright 2022 Xiaosong Jia. All Rights Reserved. 2 | ## Check https://github.com/waymo-research/waymo-open-dataset/blob/656f759070a7b1356f9f0403b17cd85323e0626c/src/waymo_open_dataset/protos/map.proto and https://github.com/waymo-research/waymo-open-dataset/blob/656f759070a7b1356f9f0403b17cd85323e0626c/src/waymo_open_dataset/protos/scenario.proto for details about the data structure and data type 3 | from typing_extensions import final 4 | import os 5 | os.sys.path.append('.') 6 | import map_pb2 7 | import scenario_pb2 8 | import os 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 10 | import pickle 11 | from posixpath import basename 12 | import numpy as np 13 | import tensorflow as tf 14 | from google.protobuf import text_format 15 | import sys 16 | import gc 17 | import math 18 | import shutil 19 | 20 | ## To obtain the index for interpolate missed frame of inputs 21 | def get_all_break_point(total_lis, sub_lis): 22 | i = 0 23 | j = 0 24 | last_conti_index = -1 25 | break_point_i_lis = [] 26 | break_point_j_lis = [] 27 | while j < len(sub_lis): 28 | while total_lis[i] != sub_lis[j]: 29 | i += 1 30 | if last_conti_index == -1: 31 | last_conti_index = i 32 | elif i == last_conti_index + 1: 33 | last_conti_index += 1 34 | else: 35 | break_point_i_lis.append((last_conti_index, i)) 36 | break_point_j_lis.append((j-1, j)) 37 | last_conti_index = i 38 | j += 1 39 | return break_point_i_lis, break_point_j_lis 40 | 41 | import scipy.interpolate as interp 42 | ## Interpolate polylines to the target number of points 43 | def interpolate_polyline(polyline, num_points): 44 | if np.allclose(polyline[0], polyline[1]): 45 | return polyline[0][np.newaxis, :].repeat(num_points, axis=0) 46 | tck, u = interp.splprep(polyline.T, s=0, k=1) 47 | u = np.linspace(0.0, 1.0, num_points) 48 | return np.column_stack(interp.splev(u, tck)) 49 | 50 | def euclid(label, pred): 51 | return np.sqrt((label[..., 0]-pred[...,0])**2 + (label[...,1]-pred[...,1])**2) 52 | 53 | 54 | 55 | ## For interpolation 56 | def transfer_poly_to_rectangle(poly_coor, yaw): 57 | cos_theta = np.cos(-yaw) 58 | sin_theta = np.sin(-yaw) 59 | poly_coor[..., 0], poly_coor[..., 1] = poly_coor[..., 0]*cos_theta - poly_coor[..., 1]*sin_theta, poly_coor[..., 1]*cos_theta + poly_coor[..., 0]*sin_theta 60 | p = Polygon(poly_coor) 61 | xmin, ymin, xmax, ymax = p.bounds 62 | poly_coor = np.array([[xmin, ymin], [xmin, ymax], [xmax, ymax], [xmax, ymin]]) 63 | cos_theta = np.cos(yaw) 64 | sin_theta = np.sin(yaw) 65 | poly_coor[..., 0], poly_coor[..., 1] = poly_coor[..., 0]*cos_theta - poly_coor[..., 1]*sin_theta, poly_coor[..., 1]*cos_theta + poly_coor[..., 0]*sin_theta 66 | return poly_coor 67 | 68 | num_of_scene_per_folder = 84 ## parallel preprocessing 69 | current_scene_index = 0 if len(sys.argv) == 1 else int(sys.argv[1]) 70 | parent_path = os.path.join(os.path.dirname(os.getcwd()), "dataset", "waymo") 71 | 72 | data_version = "hdgt_waymo_dev_tmp" if len(sys.argv) == 1 else sys.argv[2] 73 | 74 | if current_scene_index == 12: ## Validation Set 75 | data_split_type = "validation" 76 | saving_folder = os.path.join(parent_path, data_split_type, data_version+str(current_scene_index)) 77 | 78 | else: 79 | data_split_type = "training" 80 | saving_folder = os.path.join(parent_path, data_split_type, data_version+str(current_scene_index)) 81 | if os.path.exists(saving_folder): 82 | shutil.rmtree(saving_folder) 83 | os.mkdir(saving_folder) 84 | print("Scene Index:", current_scene_index, "Data Split:", data_split_type, "Start!!!") 85 | is_val = ("validation" == data_split_type) 86 | 87 | case_cnt = 0 88 | num_of_element_lis = [] ## The number of elements (agents + map) in each scene; For balanced batching; Otherwise, the GPU memory usage might vary a lot 89 | 90 | fdir = os.listdir(os.path.join(parent_path, data_split_type)) ## The directory contains all tfrecord file 91 | fdir.sort() 92 | for fname in fdir: 93 | # if case_cnt > 5: 94 | # break 95 | if "tfrecord" in fname: 96 | record_index = int(fname.split("-")[1]) 97 | if not is_val and (record_index < current_scene_index * num_of_scene_per_folder or record_index >= (current_scene_index+1) * num_of_scene_per_folder): ## Not for this worker 98 | continue 99 | raw_dataset = tf.data.TFRecordDataset([os.path.join(parent_path, data_split_type, fname)]) 100 | print("new file", fname, flush=True) 101 | for raw_record_index, raw_record in enumerate(raw_dataset): 102 | proto_string = raw_record.numpy() 103 | proto = scenario_pb2.Scenario() 104 | proto.ParseFromString(proto_string) 105 | ## Agent Feature 106 | now_num_time_step = len(proto.timestamps_seconds) 107 | now_all_tracks = proto.tracks 108 | total_num_of_agent = len(now_all_tracks) 109 | 110 | ### Debug 111 | # if case_cnt > 5: 112 | # break 113 | 114 | now_predict_agent_index = [] 115 | now_difficulty = [] 116 | for predict_track_info in proto.tracks_to_predict: 117 | now_predict_agent_index.append(predict_track_info.track_index) 118 | now_difficulty.append(predict_track_info.difficulty) 119 | 120 | now_difficulty = np.array(now_difficulty) 121 | now_other_agent_index = [_ for _ in range(total_num_of_agent) if _ not in now_predict_agent_index] 122 | all_object_index = now_predict_agent_index + now_other_agent_index 123 | 124 | 125 | all_predict_agent_feature = [] 126 | all_label = [] 127 | all_auxiliary_label = [] 128 | all_label_mask = [] 129 | all_predict_agent_type = [] 130 | new_predict_index = [] ## After filering and reordering agents 131 | 132 | all_other_agent_type = [] 133 | all_other_agent_feature = [] 134 | all_other_label = [] 135 | all_other_label_mask = [] 136 | new_other_index = [] 137 | 138 | now_index2id = {} 139 | pred_num = 0 140 | now_scene_id = proto.scenario_id 141 | for now_track_index_cnt, now_track_index in enumerate(all_object_index): 142 | now_track = now_all_tracks[now_track_index] 143 | now_mask = np.array([1] * 11) 144 | now_index2id[now_track_index] = now_track.id 145 | now_fea = [] 146 | obs_end_index = None 147 | for timestep, timestep_state in enumerate(now_track.states): 148 | if timestep_state.valid: 149 | now_fea.append(np.array([int(timestep), float(timestep_state.center_x), float(timestep_state.center_y), float(timestep_state.center_z), float(timestep_state.velocity_x), float(timestep_state.velocity_y), float(timestep_state.heading), float(timestep_state.length), float(timestep_state.width), float(timestep_state.height),])) 150 | if timestep <= 10: 151 | obs_end_index = len(now_fea) 152 | now_fea = np.stack(now_fea, axis=0) #Num_Observed_Timestep, C: #T, x, y, z, vx, vy, psi, length, width, height 153 | ## The index of first timestep > 10 -> Not in the observation interval [0, 10] - 1.1 second 154 | if now_fea[0][0] > 10: 155 | continue 156 | now_obs_fea = now_fea[:obs_end_index] # 157 | ## Not fully observable - interpolate with constant acceleratation assumption 158 | if now_obs_fea.shape[0] < 11: 159 | start_step = int(now_obs_fea[0, 0]) 160 | end_step = int(now_obs_fea[-1, 0]) 161 | padded_fea = now_obs_fea[:, 1:7] ## We only need to interpolate x, y, z, vx, vy, psi 162 | 163 | if now_obs_fea.shape[0] != end_step-start_step+1: 164 | break_point_i_lis, break_point_j_lis = get_all_break_point(list(range(0, 11)), now_obs_fea[:, 0]) 165 | line_lis = [now_obs_fea[:break_point_j_lis[0][0]+1, 1:7]] 166 | for bi in range(len(break_point_i_lis)): 167 | line = interp.interp1d(x=[0.0, 1.0], y=now_obs_fea[break_point_j_lis[bi][0]:break_point_j_lis[bi][1]+1, 1:7], assume_sorted=True, axis=0)(np.linspace(0.0, 1.0, break_point_i_lis[bi][1]-break_point_i_lis[bi][0]+1)) 168 | v_xy = line[1:, 3:5] 169 | cumsum_step = np.cumsum(v_xy/10.0, axis=0) + line[0, 0:2][np.newaxis, :] 170 | line[1:, 0:2] = cumsum_step 171 | now_mask[break_point_i_lis[bi][0]+1:break_point_i_lis[bi][1]] = 0 172 | line_lis.append(line[1:-1, :]) 173 | if bi == len(break_point_i_lis) - 1: 174 | line_lis.append(now_obs_fea[break_point_j_lis[bi][1]:, 1:7]) 175 | else: 176 | line_lis.append(now_obs_fea[break_point_j_lis[bi][1]:break_point_j_lis[bi+1][0]+1, 1:7]) 177 | padded_fea = np.concatenate(line_lis, axis=0) 178 | if start_step > 0: 179 | v_xy = padded_fea[0, 3:5] 180 | cumsum_step = np.cumsum((-v_xy/10.0)[np.newaxis, :].repeat(start_step, axis=0), axis=0)[::-1, :] + padded_fea[0, 0:2][np.newaxis, :] 181 | padded_fea = np.concatenate([padded_fea[0][np.newaxis,:].repeat(start_step, axis=0), padded_fea], axis=0) 182 | padded_fea[:start_step, :2] = cumsum_step 183 | now_mask[:start_step] = 0 184 | if end_step < 10: 185 | v_xy = padded_fea[-1, 3:5] 186 | cumsum_step = np.cumsum((v_xy/10.0)[np.newaxis, :].repeat(10-end_step, axis=0), axis=0) + padded_fea[-1, 0:2][np.newaxis, :] 187 | padded_fea = np.concatenate([padded_fea, padded_fea[-1][np.newaxis,:].repeat(10-end_step, axis=0)], axis=0) 188 | padded_fea[end_step+1:, :2] = cumsum_step 189 | now_mask[end_step+1:] = 0 190 | now_obs_fea_padded = np.concatenate([padded_fea, np.array([now_obs_fea[:, -3].mean()]*11)[:, np.newaxis], np.array([now_obs_fea[:, -2].mean()]*11)[:, np.newaxis], np.array([now_obs_fea[:, -1].mean()]*11)[:, np.newaxis], now_mask[:, np.newaxis]], axis=-1) 191 | else: 192 | now_obs_fea_padded = np.concatenate([now_obs_fea[:, 1:], now_mask[:, np.newaxis]], axis=-1) 193 | 194 | 195 | now_future_fea = now_fea[obs_end_index:] 196 | if now_future_fea.shape[0] != 80: 197 | now_label_mask = np.array([0] * 80) 198 | tmp_label = np.zeros((1, 80, 2)) 199 | tmp_auxiliary_label = np.zeros((1, 80, 3)) 200 | for label_time_index_i in range(now_future_fea.shape[0]): 201 | label_time_index_in_lis = int(now_future_fea[label_time_index_i, 0]) - 11 202 | now_label_mask[label_time_index_in_lis] = 1 203 | tmp_label[0, label_time_index_in_lis, :] = now_future_fea[label_time_index_i, [1, 2]] 204 | tmp_auxiliary_label[0, label_time_index_in_lis, :] = now_future_fea[label_time_index_i, [4, 5, 6]] 205 | else: 206 | now_label_mask = np.array([1] * 80) 207 | tmp_label = now_future_fea[:, 1:3][np.newaxis, :, :] 208 | tmp_auxiliary_label = now_future_fea[:, 4:7][np.newaxis, :, :] 209 | 210 | if now_track_index_cnt >= len(now_predict_agent_index) or now_future_fea.shape[0] == 0: 211 | ## Other Label 212 | all_other_agent_feature.append(now_obs_fea_padded) 213 | all_other_label.append(tmp_label) 214 | all_other_label_mask.append(now_label_mask) 215 | new_other_index.append(now_track_index) 216 | all_other_agent_type.append(now_track.object_type) 217 | else: 218 | pred_num += 1 219 | all_predict_agent_feature.append(now_obs_fea_padded) 220 | all_label.append(tmp_label) 221 | all_auxiliary_label.append(tmp_auxiliary_label) 222 | all_label_mask.append(now_label_mask) 223 | new_predict_index.append(now_track_index) 224 | all_predict_agent_type.append(now_track.object_type) 225 | 226 | all_object_id = [now_index2id[_] for _ in new_predict_index] + [now_index2id[_] for _ in new_other_index] 227 | all_agent_type = np.array(all_predict_agent_type + all_other_agent_type) 228 | all_agent_feature = np.stack(all_predict_agent_feature+all_other_agent_feature, axis=0) ## Num_agent, T_observed, C: x, y, z, vx, vy, heading, length, width, height, mask 229 | 230 | #all_agent_obs_final_v = np.sqrt(all_input_data[:, -1, 3]**2+all_input_data[:, -1, 4]**2) 231 | all_agent_map_size = np.ones(all_agent_feature.shape[0]) * 999.0 ## During preprocessing, we simply keep all map elements 232 | #all_agent_obs_final_v * 8.0 + np.vectorize(map_size_lis.__getitem__)(all_agent_type) 233 | all_dynamic_map_fea_dic = {} 234 | for time_step in range(11): 235 | for map_element_index in range(len(proto.dynamic_map_states[time_step].lane_states)): 236 | now_tuple = (float(proto.dynamic_map_states[time_step].lane_states[map_element_index].stop_point.x), float(proto.dynamic_map_states[time_step].lane_states[map_element_index].stop_point.y), float(proto.dynamic_map_states[time_step].lane_states[map_element_index].stop_point.z), proto.dynamic_map_states[time_step].lane_states[map_element_index].lane) 237 | if now_tuple not in all_dynamic_map_fea_dic: 238 | all_dynamic_map_fea_dic[now_tuple] = [0] * 11 ## 0 represents unknown 239 | all_dynamic_map_fea_dic[now_tuple][time_step] = proto.dynamic_map_states[time_step].lane_states[map_element_index].state 240 | all_unkown_traffic = [0] * 11 241 | traffic_light_info_to_remove = [] 242 | for k, v in all_dynamic_map_fea_dic.items(): 243 | if v == all_unkown_traffic: 244 | traffic_light_info_to_remove.append(k) 245 | for k in traffic_light_info_to_remove: 246 | all_dynamic_map_fea_dic.pop(k) 247 | all_dynamic_map_fea = {int(k[3]):np.array([k[0], k[1], k[2]] + v) for k, v in all_dynamic_map_fea_dic.items()} 248 | 249 | #id: (type, polygon) type:0,1 250 | all_polygon_fea = [] 251 | #id: (lane_id_lis, [x,y]) 252 | all_stopsign_fea = {} 253 | #id: lane_info_dic 254 | all_lane_fea = {} 255 | all_road_edge = [] 256 | all_road_line = [] 257 | for now_map_fea in proto.map_features: 258 | if now_map_fea.HasField("crosswalk"): 259 | now_polygon_fea = [0, np.array([[_.x, _.y, _.z]for _ in now_map_fea.crosswalk.polygon])] 260 | all_polygon_fea.append(now_polygon_fea) 261 | if now_map_fea.HasField("speed_bump"): 262 | now_polygon_fea = [1, np.array([[_.x, _.y, _.z]for _ in now_map_fea.speed_bump.polygon])] 263 | all_polygon_fea.append(now_polygon_fea) 264 | if now_map_fea.HasField("stop_sign"): 265 | all_stopsign_fea[int(now_map_fea.id)] = [list(now_map_fea.stop_sign.lane), [now_map_fea.stop_sign.position.x, now_map_fea.stop_sign.position.y, now_map_fea.stop_sign.position.z]] 266 | if now_map_fea.HasField("lane"): 267 | all_lane_fea[int(now_map_fea.id)] = {} 268 | all_lane_fea[int(now_map_fea.id)]["speed_limit"] = now_map_fea.lane.speed_limit_mph 269 | all_lane_fea[int(now_map_fea.id)]["type"] = now_map_fea.lane.type #5 types 270 | all_lane_fea[int(now_map_fea.id)]["xyz"] = np.array([[_.x, _.y, _.z]for _ in now_map_fea.lane.polyline]) 271 | ## A list of IDs for lanes that this lane may be entered from. 272 | all_lane_fea[int(now_map_fea.id)]["entry"] = list(now_map_fea.lane.entry_lanes) 273 | ## A list of IDs for lanes that this lane may exit to. 274 | all_lane_fea[int(now_map_fea.id)]["exit"] = list(now_map_fea.lane.exit_lanes) 275 | all_lane_fea[int(now_map_fea.id)]["left"] = [] 276 | for left_neighbor in now_map_fea.lane.left_neighbors: 277 | boundary_type_lis = [int(_.boundary_type) for _ in left_neighbor.boundaries] 278 | ## For simplicity, we use the first appeared boundary type as the type 279 | if 1 in boundary_type_lis: 280 | boundary_type = 1 281 | elif 2 in boundary_type_lis: 282 | boundary_type = 2 283 | elif 3 in boundary_type_lis: 284 | boundary_type = 3 285 | else: 286 | boundary_type = 0 287 | ## ID -> Neighbor ID, self_start, self_end, neighbor_start, neighbor_end, type (4) 288 | all_lane_fea[int(now_map_fea.id)]["left"].append([left_neighbor.feature_id, left_neighbor.self_start_index, left_neighbor.self_end_index, left_neighbor.neighbor_start_index, left_neighbor.neighbor_end_index, boundary_type]) 289 | all_lane_fea[int(now_map_fea.id)]["right"] = [] 290 | for right_neighbor in now_map_fea.lane.right_neighbors: 291 | boundary_type_lis = [int(_.boundary_type) for _ in right_neighbor.boundaries] 292 | if 1 in boundary_type_lis: 293 | boundary_type = 1 294 | elif 2 in boundary_type_lis: 295 | boundary_type = 2 296 | elif 3 in boundary_type_lis: 297 | boundary_type = 3 298 | else: 299 | boundary_type = 0 300 | ##ID -> Neighbor ID, self_start, self_end, neighbor_start, neighbor_end, type 301 | all_lane_fea[int(now_map_fea.id)]["right"].append([right_neighbor.feature_id, right_neighbor.self_start_index, right_neighbor.self_end_index, right_neighbor.neighbor_start_index, right_neighbor.neighbor_end_index, boundary_type]) 302 | if now_map_fea.HasField("road_edge"): 303 | road_edge_xy = np.array([[_.x, _.y, _.z]for _ in now_map_fea.road_edge.polyline]) 304 | if road_edge_xy.shape[0] > 2: 305 | all_polygon_fea.append([now_map_fea.road_edge.type+2, road_edge_xy]) 306 | if now_map_fea.HasField("road_line"): 307 | road_line_xy = np.array([[_.x, _.y, _.z]for _ in now_map_fea.road_line.polyline]) 308 | if road_line_xy.shape[0] > 2: 309 | all_polygon_fea.append([now_map_fea.road_line.type+2+3, road_line_xy]) ## 14 types 310 | 311 | ## Split Long Lane and interpolate to the same number of points (20) per polyline (20m). Then, we need to update the up/front/left/right relations of the splitted lanes 312 | length_per_polyline = 40.0 # 20 meters 313 | point_per_polyline = 21 314 | space = int(length_per_polyline // (point_per_polyline-1)) 315 | 316 | new_lane_fea = [] 317 | old_lane_id_to_new_lane_index_lis = {} 318 | 319 | for old_lane_id, old_lane_info in all_lane_fea.items(): 320 | if old_lane_info["xyz"].shape[0] <= length_per_polyline: 321 | old_lane_id_to_new_lane_index_lis[old_lane_id] = [len(new_lane_fea)] 322 | new_lane_xy = old_lane_info["xyz"] 323 | if new_lane_xy.shape[0] > 1: 324 | new_lane_xy = interpolate_polyline(new_lane_xy, point_per_polyline) 325 | else: 326 | new_lane_xy = np.broadcast_to(new_lane_xy, (point_per_polyline, 3)) 327 | new_lane_fea.append({"xyz":new_lane_xy, "speed_limit":old_lane_info["speed_limit"], "type":old_lane_info["type"], "left":[], "right":[], "prev":[], "follow":[] , "stop":[], "signal":[]}) 328 | else: 329 | num_of_new_lane = math.ceil(old_lane_info["xyz"].shape[0]/length_per_polyline) 330 | now_lanelet_new_index_lis = list(range(len(new_lane_fea), len(new_lane_fea)+num_of_new_lane)) 331 | old_lane_id_to_new_lane_index_lis[old_lane_id] = now_lanelet_new_index_lis 332 | new_lane_xy = [] 333 | for _ in range(num_of_new_lane-1): 334 | tmp = old_lane_info["xyz"][int(_*length_per_polyline):int(_*length_per_polyline+length_per_polyline+1)] 335 | new_lane_xy.append(tmp[::space, :]) 336 | tmp = old_lane_info["xyz"][int((num_of_new_lane-1)*length_per_polyline):] 337 | if tmp.shape[0] == 1: 338 | tmp = np.concatenate([old_lane_info["xyz"][int((num_of_new_lane-1)*length_per_polyline-1)][np.newaxis, :], tmp], axis=0) 339 | new_lane_xy.append(interpolate_polyline(tmp, point_per_polyline)) 340 | #tmp = tmp[::2, :] 341 | for _ in range(len(new_lane_xy)): 342 | new_lane_fea.append({"xyz":new_lane_xy[_], "speed_limit":old_lane_info["speed_limit"], "type":old_lane_info["type"], "left":[], "right":[], "prev":[], "follow":[], "stop":[], "signal":[]}) 343 | 344 | ## Update relations 345 | for old_lane_id, new_lane_lis in old_lane_id_to_new_lane_index_lis.items(): 346 | if len(new_lane_lis) > 0: 347 | for j in range(1, len(new_lane_lis)): 348 | prev_index = new_lane_lis[j-1] 349 | next_index = new_lane_lis[j] 350 | new_lane_fea[prev_index]["follow"].append([next_index, 0]) 351 | new_lane_fea[next_index]["prev"].append([prev_index, 1]) 352 | ## Follow 353 | tmp_index = new_lane_lis[-1] 354 | for old_adj_index in all_lane_fea[old_lane_id]["exit"]: 355 | new_lane_fea[tmp_index]["follow"].append([old_lane_id_to_new_lane_index_lis[old_adj_index][0], 0]) 356 | 357 | ## Prev 358 | tmp_index = new_lane_lis[0] 359 | for old_adj_index in all_lane_fea[old_lane_id]["entry"]: 360 | new_lane_fea[tmp_index]["prev"].append([old_lane_id_to_new_lane_index_lis[old_adj_index][-1], 1]) 361 | 362 | ## Left Right 363 | for edge_type in ["left", "right"]: 364 | old_adj_info_lis = all_lane_fea[old_lane_id][edge_type] 365 | ## ID, self_start, end, neighbor_start, end, type 366 | for old_adj_info in old_adj_info_lis: 367 | can_turn_new_lane_lis = new_lane_lis[int(old_adj_info[1]//length_per_polyline):int(old_adj_info[2]//length_per_polyline+1)] 368 | can_turn_new_adj_lane_lis = old_lane_id_to_new_lane_index_lis[old_adj_info[0]][int(old_adj_info[3]//length_per_polyline):int(old_adj_info[4]//length_per_polyline+1)] 369 | l1 = len(can_turn_new_lane_lis) 370 | l2 = len(can_turn_new_adj_lane_lis) 371 | boundary_type = old_adj_info[5] 372 | if l1 == l2: 373 | for tmp_index_i in range(l1): 374 | tmp_index = can_turn_new_lane_lis[tmp_index_i] 375 | new_lane_fea[tmp_index][edge_type].append([can_turn_new_adj_lane_lis[tmp_index_i], boundary_type+2]) 376 | elif l1 < l2: 377 | ratio = int(math.ceil(float(l2)/float(l1))) 378 | for tmp_index_i in range(l1): 379 | tmp_index = can_turn_new_lane_lis[tmp_index_i] 380 | ratio_index = 0 381 | gap = ratio - 1 382 | if l2%l1 == 0: 383 | gap += 1 384 | while ratio_index < ratio and ratio_index + tmp_index_i * gap < l2: 385 | new_lane_fea[tmp_index][edge_type].append([can_turn_new_adj_lane_lis[int(ratio_index + tmp_index_i * gap)], boundary_type+2]) 386 | ratio_index += 1 387 | elif l1 > l2: 388 | ratio = int(math.ceil(float(l1)/float(l2))) 389 | for adj_index_i in range(l2): 390 | tmp_adj_index = can_turn_new_adj_lane_lis[adj_index_i] 391 | ratio_index = 0 392 | gap = ratio - 1 393 | if l1%l2 == 0: 394 | gap += 1 395 | while ratio_index < ratio and ratio_index + adj_index_i * gap < l1: 396 | tmp_index = can_turn_new_lane_lis[ratio_index + adj_index_i * gap] 397 | new_lane_fea[tmp_index][edge_type].append([tmp_adj_index, boundary_type+2]) 398 | ratio_index += 1 399 | for stop_sign_id in all_stopsign_fea: 400 | new_relate_to_stop_sign_id_lis = [] 401 | for _ in all_stopsign_fea[stop_sign_id][0]: 402 | new_relate_to_stop_sign_id_lis += old_lane_id_to_new_lane_index_lis[_] 403 | for _ in new_relate_to_stop_sign_id_lis: 404 | new_lane_fea[_]["stop"].append(all_stopsign_fea[stop_sign_id][1]) 405 | for old_lane_id in all_dynamic_map_fea: 406 | new_lane_id_lis = old_lane_id_to_new_lane_index_lis[old_lane_id] 407 | for _ in new_lane_id_lis: 408 | new_lane_fea[_]["signal"].append(all_dynamic_map_fea[old_lane_id]) 409 | for _ in range(len(new_lane_fea)): 410 | new_lane_fea[_]["yaw"] = np.arctan2(new_lane_fea[_]["xyz"][-1, 1]-new_lane_fea[_]["xyz"][0, 1], new_lane_fea[_]["xyz"][-1, 0]-new_lane_fea[_]["xyz"][0, 0]) 411 | 412 | ##Split and Regularize Polygon fea 413 | ##20m per piece, 20 point 414 | new_polygon_fea = [] 415 | for polygon_index in range(len(all_polygon_fea)): 416 | if all_polygon_fea[polygon_index][0] not in [0, 1]: 417 | if len(all_polygon_fea[polygon_index][1]) > length_per_polyline: 418 | num_of_piece = int(len(all_polygon_fea[polygon_index][1]) // length_per_polyline + 1) 419 | length_per_piece = len(all_polygon_fea[polygon_index][1])//num_of_piece + 1 420 | for _ in range(num_of_piece): 421 | polygon_coor_of_current_piece = all_polygon_fea[polygon_index][1][int(_*length_per_piece):int((_+1)*length_per_piece)] 422 | if polygon_coor_of_current_piece.shape[0] > 1: 423 | new_polygon_fea.append([all_polygon_fea[polygon_index][0], polygon_coor_of_current_piece]) 424 | else: 425 | if all_polygon_fea[polygon_index][1].shape[0] > 1: 426 | new_polygon_fea.append(all_polygon_fea[polygon_index]) 427 | else: 428 | new_polygon_fea.append([all_polygon_fea[polygon_index][0], np.concatenate([all_polygon_fea[polygon_index][1], all_polygon_fea[polygon_index][1][0, :][np.newaxis, :]], axis=0)]) 429 | 430 | all_polygon_fea = [[_[0], interpolate_polyline(_[1], point_per_polyline)] for _ in new_polygon_fea] 431 | 432 | num_of_agent = all_agent_feature.shape[0] 433 | # ## Split Too Much Agent 434 | new_dist_between_agent_lane = (euclid(all_agent_feature[:, -1, :2][:, np.newaxis, np.newaxis, :], np.stack([_["xyz"] for _ in new_lane_fea])[np.newaxis, :, :, :]).min(2) < all_agent_map_size[:, np.newaxis]) 435 | if len(all_polygon_fea) > 0: 436 | new_dist_between_agent_polygon = (euclid(all_agent_feature[:, -1, [0,1]][:, np.newaxis, np.newaxis, :], np.stack([_[1] for _ in all_polygon_fea], axis=0)[np.newaxis, :, :, :]).min(2) < all_agent_map_size[:, np.newaxis]) 437 | 438 | 439 | lane_new_index_to_final_index = {} 440 | for agent_index_i in range(new_dist_between_agent_lane.shape[0]): 441 | nearby_lane_new_index_lis = np.where(new_dist_between_agent_lane[agent_index_i, :])[0].tolist() 442 | for nearby_lane_new_index in nearby_lane_new_index_lis: 443 | if nearby_lane_new_index not in lane_new_index_to_final_index: 444 | lane_new_index_to_final_index[nearby_lane_new_index] = len(lane_new_index_to_final_index) 445 | final_lane_fea = [{} for _ in range(len(lane_new_index_to_final_index))] 446 | for lane_new_index in lane_new_index_to_final_index: 447 | for transfer_key in ["xyz", "speed_limit", "type", "stop", "signal", "yaw"]: 448 | final_lane_fea[lane_new_index_to_final_index[lane_new_index]][transfer_key] = new_lane_fea[lane_new_index][transfer_key] 449 | for transfer_key in ["left", "right", "prev", "follow"]: 450 | final_lane_fea[lane_new_index_to_final_index[lane_new_index]][transfer_key] = [[lane_new_index_to_final_index[_[0]], _[1]] for _ in new_lane_fea[lane_new_index][transfer_key] if _[0] in lane_new_index_to_final_index] 451 | polygon_new_index_to_final_index = {} 452 | if len(all_polygon_fea) > 0: 453 | for agent_index_i in range(new_dist_between_agent_polygon.shape[0]): 454 | nearby_polygon_lis = np.where(new_dist_between_agent_polygon[agent_index_i, :])[0].tolist() 455 | for nearby_polygon_new_index in nearby_polygon_lis: 456 | if nearby_polygon_new_index not in polygon_new_index_to_final_index: 457 | polygon_new_index_to_final_index[nearby_polygon_new_index] = len(polygon_new_index_to_final_index) 458 | final_polygon_fea = [[] for _ in range(len(polygon_new_index_to_final_index))] 459 | for polygon_new_index in polygon_new_index_to_final_index: 460 | final_polygon_fea[polygon_new_index_to_final_index[polygon_new_index]] = all_polygon_fea[polygon_new_index] 461 | 462 | ## Visualization 463 | # # os.environ['KMP_DUPLICATE_LIB_OK']= "True" 464 | # # from matplotlib import pyplot as plt 465 | # # plt.gca().axis('equal') 466 | # # plt.cla() 467 | # # plt.clf() 468 | 469 | # # for agent_index in range(len(all_agent_feature)): 470 | # # plt.plot(all_agent_feature[agent_index, :, 0], all_agent_feature[agent_index, :, 1], "blue", zorder=20) 471 | # # for lane_index in range(len(final_lane_fea)): 472 | # # plt.plot(final_lane_fea[lane_index]["xyz"][:, 0], final_lane_fea[lane_index]["xyz"][:, 1], "black") 473 | # # for polygon_index in range(len(final_polygon_fea)): 474 | # # plt.plot(final_polygon_fea[polygon_index][1][:, 0], final_polygon_fea[polygon_index][1][:, 1], "black") 475 | 476 | # # ### Change this to increase resolution 477 | # # plt.xlim(950, 1000) 478 | # # plt.ylim(-2150, -2200) 479 | # # visualization_key = "left" 480 | # # for lane_index in range(len(final_lane_fea)): 481 | # # if len(final_lane_fea[lane_index][visualization_key]) != 0: 482 | # # for neighbor_lane_info in final_lane_fea[lane_index][visualization_key]: 483 | # # plt.plot(final_lane_fea[lane_index]["xyz"][point_per_polyline//2, 0], final_lane_fea[lane_index]["xyz"][point_per_polyline//2, 1], final_lane_fea[neighbor_lane_info[0]]["xyz"][point_per_polyline//2, 0], final_lane_fea[neighbor_lane_info[0]]["xyz"][point_per_polyline//2, 1], marker="o", c="red") 484 | # # import ipdb 485 | # # ipdb.set_trace() 486 | # # plt.savefig("tmp.png") 487 | 488 | all_data = {} 489 | all_data["fname"] = fname 490 | all_data["agent_feature"] = all_agent_feature 491 | all_data["label"] = np.concatenate(all_label, axis=0) 492 | all_data["auxiliary_label"] = np.concatenate(all_auxiliary_label, axis=0) 493 | all_data["label_mask"] = np.stack(all_label_mask, axis=0) 494 | all_data["difficulty"] = now_difficulty 495 | 496 | all_data["pred_num"] = pred_num 497 | if len(all_other_label) != 0: 498 | all_other_label = np.concatenate(all_other_label, axis=0) 499 | all_other_label_mask = np.stack(all_other_label_mask, axis=0) 500 | all_data["other_label"] = all_other_label 501 | all_data["other_label_mask"] = all_other_label_mask 502 | all_data["obejct_id_lis"] = np.array(all_object_id) 503 | all_data["scene_id"] = now_scene_id 504 | all_data["agent_type"] = all_agent_type 505 | all_data["map_fea"] = [final_lane_fea, final_polygon_fea] 506 | 507 | with open(os.path.join(saving_folder, data_version + "_case"+str(case_cnt)+".pkl"), "wb") as g: 508 | pickle.dump(all_data, g) 509 | 510 | 511 | num_of_element_lis.append(all_agent_feature.shape[0]+len(final_lane_fea)+len(final_polygon_fea)) 512 | del all_data 513 | gc.collect() 514 | if case_cnt % 10000 == 0: 515 | print(data_split_type, case_cnt, "done", flush=True) 516 | case_cnt += 1 517 | 518 | with open(os.path.join(saving_folder, data_version+"_number_of_case.pkl"), "wb") as g: 519 | pickle.dump(np.array(num_of_element_lis), g) 520 | 521 | print("Scene Index:", current_scene_index, "Data Split:", data_split_type, "Done!!!") 522 | 523 | -------------------------------------------------------------------------------- /preprocess/scenario_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: scenario.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import message as _message 7 | from google.protobuf import reflection as _reflection 8 | from google.protobuf import symbol_database as _symbol_database 9 | # @@protoc_insertion_point(imports) 10 | 11 | _sym_db = _symbol_database.Default() 12 | 13 | 14 | import map_pb2 as map__pb2 15 | 16 | 17 | DESCRIPTOR = _descriptor.FileDescriptor( 18 | name='scenario.proto', 19 | package='waymo.open_dataset', 20 | syntax='proto2', 21 | serialized_options=None, 22 | create_key=_descriptor._internal_create_key, 23 | serialized_pb=b'\n\x0escenario.proto\x12\x12waymo.open_dataset\x1a\tmap.proto\"\xba\x01\n\x0bObjectState\x12\x10\n\x08\x63\x65nter_x\x18\x02 \x01(\x01\x12\x10\n\x08\x63\x65nter_y\x18\x03 \x01(\x01\x12\x10\n\x08\x63\x65nter_z\x18\x04 \x01(\x01\x12\x0e\n\x06length\x18\x05 \x01(\x02\x12\r\n\x05width\x18\x06 \x01(\x02\x12\x0e\n\x06height\x18\x07 \x01(\x02\x12\x0f\n\x07heading\x18\x08 \x01(\x02\x12\x12\n\nvelocity_x\x18\t \x01(\x02\x12\x12\n\nvelocity_y\x18\n \x01(\x02\x12\r\n\x05valid\x18\x0b \x01(\x08\"\xe6\x01\n\x05Track\x12\n\n\x02id\x18\x01 \x01(\x05\x12\x39\n\x0bobject_type\x18\x02 \x01(\x0e\x32$.waymo.open_dataset.Track.ObjectType\x12/\n\x06states\x18\x03 \x03(\x0b\x32\x1f.waymo.open_dataset.ObjectState\"e\n\nObjectType\x12\x0e\n\nTYPE_UNSET\x10\x00\x12\x10\n\x0cTYPE_VEHICLE\x10\x01\x12\x13\n\x0fTYPE_PEDESTRIAN\x10\x02\x12\x10\n\x0cTYPE_CYCLIST\x10\x03\x12\x0e\n\nTYPE_OTHER\x10\x04\"R\n\x0f\x44ynamicMapState\x12?\n\x0blane_states\x18\x01 \x03(\x0b\x32*.waymo.open_dataset.TrafficSignalLaneState\"\xac\x01\n\x12RequiredPrediction\x12\x13\n\x0btrack_index\x18\x01 \x01(\x05\x12J\n\ndifficulty\x18\x02 \x01(\x0e\x32\x36.waymo.open_dataset.RequiredPrediction.DifficultyLevel\"5\n\x0f\x44ifficultyLevel\x12\x08\n\x04NONE\x10\x00\x12\x0b\n\x07LEVEL_1\x10\x01\x12\x0b\n\x07LEVEL_2\x10\x02\"\xf8\x02\n\x08Scenario\x12\x13\n\x0bscenario_id\x18\x05 \x01(\t\x12\x1a\n\x12timestamps_seconds\x18\x01 \x03(\x01\x12\x1a\n\x12\x63urrent_time_index\x18\n \x01(\x05\x12)\n\x06tracks\x18\x02 \x03(\x0b\x32\x19.waymo.open_dataset.Track\x12?\n\x12\x64ynamic_map_states\x18\x07 \x03(\x0b\x32#.waymo.open_dataset.DynamicMapState\x12\x34\n\x0cmap_features\x18\x08 \x03(\x0b\x32\x1e.waymo.open_dataset.MapFeature\x12\x17\n\x0fsdc_track_index\x18\x06 \x01(\x05\x12\x1b\n\x13objects_of_interest\x18\x04 \x03(\x05\x12\x41\n\x11tracks_to_predict\x18\x0b \x03(\x0b\x32&.waymo.open_dataset.RequiredPredictionJ\x04\x08\t\x10\n' 24 | , 25 | dependencies=[map__pb2.DESCRIPTOR,]) 26 | 27 | 28 | 29 | _TRACK_OBJECTTYPE = _descriptor.EnumDescriptor( 30 | name='ObjectType', 31 | full_name='waymo.open_dataset.Track.ObjectType', 32 | filename=None, 33 | file=DESCRIPTOR, 34 | create_key=_descriptor._internal_create_key, 35 | values=[ 36 | _descriptor.EnumValueDescriptor( 37 | name='TYPE_UNSET', index=0, number=0, 38 | serialized_options=None, 39 | type=None, 40 | create_key=_descriptor._internal_create_key), 41 | _descriptor.EnumValueDescriptor( 42 | name='TYPE_VEHICLE', index=1, number=1, 43 | serialized_options=None, 44 | type=None, 45 | create_key=_descriptor._internal_create_key), 46 | _descriptor.EnumValueDescriptor( 47 | name='TYPE_PEDESTRIAN', index=2, number=2, 48 | serialized_options=None, 49 | type=None, 50 | create_key=_descriptor._internal_create_key), 51 | _descriptor.EnumValueDescriptor( 52 | name='TYPE_CYCLIST', index=3, number=3, 53 | serialized_options=None, 54 | type=None, 55 | create_key=_descriptor._internal_create_key), 56 | _descriptor.EnumValueDescriptor( 57 | name='TYPE_OTHER', index=4, number=4, 58 | serialized_options=None, 59 | type=None, 60 | create_key=_descriptor._internal_create_key), 61 | ], 62 | containing_type=None, 63 | serialized_options=None, 64 | serialized_start=368, 65 | serialized_end=469, 66 | ) 67 | _sym_db.RegisterEnumDescriptor(_TRACK_OBJECTTYPE) 68 | 69 | _REQUIREDPREDICTION_DIFFICULTYLEVEL = _descriptor.EnumDescriptor( 70 | name='DifficultyLevel', 71 | full_name='waymo.open_dataset.RequiredPrediction.DifficultyLevel', 72 | filename=None, 73 | file=DESCRIPTOR, 74 | create_key=_descriptor._internal_create_key, 75 | values=[ 76 | _descriptor.EnumValueDescriptor( 77 | name='NONE', index=0, number=0, 78 | serialized_options=None, 79 | type=None, 80 | create_key=_descriptor._internal_create_key), 81 | _descriptor.EnumValueDescriptor( 82 | name='LEVEL_1', index=1, number=1, 83 | serialized_options=None, 84 | type=None, 85 | create_key=_descriptor._internal_create_key), 86 | _descriptor.EnumValueDescriptor( 87 | name='LEVEL_2', index=2, number=2, 88 | serialized_options=None, 89 | type=None, 90 | create_key=_descriptor._internal_create_key), 91 | ], 92 | containing_type=None, 93 | serialized_options=None, 94 | serialized_start=675, 95 | serialized_end=728, 96 | ) 97 | _sym_db.RegisterEnumDescriptor(_REQUIREDPREDICTION_DIFFICULTYLEVEL) 98 | 99 | 100 | _OBJECTSTATE = _descriptor.Descriptor( 101 | name='ObjectState', 102 | full_name='waymo.open_dataset.ObjectState', 103 | filename=None, 104 | file=DESCRIPTOR, 105 | containing_type=None, 106 | create_key=_descriptor._internal_create_key, 107 | fields=[ 108 | _descriptor.FieldDescriptor( 109 | name='center_x', full_name='waymo.open_dataset.ObjectState.center_x', index=0, 110 | number=2, type=1, cpp_type=5, label=1, 111 | has_default_value=False, default_value=float(0), 112 | message_type=None, enum_type=None, containing_type=None, 113 | is_extension=False, extension_scope=None, 114 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 115 | _descriptor.FieldDescriptor( 116 | name='center_y', full_name='waymo.open_dataset.ObjectState.center_y', index=1, 117 | number=3, type=1, cpp_type=5, label=1, 118 | has_default_value=False, default_value=float(0), 119 | message_type=None, enum_type=None, containing_type=None, 120 | is_extension=False, extension_scope=None, 121 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 122 | _descriptor.FieldDescriptor( 123 | name='center_z', full_name='waymo.open_dataset.ObjectState.center_z', index=2, 124 | number=4, type=1, cpp_type=5, label=1, 125 | has_default_value=False, default_value=float(0), 126 | message_type=None, enum_type=None, containing_type=None, 127 | is_extension=False, extension_scope=None, 128 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 129 | _descriptor.FieldDescriptor( 130 | name='length', full_name='waymo.open_dataset.ObjectState.length', index=3, 131 | number=5, type=2, cpp_type=6, label=1, 132 | has_default_value=False, default_value=float(0), 133 | message_type=None, enum_type=None, containing_type=None, 134 | is_extension=False, extension_scope=None, 135 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 136 | _descriptor.FieldDescriptor( 137 | name='width', full_name='waymo.open_dataset.ObjectState.width', index=4, 138 | number=6, type=2, cpp_type=6, label=1, 139 | has_default_value=False, default_value=float(0), 140 | message_type=None, enum_type=None, containing_type=None, 141 | is_extension=False, extension_scope=None, 142 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 143 | _descriptor.FieldDescriptor( 144 | name='height', full_name='waymo.open_dataset.ObjectState.height', index=5, 145 | number=7, type=2, cpp_type=6, label=1, 146 | has_default_value=False, default_value=float(0), 147 | message_type=None, enum_type=None, containing_type=None, 148 | is_extension=False, extension_scope=None, 149 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 150 | _descriptor.FieldDescriptor( 151 | name='heading', full_name='waymo.open_dataset.ObjectState.heading', index=6, 152 | number=8, type=2, cpp_type=6, label=1, 153 | has_default_value=False, default_value=float(0), 154 | message_type=None, enum_type=None, containing_type=None, 155 | is_extension=False, extension_scope=None, 156 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 157 | _descriptor.FieldDescriptor( 158 | name='velocity_x', full_name='waymo.open_dataset.ObjectState.velocity_x', index=7, 159 | number=9, type=2, cpp_type=6, label=1, 160 | has_default_value=False, default_value=float(0), 161 | message_type=None, enum_type=None, containing_type=None, 162 | is_extension=False, extension_scope=None, 163 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 164 | _descriptor.FieldDescriptor( 165 | name='velocity_y', full_name='waymo.open_dataset.ObjectState.velocity_y', index=8, 166 | number=10, type=2, cpp_type=6, label=1, 167 | has_default_value=False, default_value=float(0), 168 | message_type=None, enum_type=None, containing_type=None, 169 | is_extension=False, extension_scope=None, 170 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 171 | _descriptor.FieldDescriptor( 172 | name='valid', full_name='waymo.open_dataset.ObjectState.valid', index=9, 173 | number=11, type=8, cpp_type=7, label=1, 174 | has_default_value=False, default_value=False, 175 | message_type=None, enum_type=None, containing_type=None, 176 | is_extension=False, extension_scope=None, 177 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 178 | ], 179 | extensions=[ 180 | ], 181 | nested_types=[], 182 | enum_types=[ 183 | ], 184 | serialized_options=None, 185 | is_extendable=False, 186 | syntax='proto2', 187 | extension_ranges=[], 188 | oneofs=[ 189 | ], 190 | serialized_start=50, 191 | serialized_end=236, 192 | ) 193 | 194 | 195 | _TRACK = _descriptor.Descriptor( 196 | name='Track', 197 | full_name='waymo.open_dataset.Track', 198 | filename=None, 199 | file=DESCRIPTOR, 200 | containing_type=None, 201 | create_key=_descriptor._internal_create_key, 202 | fields=[ 203 | _descriptor.FieldDescriptor( 204 | name='id', full_name='waymo.open_dataset.Track.id', index=0, 205 | number=1, type=5, cpp_type=1, label=1, 206 | has_default_value=False, default_value=0, 207 | message_type=None, enum_type=None, containing_type=None, 208 | is_extension=False, extension_scope=None, 209 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 210 | _descriptor.FieldDescriptor( 211 | name='object_type', full_name='waymo.open_dataset.Track.object_type', index=1, 212 | number=2, type=14, cpp_type=8, label=1, 213 | has_default_value=False, default_value=0, 214 | message_type=None, enum_type=None, containing_type=None, 215 | is_extension=False, extension_scope=None, 216 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 217 | _descriptor.FieldDescriptor( 218 | name='states', full_name='waymo.open_dataset.Track.states', index=2, 219 | number=3, type=11, cpp_type=10, label=3, 220 | has_default_value=False, default_value=[], 221 | message_type=None, enum_type=None, containing_type=None, 222 | is_extension=False, extension_scope=None, 223 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 224 | ], 225 | extensions=[ 226 | ], 227 | nested_types=[], 228 | enum_types=[ 229 | _TRACK_OBJECTTYPE, 230 | ], 231 | serialized_options=None, 232 | is_extendable=False, 233 | syntax='proto2', 234 | extension_ranges=[], 235 | oneofs=[ 236 | ], 237 | serialized_start=239, 238 | serialized_end=469, 239 | ) 240 | 241 | 242 | _DYNAMICMAPSTATE = _descriptor.Descriptor( 243 | name='DynamicMapState', 244 | full_name='waymo.open_dataset.DynamicMapState', 245 | filename=None, 246 | file=DESCRIPTOR, 247 | containing_type=None, 248 | create_key=_descriptor._internal_create_key, 249 | fields=[ 250 | _descriptor.FieldDescriptor( 251 | name='lane_states', full_name='waymo.open_dataset.DynamicMapState.lane_states', index=0, 252 | number=1, type=11, cpp_type=10, label=3, 253 | has_default_value=False, default_value=[], 254 | message_type=None, enum_type=None, containing_type=None, 255 | is_extension=False, extension_scope=None, 256 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 257 | ], 258 | extensions=[ 259 | ], 260 | nested_types=[], 261 | enum_types=[ 262 | ], 263 | serialized_options=None, 264 | is_extendable=False, 265 | syntax='proto2', 266 | extension_ranges=[], 267 | oneofs=[ 268 | ], 269 | serialized_start=471, 270 | serialized_end=553, 271 | ) 272 | 273 | 274 | _REQUIREDPREDICTION = _descriptor.Descriptor( 275 | name='RequiredPrediction', 276 | full_name='waymo.open_dataset.RequiredPrediction', 277 | filename=None, 278 | file=DESCRIPTOR, 279 | containing_type=None, 280 | create_key=_descriptor._internal_create_key, 281 | fields=[ 282 | _descriptor.FieldDescriptor( 283 | name='track_index', full_name='waymo.open_dataset.RequiredPrediction.track_index', index=0, 284 | number=1, type=5, cpp_type=1, label=1, 285 | has_default_value=False, default_value=0, 286 | message_type=None, enum_type=None, containing_type=None, 287 | is_extension=False, extension_scope=None, 288 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 289 | _descriptor.FieldDescriptor( 290 | name='difficulty', full_name='waymo.open_dataset.RequiredPrediction.difficulty', index=1, 291 | number=2, type=14, cpp_type=8, label=1, 292 | has_default_value=False, default_value=0, 293 | message_type=None, enum_type=None, containing_type=None, 294 | is_extension=False, extension_scope=None, 295 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 296 | ], 297 | extensions=[ 298 | ], 299 | nested_types=[], 300 | enum_types=[ 301 | _REQUIREDPREDICTION_DIFFICULTYLEVEL, 302 | ], 303 | serialized_options=None, 304 | is_extendable=False, 305 | syntax='proto2', 306 | extension_ranges=[], 307 | oneofs=[ 308 | ], 309 | serialized_start=556, 310 | serialized_end=728, 311 | ) 312 | 313 | 314 | _SCENARIO = _descriptor.Descriptor( 315 | name='Scenario', 316 | full_name='waymo.open_dataset.Scenario', 317 | filename=None, 318 | file=DESCRIPTOR, 319 | containing_type=None, 320 | create_key=_descriptor._internal_create_key, 321 | fields=[ 322 | _descriptor.FieldDescriptor( 323 | name='scenario_id', full_name='waymo.open_dataset.Scenario.scenario_id', index=0, 324 | number=5, type=9, cpp_type=9, label=1, 325 | has_default_value=False, default_value=b"".decode('utf-8'), 326 | message_type=None, enum_type=None, containing_type=None, 327 | is_extension=False, extension_scope=None, 328 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 329 | _descriptor.FieldDescriptor( 330 | name='timestamps_seconds', full_name='waymo.open_dataset.Scenario.timestamps_seconds', index=1, 331 | number=1, type=1, cpp_type=5, label=3, 332 | has_default_value=False, default_value=[], 333 | message_type=None, enum_type=None, containing_type=None, 334 | is_extension=False, extension_scope=None, 335 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 336 | _descriptor.FieldDescriptor( 337 | name='current_time_index', full_name='waymo.open_dataset.Scenario.current_time_index', index=2, 338 | number=10, type=5, cpp_type=1, label=1, 339 | has_default_value=False, default_value=0, 340 | message_type=None, enum_type=None, containing_type=None, 341 | is_extension=False, extension_scope=None, 342 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 343 | _descriptor.FieldDescriptor( 344 | name='tracks', full_name='waymo.open_dataset.Scenario.tracks', index=3, 345 | number=2, type=11, cpp_type=10, label=3, 346 | has_default_value=False, default_value=[], 347 | message_type=None, enum_type=None, containing_type=None, 348 | is_extension=False, extension_scope=None, 349 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 350 | _descriptor.FieldDescriptor( 351 | name='dynamic_map_states', full_name='waymo.open_dataset.Scenario.dynamic_map_states', index=4, 352 | number=7, type=11, cpp_type=10, label=3, 353 | has_default_value=False, default_value=[], 354 | message_type=None, enum_type=None, containing_type=None, 355 | is_extension=False, extension_scope=None, 356 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 357 | _descriptor.FieldDescriptor( 358 | name='map_features', full_name='waymo.open_dataset.Scenario.map_features', index=5, 359 | number=8, type=11, cpp_type=10, label=3, 360 | has_default_value=False, default_value=[], 361 | message_type=None, enum_type=None, containing_type=None, 362 | is_extension=False, extension_scope=None, 363 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 364 | _descriptor.FieldDescriptor( 365 | name='sdc_track_index', full_name='waymo.open_dataset.Scenario.sdc_track_index', index=6, 366 | number=6, type=5, cpp_type=1, label=1, 367 | has_default_value=False, default_value=0, 368 | message_type=None, enum_type=None, containing_type=None, 369 | is_extension=False, extension_scope=None, 370 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 371 | _descriptor.FieldDescriptor( 372 | name='objects_of_interest', full_name='waymo.open_dataset.Scenario.objects_of_interest', index=7, 373 | number=4, type=5, cpp_type=1, label=3, 374 | has_default_value=False, default_value=[], 375 | message_type=None, enum_type=None, containing_type=None, 376 | is_extension=False, extension_scope=None, 377 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 378 | _descriptor.FieldDescriptor( 379 | name='tracks_to_predict', full_name='waymo.open_dataset.Scenario.tracks_to_predict', index=8, 380 | number=11, type=11, cpp_type=10, label=3, 381 | has_default_value=False, default_value=[], 382 | message_type=None, enum_type=None, containing_type=None, 383 | is_extension=False, extension_scope=None, 384 | serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), 385 | ], 386 | extensions=[ 387 | ], 388 | nested_types=[], 389 | enum_types=[ 390 | ], 391 | serialized_options=None, 392 | is_extendable=False, 393 | syntax='proto2', 394 | extension_ranges=[], 395 | oneofs=[ 396 | ], 397 | serialized_start=731, 398 | serialized_end=1107, 399 | ) 400 | 401 | _TRACK.fields_by_name['object_type'].enum_type = _TRACK_OBJECTTYPE 402 | _TRACK.fields_by_name['states'].message_type = _OBJECTSTATE 403 | _TRACK_OBJECTTYPE.containing_type = _TRACK 404 | _DYNAMICMAPSTATE.fields_by_name['lane_states'].message_type = map__pb2._TRAFFICSIGNALLANESTATE 405 | _REQUIREDPREDICTION.fields_by_name['difficulty'].enum_type = _REQUIREDPREDICTION_DIFFICULTYLEVEL 406 | _REQUIREDPREDICTION_DIFFICULTYLEVEL.containing_type = _REQUIREDPREDICTION 407 | _SCENARIO.fields_by_name['tracks'].message_type = _TRACK 408 | _SCENARIO.fields_by_name['dynamic_map_states'].message_type = _DYNAMICMAPSTATE 409 | _SCENARIO.fields_by_name['map_features'].message_type = map__pb2._MAPFEATURE 410 | _SCENARIO.fields_by_name['tracks_to_predict'].message_type = _REQUIREDPREDICTION 411 | DESCRIPTOR.message_types_by_name['ObjectState'] = _OBJECTSTATE 412 | DESCRIPTOR.message_types_by_name['Track'] = _TRACK 413 | DESCRIPTOR.message_types_by_name['DynamicMapState'] = _DYNAMICMAPSTATE 414 | DESCRIPTOR.message_types_by_name['RequiredPrediction'] = _REQUIREDPREDICTION 415 | DESCRIPTOR.message_types_by_name['Scenario'] = _SCENARIO 416 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 417 | 418 | ObjectState = _reflection.GeneratedProtocolMessageType('ObjectState', (_message.Message,), { 419 | 'DESCRIPTOR' : _OBJECTSTATE, 420 | '__module__' : 'scenario_pb2' 421 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.ObjectState) 422 | }) 423 | _sym_db.RegisterMessage(ObjectState) 424 | 425 | Track = _reflection.GeneratedProtocolMessageType('Track', (_message.Message,), { 426 | 'DESCRIPTOR' : _TRACK, 427 | '__module__' : 'scenario_pb2' 428 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.Track) 429 | }) 430 | _sym_db.RegisterMessage(Track) 431 | 432 | DynamicMapState = _reflection.GeneratedProtocolMessageType('DynamicMapState', (_message.Message,), { 433 | 'DESCRIPTOR' : _DYNAMICMAPSTATE, 434 | '__module__' : 'scenario_pb2' 435 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.DynamicMapState) 436 | }) 437 | _sym_db.RegisterMessage(DynamicMapState) 438 | 439 | RequiredPrediction = _reflection.GeneratedProtocolMessageType('RequiredPrediction', (_message.Message,), { 440 | 'DESCRIPTOR' : _REQUIREDPREDICTION, 441 | '__module__' : 'scenario_pb2' 442 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.RequiredPrediction) 443 | }) 444 | _sym_db.RegisterMessage(RequiredPrediction) 445 | 446 | Scenario = _reflection.GeneratedProtocolMessageType('Scenario', (_message.Message,), { 447 | 'DESCRIPTOR' : _SCENARIO, 448 | '__module__' : 'scenario_pb2' 449 | # @@protoc_insertion_point(class_scope:waymo.open_dataset.Scenario) 450 | }) 451 | _sym_db.RegisterMessage(Scenario) 452 | 453 | 454 | # @@protoc_insertion_point(module_scope) 455 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | joblib==1.1.0 2 | matplotlib==3.5.2 3 | numba==0.55.1 4 | numpy==1.21.5 5 | opencv_python==4.5.5.64 6 | pandas==1.4.2 7 | Pillow==9.2.0 8 | pyarrow==8.0.0 9 | pyproj==3.3.1 10 | rich==12.4.4 11 | scikit_learn==1.1.1 12 | scipy==1.6.2 13 | -------------------------------------------------------------------------------- /src/pipeline.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/HDGT/97eb8f45601a5c1b87b8f99c05a47751c7da4af8/src/pipeline.PNG -------------------------------------------------------------------------------- /training/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import dgl 6 | import math 7 | import dgl.function as fn 8 | from functools import partial 9 | import math 10 | class BN1D(nn.Module): 11 | def __init__(self, d_in): 12 | super().__init__() 13 | self.bn = nn.BatchNorm1d(d_in) 14 | def forward(self, x): 15 | if len(x.shape) == 3: 16 | return self.bn(x.transpose(1, 2)).transpose(1, 2) 17 | if len(x.shape) == 2: 18 | return self.bn(x.unsqueeze(-1)).squeeze(-1) 19 | 20 | class SEBlock(nn.Module): 21 | def __init__(self, channels): 22 | super().__init__() 23 | self.conv_se1 = nn.Conv1d(channels, channels, kernel_size=1, bias=True) 24 | self.conv_se2 = nn.Conv1d(channels, channels, kernel_size=1, bias=True) 25 | self.act = nn.ReLU(inplace=True) 26 | self.avg_pool = nn.AdaptiveAvgPool1d(1) 27 | self.sigmoid = torch.nn.Sigmoid() 28 | def forward(self, x): 29 | return self.sigmoid(self.conv_se2(self.act(self.conv_se1(self.avg_pool(x))))) * x 30 | 31 | class SEBasicBlock(nn.Module): 32 | def __init__(self, in_c, out_c, temporal_length, stride=1): 33 | super().__init__() 34 | self.in_c = in_c 35 | self.out_c = out_c 36 | self.conv1 = nn.Conv1d(kernel_size=3, in_channels=in_c, out_channels=out_c, stride=1, padding=1, bias=False) 37 | self.bn1 = nn.BatchNorm1d(in_c) 38 | self.conv2 = nn.Conv1d(kernel_size=3, in_channels=out_c, out_channels=out_c, stride=stride, padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm1d(out_c) 40 | self.act = nn.ReLU(inplace=True) 41 | 42 | self.stride = stride 43 | self.downsample = None 44 | self.se = SEBlock(in_c) 45 | if stride != 1: 46 | if temporal_length == 6 or temporal_length == 2: 47 | self.downsample = nn.Sequential(torch.nn.AvgPool1d(kernel_size=2, stride=2), nn.Conv1d(kernel_size=1, in_channels=in_c, out_channels=out_c, stride=1), nn.BatchNorm1d(out_c)) 48 | else: 49 | self.downsample = nn.Sequential(torch.nn.AvgPool1d(kernel_size=2, stride=2, padding=1, count_include_pad=False), nn.Conv1d(kernel_size=1, in_channels=in_c, out_channels=out_c, stride=1), nn.BatchNorm1d(out_c)) 50 | elif in_c != out_c: 51 | self.downsample = nn.Sequential(nn.Conv1d(kernel_size=1, in_channels=in_c, out_channels=out_c, stride=1), nn.BatchNorm1d(out_c)) 52 | 53 | def forward(self, x): 54 | identity = x 55 | if self.downsample is not None: 56 | identity = self.downsample(identity) 57 | out = self.conv1(self.act(self.bn1(x))) 58 | out = self.conv2(self.act(self.bn2(x))) 59 | if self.in_c == self.out_c: 60 | out = self.se(out) 61 | out = out + identity 62 | return out 63 | 64 | 65 | class TemporalBlock(nn.Module): 66 | def __init__(self, inplanes, temporal_length, args): 67 | super(TemporalBlock, self).__init__() 68 | self.layers = [SEBasicBlock(in_c=inplanes, out_c=inplanes, temporal_length=temporal_length, stride=1)] 69 | self.layers += [SEBasicBlock(in_c=inplanes, out_c=inplanes, temporal_length=temporal_length, stride=2)] 70 | self.layers = nn.ModuleList(self.layers) 71 | def forward(self, x): 72 | out = x 73 | for _ in range(len(self.layers)): 74 | out = self.layers[_](out) 75 | return out 76 | 77 | class AgentTemporalEncoder(nn.Module): 78 | def __init__(self, args): 79 | super().__init__() 80 | kernel_size_dic = {0:11, 1:6, 2:3, 3:2, 4:1, 5:1, 6:1, 7:1, 8:1, 9:1, 10:1} 81 | self.temporal_layers = torch.nn.ModuleList([TemporalBlock(args.hidden_dim, kernel_size_dic[_], args) for _ in range(2)]) 82 | def forward(self, feat): 83 | feat = self.temporal_layers[0](feat) 84 | for _ in range(1, len(self.temporal_layers)): 85 | feat = self.temporal_layers[_](feat) 86 | feat = feat[..., -1] ## Last timestep 87 | return feat 88 | 89 | class BN1D(nn.Module): 90 | def __init__(self, d_in): 91 | super().__init__() 92 | self.bn = nn.BatchNorm1d(d_in) 93 | def forward(self, x): 94 | if len(x.shape) == 3: 95 | return self.bn(x.transpose(1, 2)).transpose(1, 2) 96 | if len(x.shape) == 2: 97 | is_single_input = (x.shape[0] == 1 and self.training == True) 98 | if is_single_input: 99 | self.bn.eval() 100 | res = self.bn(x.unsqueeze(-1)).squeeze(-1) 101 | if is_single_input: 102 | self.bn.train() 103 | return res 104 | 105 | class MLP(nn.Module): 106 | def __init__(self, d_in, d_hid, d_out, norm=None, dropout=0.0, prenorm=False): 107 | super().__init__() 108 | self.w_1 = nn.Linear(d_in, d_hid) 109 | self.w_2 = nn.Linear(d_hid, d_out) # position-wise 110 | self.act = nn.ReLU(inplace=True) 111 | self.prenorm = prenorm 112 | if norm is None: 113 | self.norm = nn.Identity() 114 | elif prenorm: 115 | self.norm = norm(d_in) 116 | else: 117 | self.norm = norm(d_hid) 118 | #if dropout != 0.0: 119 | #self.dropout = nn.Dropout(dropout) 120 | #else: 121 | self.dropout = nn.Identity() 122 | def forward(self, x): 123 | if self.prenorm: 124 | output = self.dropout(self.w_2(self.act(self.w_1(self.norm(x))))) 125 | else: 126 | output = self.dropout(self.w_2(self.act(self.norm(self.w_1(x))))) 127 | return output 128 | 129 | 130 | class PointNet(nn.Module): 131 | def __init__(self, args): 132 | super().__init__() 133 | self.fc1 = MLP(d_in=args.hidden_dim//4, d_hid=args.hidden_dim//4, d_out=args.hidden_dim//8, norm=BN1D) 134 | self.fc2 = MLP(d_in=args.hidden_dim//8, d_hid=args.hidden_dim//8, d_out=args.hidden_dim//8, norm=BN1D) 135 | self.fc3 = MLP(d_in=args.hidden_dim//4, d_hid=args.hidden_dim//4, d_out=args.hidden_dim//4, norm=BN1D) 136 | def forward(self, x): 137 | out = self.fc1(x) 138 | out = torch.cat([out, self.fc2(out).max(-2)[0].unsqueeze(-2).repeat(1, x.shape[-2], 1)], dim=-1) 139 | out = torch.cat([out, self.fc3(out).max(-2)[0].unsqueeze(-2).repeat(1, x.shape[-2], 1)], dim=-1) 140 | return out 141 | 142 | class Agent2embedding(nn.Module): 143 | def __init__(self, input_dim, args): 144 | super().__init__() 145 | self.hidden_dim = args.hidden_dim 146 | self.act = nn.ReLU(inplace=True) 147 | ## -4 - minus coor and mask, args.hidden_dim//4 - coordinate feature 148 | self.fea_MLP = MLP(input_dim-4+args.hidden_dim//4, args.hidden_dim//2, args.hidden_dim, norm=BN1D) 149 | def forward(self, input_dic, shared_coor_encoder): 150 | #0-2 x, y, z, 3-4 vx, vy, 5-6 cos, sin, 7-9 witdth, length, height 10 mask 151 | x = input_dic["graph_lis"].ndata["a_n_fea"]["agent"] 152 | #z, vx, vy, cos, sin, width, length, height, mask 153 | coor_fea = shared_coor_encoder(x[..., :3]) 154 | fea = torch.cat([coor_fea, x[..., 3:-1]], dim=-1) 155 | fea = self.fea_MLP(fea) 156 | return fea.transpose(1, 2) 157 | 158 | ## Centerline (Lane) Embedding 159 | class Lane2embedding(nn.Module): 160 | def __init__(self, args): 161 | super().__init__() 162 | self.hidden_dim = args.hidden_dim 163 | self.act = nn.ReLU(inplace=True) 164 | ## Encoder Coordinate Information 165 | self.pointnet = PointNet(args) 166 | self.point_out_fc = nn.Linear(args.hidden_dim//2, args.hidden_dim) 167 | self.type_emb = torch.nn.Embedding(num_embeddings=4, embedding_dim=args.hidden_dim//4) 168 | self.stop_fc = nn.Linear(args.hidden_dim//4, args.hidden_dim//4) 169 | self.signal_fc = nn.Linear(args.hidden_dim//4, args.hidden_dim//4) 170 | self.signal_emb = torch.nn.Embedding(num_embeddings=9, embedding_dim=args.hidden_dim//2) 171 | self.signal_gru = torch.nn.GRU(input_size=args.hidden_dim//2, hidden_size=args.hidden_dim//2, num_layers=1, batch_first=True) 172 | 173 | ## [coor, type, stop_signal] 174 | self.lane_n_out_fc = MLP(d_in=args.hidden_dim+args.hidden_dim//4+self.hidden_dim//4*3+self.hidden_dim//2, d_hid=args.hidden_dim*4, d_out=args.hidden_dim, norm=BN1D) 175 | 176 | self.boudary_emb = torch.nn.Embedding(num_embeddings=11, embedding_dim=args.hidden_dim//4) 177 | ## [boundary, n_fea, rel_pos_fea] 178 | self.lane_e_out_fc = MLP(d_in=args.hidden_dim+args.hidden_dim//4+args.hidden_dim//4, d_hid=args.hidden_dim*4, d_out=args.hidden_dim, norm=BN1D) 179 | 180 | def forward(self, input_dic, shared_coor_encoder, shared_rel_encoder): 181 | coor_fea = input_dic["graph_lis"].ndata["l_n_coor_fea"]["lane"] 182 | coor_fea = shared_coor_encoder(coor_fea) 183 | coor_fea = self.pointnet(coor_fea) 184 | coor_fea = self.act(self.point_out_fc(coor_fea.max(dim=1)[0])) 185 | type_fea = self.type_emb(input_dic["graph_lis"].ndata["l_n_type_fea"]["lane"]) 186 | 187 | lane_n_num = coor_fea.shape[0] 188 | ## If there is no stop sign/traffic signal controlling the lane, the cooresponding features are all zeros 189 | stop_signal_fea = torch.zeros((lane_n_num, self.hidden_dim//4*3+self.hidden_dim//2), device="cuda:"+str(input_dic["gpu"])) 190 | if "lane_n_stop_sign_fea_lis" in input_dic: 191 | stop_sign_fea = input_dic["lane_n_stop_sign_fea_lis"] 192 | stop_sign_fea = shared_coor_encoder(stop_sign_fea) 193 | stop_sign_fea = self.act(self.stop_fc(stop_sign_fea)) 194 | stop_signal_fea[input_dic["lane_n_stop_sign_index_lis"]][..., :self.hidden_dim//4] += stop_sign_fea 195 | if "lane_n_signal_fea_lis" in input_dic: 196 | signal_fea = input_dic["lane_n_signal_fea_lis"] 197 | signal_coor_fea = self.act(self.signal_fc(shared_coor_encoder(signal_fea[..., :3]))) 198 | signal_dynamic = self.signal_emb(signal_fea[..., 3:].long()) 199 | signal_dynamic, _ = self.signal_gru(signal_dynamic) 200 | signal_fea = torch.cat([signal_coor_fea, self.act(signal_dynamic[:, -1, :])], dim=-1) 201 | stop_signal_fea[input_dic["lane_n_signal_index_lis"]][..., self.hidden_dim//4:self.hidden_dim//4*2+self.hidden_dim//2] += signal_fea 202 | 203 | output_n_fea = self.lane_n_out_fc(torch.cat([coor_fea, type_fea, stop_signal_fea], dim=-1)) 204 | 205 | ## Lane Edge Feature Encoding 206 | lane_e_num_lis_by_etype = np.cumsum([0] + [len(input_dic["graph_lis"].edata["l_e_fea"][_]) for _ in [("lane", "l2a", "agent"), ("lane", "left", "lane"), ("lane", "right", "lane"), ("lane", "prev", "lane"), ("lane", "follow", "lane")]]) 207 | lane_e_rel_pos = torch.cat([input_dic["graph_lis"].edata["l_e_fea"][_] for _ in [("lane", "l2a", "agent"), ("lane", "left", "lane"), ("lane", "right", "lane"), ("lane", "prev", "lane"), ("lane", "follow", "lane")]], dim=0) 208 | lane_e_rel_pos_fea = shared_rel_encoder(lane_e_rel_pos) 209 | lane_e_num = lane_e_rel_pos_fea.shape[0] 210 | lane_src_indices = torch.cat([input_dic["graph_lis"].edges(etype=_)[0] for _ in ["l2a", "left", "right", "prev", "follow"]], dim=0) 211 | lane_e_src_n_fea = output_n_fea[lane_src_indices] 212 | 213 | output_e_fea = torch.cat([torch.zeros((lane_e_num, self.hidden_dim//4), device="cuda:"+str(input_dic["gpu"])), lane_e_rel_pos_fea, lane_e_src_n_fea,], dim=-1) 214 | 215 | boudary_emb = torch.cat([input_dic["graph_lis"].edata["boundary_type"][_] for _ in [("lane", "left", "lane"), ("lane", "right", "lane"), ("lane", "prev", "lane"), ("lane", "follow", "lane")]], dim=0) 216 | boudary_emb = self.boudary_emb(boudary_emb) 217 | output_e_fea[lane_e_num_lis_by_etype[1]:, :self.hidden_dim//4] += boudary_emb 218 | output_e_fea = self.lane_e_out_fc(output_e_fea) 219 | 220 | input_dic["graph_lis"].ndata["l_n_hidden"] = {"lane":output_n_fea} 221 | for _index, _ in enumerate([("lane", "l2a", "agent"), ("lane", "left", "lane"), ("lane", "right", "lane"), ("lane", "prev", "lane"), ("lane", "follow", "lane")]): 222 | input_dic["graph_lis"].edata["l_e_hidden"] = {_:output_e_fea[lane_e_num_lis_by_etype[_index]:lane_e_num_lis_by_etype[_index+1]]} 223 | return None 224 | 225 | class Polygon2embedding(nn.Module): 226 | def __init__(self, args): 227 | super().__init__() 228 | self.hidden_dim = args.hidden_dim 229 | self.pointnet = PointNet(args) 230 | self.type_emb = torch.nn.Embedding(num_embeddings=14, embedding_dim=args.hidden_dim//2) 231 | self.out_fc = MLP(args.hidden_dim//2+args.hidden_dim//2, args.hidden_dim*4, args.hidden_dim, norm=BN1D) 232 | 233 | def forward(self, input_dic, shared_coor_encoder): 234 | coor_fea = input_dic["graph_lis"].edata['g2a_e_fea'][("polygon", "g2a", "agent")] 235 | coor_fea = shared_coor_encoder(coor_fea) 236 | coor_fea = self.pointnet(coor_fea).max(dim=1)[0] 237 | type_fea = input_dic["graph_lis"].edata['g2a_e_type'][("polygon", "g2a", "agent")] 238 | type_fea = self.type_emb(type_fea) 239 | fea = torch.cat([coor_fea, type_fea], axis=-1) 240 | fea = self.out_fc(fea) 241 | input_dic["graph_lis"].edata["g_e_hidden"] = {("polygon", "g2a", "agent"):fea} 242 | 243 | 244 | class ScaledDotProductAttention(torch.nn.Module): 245 | """ Scaled Dot-Product Attention """ 246 | def __init__(self, temperature, args): 247 | super().__init__() 248 | self.temperature = temperature 249 | self.dropout = torch.nn.Dropout(args.dropout) 250 | self.softmax = torch.nn.Softmax(dim=2) 251 | def forward(self, q, k, v, mask=None): 252 | attn = torch.bmm(q, k.transpose(1, 2)) 253 | n = float(v.shape[1]) 254 | attn = attn / self.temperature * math.log(n+1, 32) 255 | if mask is not None: 256 | attn = attn.masked_fill(mask, -1e10) 257 | attn = self.softmax(attn) 258 | attn = self.dropout(attn) 259 | output = torch.bmm(attn, v) 260 | return output, attn 261 | 262 | class PositionwiseFeedForward(nn.Module): 263 | ''' A two-feed-forward-layer module ''' 264 | def __init__(self, d_in, d_hid, dropout, args): 265 | super().__init__() 266 | self.d_in = d_in 267 | self.out_dim = d_in 268 | self.norm = nn.LayerNorm(d_in) 269 | self.w1 = nn.Linear(d_in, d_hid) 270 | self.w2 = nn.Linear(d_hid, d_in) 271 | self.w3 = nn.Linear(d_in, d_hid) 272 | self.act = nn.SiLU() 273 | if dropout != 0.0: 274 | self.dropout = nn.Dropout(dropout) 275 | else: 276 | self.dropout = nn.Identity() 277 | def forward(self, x): 278 | residual = x 279 | output = self.norm(x) 280 | output = self.w2(self.act(self.w1(output)) * self.w3(output)) 281 | output = self.dropout(output) + residual 282 | return output 283 | 284 | 285 | 286 | class LaneHetGNN(nn.Module): 287 | def __init__(self, args): 288 | super().__init__() 289 | self.etype_num = 5 290 | self.norm = nn.LayerNorm 291 | self.node_mlp = nn.ModuleDict({ 292 | _:MLP(args.hidden_dim, args.hidden_dim, args.hidden_dim, norm=nn.LayerNorm, prenorm=True) 293 | for _ in ["left", "right", "prev", "follow"]}) 294 | 295 | self.node_mlp["a2l"] = nn.ModuleList([MLP(args.hidden_dim, args.hidden_dim, args.hidden_dim, norm=nn.LayerNorm, prenorm=True) 296 | for _ in range(3)]) 297 | 298 | self.node_fc = nn.Linear(self.etype_num*args.hidden_dim, args.hidden_dim) 299 | self.node_ffn = PositionwiseFeedForward(args.hidden_dim, args.hidden_dim*4, args.dropout, args) 300 | 301 | self.etype_dic = {} 302 | for etype in ["left", "right", "prev", "follow", "a2l"][:self.etype_num]: 303 | self.etype_dic[etype] = (partial(self.message_func, etype=etype), partial(self.reduce_func, etype=etype)) 304 | 305 | self.edge_MLP = MLP(args.hidden_dim*3, args.hidden_dim*4, args.hidden_dim, norm=nn.LayerNorm, dropout=args.dropout, prenorm=True) 306 | 307 | self.agent_edge_MLP = nn.ModuleList([MLP(args.hidden_dim*3, args.hidden_dim*4, args.hidden_dim, norm=nn.LayerNorm, dropout=args.dropout, prenorm=True) 308 | for _ in range(3) 309 | ]) 310 | 311 | def forward(self, input_dic): 312 | #lane_n_fea = input_dic["graph_lis"].ndata["l_n_hidden"]["lane"] 313 | with input_dic["graph_lis"].local_scope(): 314 | self.gpu = input_dic["gpu"] 315 | self.a_e_type_dict = input_dic["a_e_type_dict"] 316 | 317 | input_dic["graph_lis"].multi_update_all(etype_dict=self.etype_dic, cross_reducer="stack") ## Stack all features of all types of in-edges 318 | output_lane_n_fea = input_dic["graph_lis"].ndata["l_n_hidden_out"]["lane"] 319 | 320 | output_lane_n_fea = self.node_fc(output_lane_n_fea.view(output_lane_n_fea.shape[0], -1)) + input_dic["graph_lis"].ndata["l_n_hidden"]["lane"] 321 | output_lane_n_fea = self.node_ffn(output_lane_n_fea) 322 | 323 | output_e_fea = [self.edge_MLP(torch.cat([input_dic["graph_lis"].ndata["l_n_hidden"]["lane"][input_dic["graph_lis"].edges(etype=_)[0]], input_dic["graph_lis"].edges[_].data["l_e_hidden"], input_dic["graph_lis"].ndata["l_n_hidden"]["lane"][input_dic["graph_lis"].edges(etype=_)[1]]], dim=-1))+input_dic["graph_lis"].edges[_].data["l_e_hidden"] for _ in ["left", "right", "prev", "follow"]] 324 | 325 | agent2lane_e_fea = torch.cat([input_dic["graph_lis"].ndata["a_n_hidden"]["agent"][input_dic["graph_lis"].edges(etype="a2l")[0]], input_dic["graph_lis"].edges["a2l"].data["a_e_hidden"], input_dic["graph_lis"].ndata["l_n_hidden"]["lane"][input_dic["graph_lis"].edges(etype="a2l")[1]]], dim=-1) 326 | output_agent2lane_e_fea = torch.zeros_like(input_dic["graph_lis"].edges["a2l"].data["a_e_hidden"]) 327 | for agent_type_index in range(3): 328 | if len(self.a_e_type_dict["a2l"][agent_type_index]) != 0: 329 | output_agent2lane_e_fea[self.a_e_type_dict["a2l"][agent_type_index]] = self.agent_edge_MLP[agent_type_index](agent2lane_e_fea[self.a_e_type_dict["a2l"][agent_type_index]]) +input_dic["graph_lis"].edges["a2l"].data["a_e_hidden"][self.a_e_type_dict["a2l"][agent_type_index]] 330 | output_e_fea.append(output_agent2lane_e_fea) 331 | return output_lane_n_fea, output_e_fea 332 | 333 | def message_func(self, edges, etype): 334 | if etype != "a2l": 335 | return {"l_e_hidden_"+etype:self.node_mlp[etype](edges.data["l_e_hidden"])} 336 | else: 337 | tmp_out = torch.zeros_like((edges.data["a_e_hidden"])) 338 | tmp_input = edges.data["a_e_hidden"] 339 | for agent_type_index in range(3): 340 | if len(self.a_e_type_dict[etype][agent_type_index]) != 0: 341 | tmp_out[self.a_e_type_dict[etype][agent_type_index]] = self.node_mlp[etype][agent_type_index](tmp_input[self.a_e_type_dict[etype][agent_type_index]]) 342 | return {"l_e_hidden_"+etype:tmp_out} 343 | 344 | def reduce_func(self, nodes, etype): 345 | return {"l_n_hidden_out":nodes.mailbox["l_e_hidden_"+etype].max(dim=1)[0]} 346 | 347 | 348 | 349 | 350 | class AgentHetGNN(nn.Module): 351 | def __init__(self, args): 352 | super().__init__() 353 | self.norm = nn.LayerNorm 354 | self.hidden_dim = args.hidden_dim 355 | self.head_dim = args.head_dim 356 | self.d_model = args.hidden_dim 357 | self.n_head = self.d_model//args.head_dim 358 | d_k = self.head_dim 359 | d_v = self.head_dim 360 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), args=args) 361 | 362 | self.wqs = nn.ModuleList([nn.Sequential( 363 | self.norm(self.d_model), 364 | nn.Linear(self.d_model, self.n_head * d_k * 3, bias=False), 365 | ) 366 | for _ in range(3) 367 | ]) 368 | 369 | self.wkvs = nn.ModuleDict({ 370 | _: nn.Sequential( 371 | self.norm(self.d_model), 372 | nn.Linear(self.d_model, self.n_head * d_k * 2, bias=False) 373 | ) 374 | for _ in ["l2a", "g2a"] 375 | }) 376 | 377 | self.wkvs["other"] = nn.ModuleList([ 378 | nn.Sequential( 379 | self.norm(self.d_model), 380 | nn.Linear(self.d_model, self.n_head * d_k * 2, bias=False) 381 | ) 382 | for _ in range(3) 383 | ]) 384 | 385 | self.attn_fcs = nn.ModuleList([nn.ModuleDict({ 386 | _:nn.Sequential( 387 | nn.Linear(self.n_head * d_v, self.d_model, bias=True), 388 | nn.ReLU(inplace=True) 389 | ) 390 | for _ in ["l2a", "g2a", "other"] 391 | }) 392 | for agent_type_index in range(3)] 393 | ) 394 | 395 | 396 | self.self_fc = nn.ModuleList([ 397 | nn.Sequential( 398 | nn.Linear(self.hidden_dim, self.hidden_dim, bias=True), 399 | nn.ReLU(inplace=True) 400 | ) 401 | for _ in range(3)] 402 | ) 403 | 404 | self.out_fc = nn.ModuleList([nn.Sequential( 405 | nn.Linear(self.hidden_dim*4, self.hidden_dim, bias=True), 406 | nn.Dropout(args.dropout) if args.dropout > 0.0 else nn.Identity() 407 | ) 408 | for _ in range(3)]) 409 | 410 | self.out_ffn = nn.ModuleList([PositionwiseFeedForward(args.hidden_dim, args.hidden_dim*4, args.dropout, nn.LayerNorm) 411 | for _ in range(3) 412 | ]) 413 | self.etype_dic = {} 414 | for etype in ["other", "l2a", "g2a"]: 415 | self.etype_dic[etype] = (partial(self.message_func, etype=etype), partial(self.reduce_func, etype=etype)) 416 | self.etype2hidden_name = {"other":"a_e_hidden", "l2a":"l_e_hidden", "g2a":"g_e_hidden"} 417 | self.etype2src_name = {"other":"agent", "l2a":"lane", "g2a":"polygon"} 418 | 419 | self.edge_MLP = { 420 | "l2a":MLP(args.hidden_dim*3, args.hidden_dim*4, args.hidden_dim, nn.LayerNorm, dropout=args.dropout, prenorm=True), 421 | "g2a":MLP(args.hidden_dim*2, args.hidden_dim*4, args.hidden_dim, nn.LayerNorm, dropout=args.dropout, prenorm=True), 422 | } 423 | self.edge_MLP["self"] = nn.ModuleList([MLP(args.hidden_dim, args.hidden_dim, args.hidden_dim, nn.LayerNorm, dropout=args.dropout, prenorm=True) 424 | for _ in range(3) 425 | ]) 426 | 427 | self.edge_MLP["other"] = nn.ModuleList([MLP(args.hidden_dim*3, args.hidden_dim*4, args.hidden_dim, nn.LayerNorm, dropout=args.dropout, prenorm=True) 428 | for _ in range(3) 429 | ]) 430 | self.edge_MLP = nn.ModuleDict(self.edge_MLP) 431 | 432 | def forward(self, input_dic): 433 | self.a_e_type_dict = input_dic["a_e_type_dict"] 434 | self.a_n_type_lis = input_dic["a_n_type_lis"] 435 | with input_dic["graph_lis"].local_scope(): 436 | self.device = "cuda:"+str(input_dic["gpu"]) 437 | 438 | input_q = input_dic["graph_lis"].ndata["a_n_hidden"]["agent"] 439 | q = torch.zeros((input_dic["graph_lis"].ndata["a_n_hidden"]["agent"].shape[0], self.n_head * self.head_dim * 3), device=self.device) 440 | for agent_type_index in range(3): 441 | if len(self.a_n_type_lis[agent_type_index]) != 0: 442 | q[self.a_n_type_lis[agent_type_index]] = self.wqs[agent_type_index](input_q[self.a_n_type_lis[agent_type_index]]) 443 | other_q, lane_q, polygon_q = q.view(input_q.shape[0], self.n_head, self.head_dim*3).split(dim=-1, split_size=self.head_dim) 444 | 445 | input_dic["graph_lis"].nodes["agent"].data["other_q"] = other_q 446 | input_dic["graph_lis"].nodes["agent"].data["l2a_q"] = lane_q 447 | input_dic["graph_lis"].nodes["agent"].data["g2a_q"] = polygon_q 448 | input_dic["graph_lis"].multi_update_all(etype_dict=self.etype_dic, cross_reducer="stack") 449 | all_out_fea = input_dic["graph_lis"].ndata["a_n_hidden_out"]["agent"] 450 | 451 | other_out_tmp_in = all_out_fea[:, 0, :] 452 | lane_out_fea_tmp_in = all_out_fea[:, 1, :] 453 | polygon_out_fea_tmp_in = all_out_fea[:, 2, :] 454 | input_self_fea = input_dic["graph_lis"].edges["self"].data["a_e_hidden"] 455 | all_out_n_fea = torch.zeros_like(input_dic["graph_lis"].nodes["agent"].data["a_n_hidden"]) 456 | for agent_type_index in range(3): 457 | if len(self.a_n_type_lis[agent_type_index]) != 0: 458 | other_out_fea = self.attn_fcs[agent_type_index]["other"](other_out_tmp_in[self.a_n_type_lis[agent_type_index]]) 459 | lane_out_fea = self.attn_fcs[agent_type_index]["l2a"](lane_out_fea_tmp_in[self.a_n_type_lis[agent_type_index]]) 460 | polygon_out_fea = self.attn_fcs[agent_type_index]["g2a"](polygon_out_fea_tmp_in[self.a_n_type_lis[agent_type_index]]) 461 | self_fea = self.self_fc[agent_type_index](input_self_fea[self.a_n_type_lis[agent_type_index]]) 462 | 463 | out_n_fea = torch.stack([self_fea, other_out_fea, lane_out_fea, polygon_out_fea], dim=1) 464 | out_n_fea = self.out_fc[agent_type_index](out_n_fea.view(out_n_fea.shape[0], -1)) + input_dic["graph_lis"].nodes["agent"].data["a_n_hidden"][self.a_n_type_lis[agent_type_index]] 465 | out_n_fea = self.out_ffn[agent_type_index](out_n_fea) 466 | 467 | all_out_n_fea[self.a_n_type_lis[agent_type_index]] = out_n_fea 468 | 469 | out_e_fea_lis = [] 470 | self_e_fea_tmp_in = input_dic["graph_lis"].edges["self"].data["a_e_hidden"] 471 | other_e_fea_tmp_in = torch.cat([input_dic["graph_lis"].ndata["a_n_hidden"]["agent"][input_dic["graph_lis"].edges(etype="other")[0]], input_dic["graph_lis"].edges["other"].data["a_e_hidden"], input_dic["graph_lis"].ndata["a_n_hidden"]["agent"][input_dic["graph_lis"].edges(etype="other")[1]]], dim=-1) 472 | 473 | self_e_fea_tmp_out = torch.zeros_like(self_e_fea_tmp_in) 474 | other_e_fea_tmp_out = torch.zeros((other_e_fea_tmp_in.shape[0], self.hidden_dim), device=self.device) 475 | for agent_type_index in range(3): 476 | if len(self.a_e_type_dict["self"][agent_type_index]) != 0: 477 | self_e_fea_tmp_out[self.a_e_type_dict["self"][agent_type_index]] = self.edge_MLP["self"][agent_type_index](self_e_fea_tmp_in[self.a_e_type_dict["self"][agent_type_index]]) 478 | if len(self.a_e_type_dict["other"][agent_type_index]) != 0: 479 | other_e_fea_tmp_out[self.a_e_type_dict["other"][agent_type_index]] = self.edge_MLP["other"][agent_type_index](other_e_fea_tmp_in[self.a_e_type_dict["other"][agent_type_index]]) 480 | out_e_fea_lis.append(self_e_fea_tmp_out + input_dic["graph_lis"].edges["self"].data["a_e_hidden"]) 481 | out_e_fea_lis.append(other_e_fea_tmp_out + input_dic["graph_lis"].edges["other"].data["a_e_hidden"]) 482 | 483 | l2a_out_e_fea = torch.cat([input_dic["graph_lis"].ndata["l_n_hidden"]["lane"][input_dic["graph_lis"].edges(etype="l2a")[0]], input_dic["graph_lis"].edges["l2a"].data["l_e_hidden"], input_dic["graph_lis"].ndata["a_n_hidden"]["agent"][input_dic["graph_lis"].edges(etype="l2a")[1]]], dim=-1) 484 | l2a_out_e_fea = self.edge_MLP["l2a"](l2a_out_e_fea) 485 | out_e_fea_lis.append(l2a_out_e_fea + input_dic["graph_lis"].edges["l2a"].data["l_e_hidden"]) 486 | 487 | g2a_out_e_fea = torch.cat([input_dic["graph_lis"].ndata["a_n_hidden"]["agent"][input_dic["graph_lis"].edges(etype="g2a")[1]], input_dic["graph_lis"].edges["g2a"].data["g_e_hidden"]], dim=-1) 488 | g2a_out_e_fea = self.edge_MLP["g2a"](g2a_out_e_fea) 489 | out_e_fea_lis.append(g2a_out_e_fea + input_dic["graph_lis"].edges["g2a"].data["g_e_hidden"]) 490 | return all_out_n_fea, out_e_fea_lis 491 | 492 | def message_func(self, edges, etype): 493 | if etype != "other": 494 | k, v = self.wkvs[etype](edges.data[self.etype2hidden_name[etype]]).view(-1, self.n_head, self.head_dim*2).split(dim=-1, split_size=self.head_dim) 495 | else: 496 | tmp_input = edges.data[self.etype2hidden_name[etype]] 497 | tmp_output = torch.zeros((tmp_input.shape[0], self.n_head * self.head_dim * 2), device=self.device) 498 | for agent_type_index in range(3): 499 | if len(self.a_e_type_dict["other"][agent_type_index]) != 0: 500 | tmp_output[self.a_e_type_dict["other"][agent_type_index]] = self.wkvs[etype][agent_type_index](tmp_input[self.a_e_type_dict["other"][agent_type_index]]) 501 | k, v = tmp_output.view(-1, self.n_head, self.head_dim*2).split(dim=-1, split_size=self.head_dim) 502 | return {etype+"_k":k, etype+"_v":v} 503 | 504 | def reduce_func(self, nodes, etype): 505 | node_num, neighbor_num, n_head, hidden = nodes.mailbox[etype+"_k"].shape 506 | q = nodes.data[etype+"_q"].view(node_num*self.n_head, -1).unsqueeze(1) 507 | k = nodes.mailbox[etype+"_k"].transpose(1, 2).reshape(node_num*self.n_head, neighbor_num, -1) 508 | v = nodes.mailbox[etype+"_v"].transpose(1, 2).reshape(node_num*self.n_head, neighbor_num, -1) 509 | output, attn = self.attention(q, k, v, mask=None) 510 | return {"a_n_hidden_out":output.view(node_num, -1)} 511 | 512 | 513 | 514 | class HDGT_encoder(nn.Module): 515 | def __init__(self, input_dim, args): 516 | super().__init__() 517 | self.hidden_dim = args.hidden_dim 518 | self.shared_coor_encoder = MLP(d_in=3, d_hid=args.hidden_dim//8, d_out=args.hidden_dim//4, norm=None) ## Encode x,y,z 519 | self.shared_rel_encoder = MLP(d_in=5, d_hid=args.hidden_dim//8, d_out=args.hidden_dim//4, norm=None) ## Encode Delta (x, y, z, cos(psi), sin(psi)) 520 | 521 | self.agent_emb = Agent2embedding(input_dim, args) 522 | self.temporal_encoders = torch.nn.ModuleList([AgentTemporalEncoder(args) for _ in range(3)]) ## different temporal encoder for different agent type 523 | ## Input: [Node Feature, Rel_Pos] 524 | self.agent_e_fea_MLPs = torch.nn.ModuleList([MLP(args.hidden_dim+args.hidden_dim//4, args.hidden_dim*4, args.hidden_dim, norm=BN1D) for _ in range(3)]) ## Inti Agent Edge Feature 525 | 526 | self.lane_emb = Lane2embedding(args) 527 | self.polygon_emb = Polygon2embedding(args) 528 | 529 | self.num_of_gnn_layer = args.num_of_gnn_layer 530 | self.lane_gnns = torch.nn.ModuleList([LaneHetGNN(args=args) for _ in range(self.num_of_gnn_layer)]) 531 | self.agent_gnns = torch.nn.ModuleList([AgentHetGNN(args=args) for _ in range(self.num_of_gnn_layer)]) 532 | 533 | 534 | def forward(self, input_dic): 535 | ## Init Agent Node 536 | agent_n_emb = self.agent_emb(input_dic, self.shared_coor_encoder) 537 | agent_n_type_indices = [torch.where((input_dic["graph_lis"].ndata["a_n_type"]["agent"]) == _) for _ in range(3)] 538 | agent_n_fea = torch.zeros((agent_n_emb.shape[0], agent_n_emb.shape[1]), device="cuda:"+str(input_dic["gpu"])) 539 | for _ in range(3): 540 | agent_n_fea[agent_n_type_indices[_]] = self.temporal_encoders[_](agent_n_emb[agent_n_type_indices[_]]) 541 | 542 | ## Init Agent-Related Edge 543 | agent_e_type_lis = torch.cat([input_dic["graph_lis"].edata["a_e_type"][_] for _ in [('agent', 'self', 'agent'), ('agent', 'other', 'agent'), ('agent', 'a2l', 'lane')]], dim=0) 544 | agent_e_fea_rel_pos = torch.cat([input_dic["graph_lis"].edata["a_e_fea"][_] for _ in [('agent', 'self', 'agent'), ('agent', 'other', 'agent'), ('agent', 'a2l', 'lane')]], dim=0) 545 | agent_e_type_indices = [torch.where(agent_e_type_lis == _) for _ in range(3)] 546 | agent_e_src_n_fea = torch.cat([agent_n_fea[input_dic["graph_lis"].edges(etype=_)[0],...] for _ in ["self", "other", "a2l"]], dim=0) 547 | agent_e_fea_rel_pos = self.shared_rel_encoder(agent_e_fea_rel_pos) 548 | agent_e_fea = torch.zeros((agent_e_fea_rel_pos.shape[0], self.hidden_dim), device="cuda:"+str(input_dic["gpu"])) 549 | for _ in range(3): 550 | agent_e_fea[agent_e_type_indices[_]] = self.agent_e_fea_MLPs[_](torch.cat([agent_e_fea_rel_pos[agent_e_type_indices[_]], agent_e_src_n_fea[agent_e_type_indices[_]]], dim=-1)) 551 | input_dic["graph_lis"].ndata["a_n_hidden"] = {"agent":agent_n_fea} 552 | agent_e_num_lis_by_etype = np.cumsum([0] + [len(input_dic["graph_lis"].edata["a_e_type"][_]) for _ in [('agent', 'self', 'agent'), ('agent', 'other', 'agent'), ('agent', 'a2l', 'lane')]]) 553 | for _index, _ in enumerate([('agent', 'self', 'agent'), ('agent', 'other', 'agent'), ('agent', 'a2l', 'lane')]): 554 | input_dic["graph_lis"].edata["a_e_hidden"] = {_:agent_e_fea[agent_e_num_lis_by_etype[_index]:agent_e_num_lis_by_etype[_index+1]]} 555 | 556 | self.lane_emb(input_dic, self.shared_coor_encoder, self.shared_rel_encoder) 557 | self.polygon_emb(input_dic, self.shared_coor_encoder) 558 | 559 | 560 | for i in range(self.num_of_gnn_layer): 561 | output_lane_n_fea, output_in_lane_e_fea = self.lane_gnns[i](input_dic) 562 | output_agent_n_fea, output_in_agent_e_fea = self.agent_gnns[i](input_dic) 563 | input_dic["graph_lis"].nodes["lane"].data["l_n_hidden"] = output_lane_n_fea 564 | for _index, _ in enumerate(["left", "right", "prev", "follow"]): 565 | input_dic["graph_lis"].edges[_].data["l_e_hidden"] = output_in_lane_e_fea[_index] 566 | input_dic["graph_lis"].edges["a2l"].data["a_e_hidden"] = output_in_lane_e_fea[-1] 567 | 568 | input_dic["graph_lis"].nodes["agent"].data["a_n_hidden"] = output_agent_n_fea 569 | for _index, _ in enumerate([("self", "a_e_hidden"), ("other", "a_e_hideen"), ("l2a", "l_e_hidden"), ("g2a", "g_e_hidden")]): 570 | input_dic["graph_lis"].edges[_[0]].data[_[1]] = output_in_agent_e_fea[_index] 571 | return input_dic["graph_lis"] 572 | 573 | 574 | 575 | class HDGT_model(nn.Module): 576 | def __init__(self, input_dim, args): 577 | super().__init__() 578 | self.input_dim = input_dim 579 | self.hidden_dim = args.hidden_dim 580 | self.args = args 581 | self.num_prediction = args.num_prediction 582 | 583 | self.encoder = HDGT_encoder(input_dim, args) 584 | self.decoder = torch.nn.ModuleList([RefineDecoder(args) for _ in range(3)]) 585 | 586 | def forward(self, input_dic): 587 | output_het_graph = self.encoder(input_dic) 588 | 589 | neighbor_size_lis = input_dic["neighbor_size_lis"] 590 | all_agent_raw_traj = input_dic["graph_lis"].nodes["agent"].data["a_n_fea"][..., :2].clone() 591 | cumsum_neighbor_size_lis = np.cumsum(neighbor_size_lis, axis=0).tolist() 592 | cumsum_neighbor_size_lis = [0] + cumsum_neighbor_size_lis 593 | pred_num_lis = input_dic["pred_num_lis"] 594 | agent_node_fea = output_het_graph.nodes["agent"].data["a_n_hidden"] 595 | all_agent_id = input_dic["graph_lis"].nodes("agent") 596 | agent_n_type_lis = output_het_graph.ndata["a_n_type"]["agent"] 597 | 598 | ## Obtain the node feature of target agents 599 | targat_agent_indice_lis = [] 600 | target_agent_indice_bool_type_lis = [[] for _ in range(3)] 601 | targat_agent_fea = [[] for _ in range(3)] 602 | target_agent_id = [[] for _ in range(3)] 603 | for i in range(1, len(cumsum_neighbor_size_lis)): 604 | now_agent_type_lis = agent_n_type_lis[cumsum_neighbor_size_lis[i-1]:cumsum_neighbor_size_lis[i]] 605 | now_agent_id = all_agent_id[cumsum_neighbor_size_lis[i-1]:cumsum_neighbor_size_lis[i-1]+pred_num_lis[i-1]] 606 | now_agent_node_fea = agent_node_fea[cumsum_neighbor_size_lis[i-1]:cumsum_neighbor_size_lis[i-1]+pred_num_lis[i-1]] 607 | now_target_agent_indice_bool_type_lis = [now_agent_type_lis[:pred_num_lis[i-1]]==_ for _ in range(3)] 608 | 609 | targat_agent_indice_lis += list(range(cumsum_neighbor_size_lis[i-1], cumsum_neighbor_size_lis[i-1]+pred_num_lis[i-1])) 610 | for _ in range(3): 611 | target_agent_indice_bool_type_lis[_].append(now_target_agent_indice_bool_type_lis[_]) 612 | targat_agent_fea[_].append(now_agent_node_fea[now_target_agent_indice_bool_type_lis[_]]) 613 | target_agent_id[_].append(now_agent_id[now_target_agent_indice_bool_type_lis[_]]) 614 | 615 | target_agent_indice_bool_type_lis = [torch.cat(target_agent_indice_bool_type_lis[_], dim=0) for _ in range(3)] 616 | targat_agent_fea = [torch.cat(targat_agent_fea[_], dim=0) for _ in range(3)] 617 | target_agent_id = [torch.cat(target_agent_id[_], dim=0) for _ in range(3)] 618 | 619 | prediction = [] 620 | for _ in range(3): 621 | if targat_agent_fea[_].shape[0] == 0: 622 | prediction.append((torch.zeros((0, 1)), torch.zeros((0, 1)))) 623 | else: 624 | prediction.append(self.decoder[_](targat_agent_fea[_], target_agent_id[_], all_agent_raw_traj, input_dic)) 625 | agent_cls_res = [prediction[_][0] for _ in range(3)] 626 | agent_reg_res = [prediction[_][1] for _ in range(3)] 627 | return agent_reg_res, agent_cls_res, target_agent_indice_bool_type_lis 628 | 629 | 630 | 631 | class RefineCNN(nn.Module): 632 | def __init__(self, in_c, dilation=1, args=None): 633 | super(RefineCNN, self).__init__() 634 | self.in_c = in_c 635 | self.conv1 = nn.Conv1d(in_c, in_c, kernel_size=3, dilation=dilation, padding=dilation) 636 | self.gn1 = nn.GroupNorm(num_groups=in_c, num_channels=in_c) 637 | self.conv2 = nn.Conv1d(in_c, in_c, kernel_size=3, dilation=dilation, padding=dilation) 638 | self.gn2 = nn.GroupNorm(num_groups=in_c, num_channels=in_c) 639 | self.act = nn.ReLU(inplace=True) 640 | self.se = SEBlock(in_c) 641 | def forward(self, x): 642 | identity = x 643 | out = x 644 | out = self.act(self.gn1(self.conv1(out))) 645 | out = self.act(self.gn2(self.conv2(out))) 646 | out = self.se(out) + identity 647 | out = self.act(out) 648 | return out 649 | 650 | 651 | 652 | class RefineContextLayer(nn.Module): 653 | def __init__(self, args): 654 | super().__init__() 655 | self.num_prediction = args.num_prediction 656 | self.hidden_dim = args.hidden_dim 657 | self.head_dim = args.head_dim 658 | self.d_model = int(args.hidden_dim) 659 | self.n_head = self.d_model // self.head_dim 660 | d_k = self.head_dim 661 | d_v = self.head_dim 662 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5), args=args) 663 | self.wq = nn.Sequential( 664 | nn.LayerNorm(self.hidden_dim), 665 | nn.Linear(args.hidden_dim, self.n_head * d_k, bias=False) 666 | ) 667 | 668 | 669 | self.wkvs = nn.ModuleList([ 670 | nn.Sequential( 671 | nn.LayerNorm(self.hidden_dim), 672 | nn.Linear(args.hidden_dim, self.n_head * d_k * 2, bias=False), 673 | ) 674 | for _ in range(3) 675 | ]) 676 | self.attn_fc = nn.Sequential( 677 | nn.Linear(self.n_head * d_v, self.d_model, bias=True), 678 | nn.Dropout(args.dropout), 679 | ) 680 | 681 | self.ffn = PositionwiseFeedForward(self.d_model, self.d_model*4, args.dropout, args) 682 | 683 | def forward(self, raw_q, raw_kv_lis, raw_kv_indices, input_dic): 684 | all_q_lis = self.wq(raw_q).view(raw_q.shape[0], self.num_prediction*self.n_head, self.head_dim) 685 | all_kv_lis = [self.wkvs[_](torch.cat(raw_kv_lis[_], dim=0)) for _ in range(3)] 686 | all_kv_lis = [[_[raw_kv_indices[_index][indices_i-1]:raw_kv_indices[_index][indices_i]] for indices_i in range(1, len(raw_kv_indices[_index]))] for _index, _ in enumerate(all_kv_lis)] 687 | all_kv_lis = [torch.cat([all_kv_lis[0][_], all_kv_lis[1][_], all_kv_lis[2][_]], dim=0).view(-1, self.n_head, self.head_dim*2).unsqueeze(1).repeat(1, self.num_prediction, 1, 1).view(-1, self.num_prediction*self.n_head, self.head_dim*2).transpose(0, 1) for _ in range(raw_q.shape[0])] 688 | 689 | all_out_q_lis = self.attn_fc(torch.cat([self.attention(q=all_q_lis[_].unsqueeze(1), k=all_kv_lis[_][..., :self.head_dim], v=all_kv_lis[_][..., self.head_dim:])[0].view(-1, self.num_prediction, self.n_head*self.head_dim) if all_kv_lis[_].shape[0]!=0 else torch.zeros_like(raw_q[0:1, :, :]) for _ in range(raw_q.shape[0])], dim=0)) + raw_q 690 | return self.ffn(all_out_q_lis) 691 | 692 | 693 | 694 | class RefineContextModule(nn.Module): 695 | def __init__(self, args): 696 | super().__init__() 697 | self.num_prediction = args.num_prediction 698 | self.hidden_dim = args.hidden_dim 699 | self.modal_emb = nn.parameter.Parameter(torch.zeros(args.num_prediction, args.hidden_dim)) 700 | torch.nn.init.normal(self.modal_emb) 701 | self.init_q = MLP(args.hidden_dim*2, args.hidden_dim*4, args.hidden_dim, nn.LayerNorm, prenorm=True) 702 | self.etype2hidden_name = {"other":"a_e_hidden", "l2a":"l_e_hidden", "g2a":"g_e_hidden"} 703 | self.refine_context_layer = nn.ModuleList([RefineContextLayer(args) for _ in range(2)]) 704 | 705 | def forward(self, agent_ids, input_dic): 706 | with input_dic["graph_lis"].local_scope(): 707 | self.device = "cuda:"+str(input_dic["gpu"]) 708 | raw_agent_fea = input_dic["graph_lis"].ndata["a_n_hidden"]["agent"][agent_ids] 709 | num_agent = raw_agent_fea.shape[0] 710 | 711 | raw_q = torch.cat([raw_agent_fea.unsqueeze(1).repeat(1, self.num_prediction, 1), self.modal_emb.unsqueeze(0).repeat(num_agent, 1, 1)], dim=-1) 712 | raw_q = self.init_q(raw_q) + self.modal_emb 713 | 714 | raw_kv_lis = [[torch.zeros((0, self.hidden_dim), device=self.device)]*num_agent for _ in range(3)] 715 | for agent_index, agent_id in enumerate(agent_ids): 716 | for etype_index, etype in enumerate(["other", "l2a", "g2a"]): 717 | now_type_eid_lis = input_dic["graph_lis"].in_edges(etype=etype, v=agent_id, form="eid") 718 | if len(now_type_eid_lis) > 0: 719 | raw_kv_lis[etype_index][agent_index] = input_dic["graph_lis"].edges[etype].data[self.etype2hidden_name[etype]][now_type_eid_lis] 720 | 721 | raw_kv_length = [[int(__.shape[0]) for __ in _] for _ in raw_kv_lis] 722 | raw_kv_indices = [np.cumsum([0]+_) for _ in raw_kv_length] 723 | for _ in range(len(self.refine_context_layer)): 724 | raw_q = self.refine_context_layer[_](raw_q, raw_kv_lis, raw_kv_indices, input_dic) + self.modal_emb 725 | return raw_q 726 | 727 | 728 | class RefineLayer(nn.Module): 729 | def __init__(self, d_in, d_hid, args): 730 | super().__init__() 731 | self.in_linear = nn.Linear(d_in, d_hid) 732 | self.context_mlp = MLP(args.hidden_dim, args.hidden_dim//2, d_hid, norm=nn.LayerNorm, prenorm=True) 733 | self.fuse_linear = nn.Linear(d_hid * 2, d_hid) 734 | self.cnns = nn.Sequential( 735 | RefineCNN(d_hid, dilation=1, args=args), 736 | RefineCNN(d_hid, dilation=2, args=args), 737 | RefineCNN(d_hid, dilation=5, args=args), 738 | RefineCNN(d_hid, dilation=1, args=args), 739 | ) 740 | self.out_linear = MLP(d_hid, d_hid, 2, norm=None) 741 | def forward(self, x, context): 742 | x = x.view(x.shape[0], x.shape[1], 91, -1) 743 | output = self.in_linear(x) 744 | context = self.context_mlp(context).unsqueeze(-2).repeat(1, 1, 91, 1) 745 | output = self.fuse_linear(torch.cat([context, output], dim=-1)) 746 | bs, num_mode, t_len, hid = output.shape 747 | output = output.view(bs*num_mode, t_len, hid).transpose(1, 2) 748 | output = self.cnns(output).transpose(1, 2).view(bs, num_mode, t_len, hid)[:, :, 11:, :] 749 | output = self.out_linear(output) 750 | return output 751 | 752 | class RefineDecoder(nn.Module): 753 | def __init__(self, args): 754 | super().__init__() 755 | self.hidden_dim = args.hidden_dim 756 | self.num_prediction = int(args.num_prediction) 757 | 758 | self.refine_num = int(args.refine_num) 759 | self.refine_layer_lis = nn.ModuleList( 760 | [RefineLayer(2, self.hidden_dim//4, args) for _ in range(self.refine_num)] 761 | ) 762 | self.is_output_vel = (args.output_vel == "True") 763 | self.is_cumsum_vel = (args.cumsum_vel == "True") 764 | 765 | self.refine_context_attn = RefineContextModule(args) 766 | 767 | self.reg_mlp = nn.Sequential( 768 | nn.Linear(args.hidden_dim, args.hidden_dim*2), 769 | nn.ReLU(inplace=True), 770 | nn.Linear(args.hidden_dim*2, args.hidden_dim*4), 771 | nn.ReLU(inplace=True), 772 | nn.Linear(args.hidden_dim*4, 80*2), 773 | ) 774 | 775 | self.cls_mlp = nn.Sequential( 776 | nn.Linear(args.hidden_dim, args.hidden_dim), 777 | nn.ReLU(inplace=True), 778 | nn.Linear(args.hidden_dim, args.hidden_dim//2), 779 | nn.ReLU(inplace=True), 780 | nn.Linear(args.hidden_dim//2, 1), 781 | ) 782 | 783 | 784 | def forward(self, target_agent_fea, agent_ids, agent_raw_traj, input_dic): 785 | refine_context = self.refine_context_attn(agent_ids, input_dic)/10.0 786 | reg_res = self.reg_mlp(refine_context).view(target_agent_fea.shape[0], self.num_prediction, 80, 2) 787 | cls_res = self.cls_mlp(refine_context).view(target_agent_fea.shape[0], self.num_prediction) 788 | if self.is_output_vel and self.is_cumsum_vel: 789 | reg_res = torch.cumsum(reg_res, dim=-2) 790 | reg_res_lis = [reg_res] 791 | if self.refine_num > 0: 792 | now_agent_raw_traj = agent_raw_traj[agent_ids].unsqueeze(1).repeat(1, self.num_prediction, 1, 1) 793 | for _ in range(self.refine_num): 794 | now_input = torch.cat([now_agent_raw_traj, reg_res.detach()], dim=-2).view(reg_res.shape[0], self.num_prediction, -1) ## Full Traj 795 | reg_res = reg_res + self.refine_layer_lis[_](now_input, refine_context).view(reg_res.shape[0], self.num_prediction, 80, 2) 796 | reg_res_lis.append(reg_res) 797 | return cls_res, torch.stack(reg_res_lis, dim=1) 798 | 799 | 800 | 801 | 802 | 803 | from torch.optim import Optimizer 804 | from torch.optim.lr_scheduler import LambdaLR 805 | class WarmupLinearSchedule(LambdaLR): 806 | def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1): 807 | self.warmup_steps = warmup_steps 808 | self.t_total = t_total 809 | super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 810 | 811 | def lr_lambda(self, step): 812 | if step < self.warmup_steps: 813 | return float(step) / float(max(1, self.warmup_steps)) 814 | return 1.0#max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps))) 815 | 816 | import warnings 817 | def _no_grad_trunc_normal_(tensor, mean, std, a, b): 818 | # Cut & paste from PyTorch official master until it's in a few official releases - RW 819 | # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 820 | def norm_cdf(x): 821 | # Computes standard normal cumulative distribution function 822 | return (1. + math.erf(x / math.sqrt(2.))) / 2. 823 | 824 | if (mean < a - 2 * std) or (mean > b + 2 * std): 825 | warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " 826 | "The distribution of values may be incorrect.", 827 | stacklevel=2) 828 | 829 | with torch.no_grad(): 830 | # Values are generated by using a truncated uniform distribution and 831 | # then using the inverse CDF for the normal distribution. 832 | # Get upper and lower cdf values 833 | l = norm_cdf((a - mean) / std) 834 | u = norm_cdf((b - mean) / std) 835 | 836 | # Uniformly fill tensor with values from [l, u], then translate to 837 | # [2l-1, 2u-1]. 838 | tensor.uniform_(2 * l - 1, 2 * u - 1) 839 | 840 | # Use inverse cdf transform for normal distribution to get truncated 841 | # standard normal 842 | tensor.erfinv_() 843 | 844 | # Transform to proper mean, std 845 | tensor.mul_(std * math.sqrt(2.)) 846 | tensor.add_(mean) 847 | 848 | # Clamp to ensure it's in the proper range 849 | tensor.clamp_(min=a, max=b) 850 | return tensor 851 | 852 | 853 | def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): 854 | # type: (Tensor, float, float, float, float) -> Tensor 855 | r"""Fills the input Tensor with values drawn from a truncated 856 | normal distribution. The values are effectively drawn from the 857 | normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` 858 | with values outside :math:`[a, b]` redrawn until they are within 859 | the bounds. The method used for generating the random values works 860 | best when :math:`a \leq \text{mean} \leq b`. 861 | Args: 862 | tensor: an n-dimensional `torch.Tensor` 863 | mean: the mean of the normal distribution 864 | std: the standard deviation of the normal distribution 865 | a: the minimum cutoff value 866 | b: the maximum cutoff value 867 | Examples: 868 | >>> w = torch.empty(3, 5) 869 | >>> nn.init.trunc_normal_(w) 870 | """ 871 | return _no_grad_trunc_normal_(tensor, mean, std, a, b) 872 | 873 | 874 | def weights_init(m): 875 | with torch.no_grad(): 876 | classname = m.__class__.__name__ 877 | if classname.find('Conv1d') != -1: 878 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 879 | elif classname.find('LayerNorm') != -1: 880 | nn.init.constant_(m.weight, 1) 881 | nn.init.constant_(m.bias, 0) 882 | elif classname.find('BatchNorm') != -1: 883 | nn.init.constant_(m.weight, 1) 884 | nn.init.constant_(m.bias, 0) 885 | elif classname.find('GroupNorm') != -1: 886 | nn.init.constant_(m.weight, 1) 887 | nn.init.constant_(m.bias, 0) 888 | elif classname.find('Linear') != -1: 889 | if m.bias is not None: 890 | nn.init.constant_(m.bias, 0) 891 | nn.init.xavier_normal_(m.weight) 892 | elif classname.find('Embedding') != -1: 893 | trunc_normal_(m.weight, mean=0, std=0.02) 894 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["NCCL_P2P_DISABLE"] = "1" 3 | import random 4 | from typing import DefaultDict 5 | import warnings 6 | warnings.filterwarnings("ignore") 7 | import gc 8 | import numpy as np 9 | import argparse 10 | import time 11 | import sys 12 | import pickle 13 | import shutil 14 | import importlib 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.distributed as dist 18 | import torch.multiprocessing as mp 19 | import torch.nn.functional as F 20 | from torch.utils.data import Dataset, DataLoader 21 | #from waymo_data.collate_func import * 22 | import io 23 | from metricss_soft_map import soft_map 24 | import scipy.special 25 | import scipy.interpolate as interp 26 | from waymo_dataset import * 27 | 28 | parser = argparse.ArgumentParser('Interface for HDGT Training') 29 | ##### Optimizer - Scheduler 30 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 31 | parser.add_argument('--weight_decay', type=float, default=1e-6, help='weight decay') 32 | parser.add_argument('--batch_size', type=int, default=16, help='batch size') 33 | parser.add_argument('--val_batch_size', type=int, default=128, help='batch size') 34 | parser.add_argument('--n_epoch', type=int, default=30, help='number of epochs') 35 | parser.add_argument('--warmup', type=float, default=1.0, help='the number of epoch for warmup') 36 | parser.add_argument('--lr_decay_epoch', type=str, default="4-8-16-24-26", help='the index of epoch where the lr decays to lr*0.5') 37 | parser.add_argument('--num_prediction', type=int,default=6, help='the number of modality') 38 | parser.add_argument('--cls_weight', type=float,default=0.1, help='the weight of classification loss') 39 | parser.add_argument('--reg_weight', type=float,default=50.0, help='the weight of regression loss') 40 | 41 | #### Speed Up 42 | parser.add_argument('--num_of_gnn_layer', type=int, default=6, help='the number of HDGT layer') 43 | parser.add_argument('--hidden_dim', type=int, default=256, help='init hidden dimension') 44 | parser.add_argument('--head_dim', type=int, default=32, help='the dimension of attention head') 45 | parser.add_argument('--dropout', type=float, default=0.0, help='dropout probability') 46 | parser.add_argument('--num_worker', type=int, default=8, help='number of worker per dataloader') 47 | 48 | #### Setting 49 | parser.add_argument('--agent_drop', type=float, default='0.0', help='the ratio of randomly dropping agent') 50 | parser.add_argument('--data_folder', type=str,default="hdgt_waymo", help='training set') 51 | 52 | parser.add_argument('--refine_num', type=int, default=5, help='temporally refine the trajectory') 53 | parser.add_argument('--output_vel', type=str, default="True", help='output in form of velocity') 54 | parser.add_argument('--cumsum_vel', type=str, default="True", help='cumulate velocity for reg loss') 55 | 56 | 57 | #### Initialize 58 | parser.add_argument('--checkpoint', type=str, default="none", help='load checkpoint') 59 | parser.add_argument('--start_epoch', type=int, default=1, help='the index of start epoch (for resume training)') 60 | parser.add_argument('--dev_mode', type=str, default="False", help='develop_mode') 61 | 62 | parser.add_argument('--ddp_mode', type=str, default="False", help='False, True, multi_node') 63 | parser.add_argument('--port', type=str, default="31243", help='DDP') 64 | 65 | parser.add_argument('--amp', type=str, default="none", help='type of fp16') 66 | 67 | #### Log 68 | parser.add_argument('--val_every_train_step', type=int, default=-1, help='every number of training step to conduct one evaluation') 69 | parser.add_argument('--name', type=str, default="hdgt_waymo_dev", help='the name of this setting') 70 | args = parser.parse_args() 71 | os.environ["DGLBACKEND"] = "pytorch" 72 | 73 | 74 | class Logger(): 75 | def __init__(self, lognames): 76 | self.terminal = sys.stdout 77 | self.logs = [] 78 | for log_name in lognames: 79 | self.logs.append(open(log_name, 'w')) 80 | def write(self, message): 81 | self.terminal.write(message) 82 | for log in self.logs: 83 | log.write(message) 84 | log.flush() 85 | def flush(self): 86 | pass 87 | 88 | def euclid(label, pred): 89 | return torch.sqrt((label[...,0]-pred[...,0])**2 + (label[...,1]-pred[...,1])**2) 90 | 91 | def euclid_np(label, pred): 92 | return np.sqrt((label[...,0]-pred[...,0])**2 + (label[...,1]-pred[...,1])**2) 93 | 94 | def cal_ADE(label, pred): 95 | return euclid_np(label,pred).mean() 96 | 97 | def cal_FDE(label, pred): 98 | return euclid_np(label[:,-1,:], pred[:,-1,:]).mean() 99 | 100 | def cal_ade_fde_mr(labels, preds, masks): 101 | if labels.shape[0] == 0: 102 | return None, None 103 | l2_norm = euclid_np(labels, preds) 104 | 105 | masks_sum = masks.sum(1) 106 | ade_indices = masks_sum != 0 107 | ade_cnt = ade_indices.sum() 108 | ade = ((l2_norm[ade_indices] * masks[ade_indices]).sum(1)/masks_sum[ade_indices]).mean() 109 | 110 | fde_indices = masks[:, -1] != 0 111 | fde_cnt = fde_indices.sum() 112 | fde = 0.0 113 | mr = 0.0 114 | if fde_cnt != 0: 115 | fde = l2_norm[fde_indices, -1] 116 | mr = (fde > 2.0).mean() 117 | fde = fde.mean() 118 | return [ade, fde, mr], [ade_cnt, fde_cnt, fde_cnt] 119 | 120 | def cal_min6_ade_fde_mr(preds, labels, masks): 121 | if labels.shape[0] == 0: 122 | return None, None 123 | l2_norm = euclid_np(labels[:, np.newaxis, :, :], preds) 124 | ## ade6 125 | masks_sum = masks.sum(1) 126 | ade_indices = masks_sum != 0 127 | ade_cnt = ade_indices.sum() 128 | ade6 = ((l2_norm[ade_indices] * masks[ade_indices, np.newaxis, :]).sum(-1)/masks_sum[ade_indices][:, np.newaxis]).min(-1).mean() 129 | 130 | fde_indices = masks[:, -1] != 0 131 | fde_cnt = fde_indices.sum() 132 | fde6 = 0.0 133 | mr6 = 0.0 134 | if fde_cnt != 0: 135 | fde6 = l2_norm[fde_indices, :, -1].min(-1) 136 | mr6 = (fde6 > 2.0).mean() 137 | fde6 = fde6.mean() 138 | return [ade6, fde6, mr6], [ade_cnt, fde_cnt, fde_cnt] 139 | 140 | class AverageMeter(object): 141 | """Computes and stores the average and current value""" 142 | def __init__(self): 143 | self.reset() 144 | def reset(self): 145 | self.val = 0 146 | self.avg = 0 147 | self.sum = 0 148 | self.count = 0 149 | def update(self, val, n=1): 150 | self.val = val 151 | self.sum += val * n 152 | self.count += n 153 | self.avg = self.sum / self.count 154 | 155 | 156 | def main(): 157 | args = parser.parse_args() 158 | ###Distributed 159 | gpu_count = torch.cuda.device_count() 160 | global_seed = int(args.port) ## Import!!!! for coherent data splitting across process 161 | if gpu_count > 1: 162 | if args.ddp_mode == "multi_node": 163 | main_worker(int(os.environ["LOCAL_RANK"]), int(os.environ["WORLD_SIZE"]), global_seed, args) 164 | else: 165 | mp.spawn(main_worker, nprocs=gpu_count, args=(gpu_count, global_seed, args)) 166 | else: 167 | main_worker(0, gpu_count, global_seed, args) 168 | 169 | ## Running for each GPU 170 | def main_worker(gpu, gpu_count, global_seed, args): 171 | if args.ddp_mode == "multi_node": 172 | global_rank = int(os.environ["RANK"]) 173 | init_port = "tcp://"+os.environ["MASTER_ADDR"]+":"+os.environ["MASTER_PORT"] 174 | else: 175 | global_rank = gpu 176 | init_port = "tcp://127.0.0.1:"+args.port 177 | print(f"Use GPU: {gpu} for training. Global Rank:{global_rank} Global World Size:{gpu_count} Init Port {init_port}") 178 | print("Process Id:", os.getpid()) 179 | seed_num = random.randint(0, 1000000) 180 | torch.manual_seed(seed_num+global_rank) 181 | random.seed(seed_num+global_rank) 182 | np.random.seed(seed_num+global_rank) 183 | 184 | if gpu_count > 1: 185 | dist.init_process_group(backend="nccl", world_size=gpu_count, init_method=init_port, rank=global_rank) 186 | torch.cuda.set_device(gpu) 187 | device = torch.device("cuda:"+str(gpu)) 188 | 189 | 190 | snapshot_dir = None 191 | if global_rank == 0: 192 | setting_name = args.name 193 | log_dir = "logs/" + str(setting_name+"_"+time.strftime("%Y-%m-%d-%H_%M_%S",time.localtime(time.time()))) 194 | 195 | if not os.path.isdir(log_dir): 196 | os.makedirs(log_dir) 197 | sys.stdout = Logger([f"{setting_name}.log", os.path.join(log_dir, f"{setting_name}.log")]) 198 | snapshot_dir = os.path.join(log_dir, "snapshot") 199 | if not os.path.isdir(snapshot_dir): 200 | os.makedirs(snapshot_dir) 201 | print("Log Directory:", os.path.join(log_dir, sys.argv[0])) 202 | print(args) 203 | shutil.copyfile(__file__, os.path.join(log_dir, "train.py")) 204 | shutil.copyfile("model.py", os.path.join(log_dir, "model.py")) 205 | model_module = importlib.import_module("model") 206 | model = model_module.HDGT_model(input_dim=11, args=args) 207 | model.apply(model_module.weights_init) 208 | 209 | print("Start Load Dataset") 210 | train_dataloader, val_dataloader, train_sample_num, val_sample_num = obtain_dataset(global_rank, gpu_count, global_seed, args) 211 | 212 | checkpoint = None 213 | if args.checkpoint != "none": 214 | print("Load:", args.checkpoint, gpu) 215 | checkpoint = torch.load(args.checkpoint, map_location="cpu") 216 | model.load_state_dict(checkpoint["model_state_dict"]) 217 | 218 | model = model.to(device, non_blocking=True) 219 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=args.weight_decay) 220 | step_per_epoch = train_sample_num // args.batch_size // gpu_count + 1 221 | epoch = args.n_epoch 222 | 223 | warmup = args.warmup 224 | if args.start_epoch > 1: 225 | warmup = 0.0 226 | 227 | scheduler = model_module.WarmupLinearSchedule(optimizer, step_per_epoch*warmup, step_per_epoch*epoch) 228 | if args.checkpoint != "none": 229 | optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) 230 | for param_group in optimizer.param_groups: 231 | param_group["lr"] = args.lr 232 | param_group["betas"] = (0.9, 0.95) 233 | print("lr") 234 | 235 | if gpu_count > 1: 236 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu], output_device=gpu, find_unused_parameters=True) 237 | reg_criterion = torch.nn.SmoothL1Loss(reduction="none").to(device) 238 | if gpu == 0: 239 | print("start") 240 | 241 | if args.amp == "fp16": 242 | amp_data_type = torch.float16 243 | scaler = torch.cuda.amp.GradScaler() 244 | print("Use AMP with Type:", amp_data_type) 245 | elif args.amp == "bf16": 246 | amp_data_type = torch.bfloat16 247 | scaler = torch.cuda.amp.GradScaler() 248 | print("Use AMP with Type:", amp_data_type) 249 | else: 250 | amp_data_type = torch.float32 ## Not enabled 251 | scaler = None 252 | 253 | for epoch in range(args.start_epoch, args.n_epoch+1): 254 | run_model(dataloader=train_dataloader, num_sample=train_sample_num, model=model, optimizer=optimizer, scheduler=scheduler, epoch=epoch, gpu=gpu, global_rank=global_rank, gpu_count=gpu_count, is_train=True, args=args, val_dataloader=val_dataloader, val_sample_num=val_sample_num, snapshot_dir=snapshot_dir, scaler=scaler, amp_data_type=amp_data_type) 255 | 256 | 257 | def run_model(dataloader, num_sample, model, optimizer, scheduler, epoch, gpu, global_rank, gpu_count, is_train, args, val_dataloader=None, val_sample_num=None, snapshot_dir=None, scaler=None, amp_data_type=None): 258 | length_lis = [30, 50, 80] 259 | agent_type_lis = ["VEHICLE", "PEDESTRIAN", "CYCLIST"] 260 | metric_type_lis = ["ade", "fde", "mr", "ade6", "fde6", "mr6",] 261 | 262 | acc_metric_type_lis = ["accs", "acc6s"] 263 | recorder = {} 264 | for agent_type in agent_type_lis: 265 | for acc_metric_type in acc_metric_type_lis: 266 | recorder[agent_type+"_"+acc_metric_type] = AverageMeter() 267 | for length in length_lis: 268 | for metric_type in metric_type_lis: 269 | recorder[agent_type+"_"+str(length)+"_"+metric_type] = AverageMeter() 270 | recorder["loss"] = AverageMeter() 271 | recorder["reg_loss"] = AverageMeter() 272 | recorder["cls_loss"] = AverageMeter() 273 | losses_recoder, reg_losses_, cls_losses = AverageMeter(), AverageMeter(), AverageMeter() 274 | 275 | start_time = time.time() 276 | if is_train: 277 | model.train() 278 | batch_size = args.batch_size 279 | else: 280 | model.eval() 281 | batch_size = args.val_batch_size 282 | gpu_count = 1 283 | 284 | is_dev = (args.dev_mode == "True") 285 | is_output_vel = (args.output_vel == "True") 286 | is_cumsum_vel = (args.cumsum_vel == "True") 287 | 288 | 289 | if is_dev: 290 | val_every_train_step = 1 291 | print_freq = 1 292 | else: 293 | val_every_train_step = args.val_every_train_step 294 | print_freq = 100 295 | 296 | num_prediction = args.num_prediction 297 | refine_num = args.refine_num 298 | cls_weight = args.cls_weight 299 | reg_weight = args.reg_weight 300 | device = torch.device("cuda:"+str(gpu)) 301 | 302 | reg_criteria = torch.nn.SmoothL1Loss(reduction="none").to(device) 303 | 304 | lr_decay_epoch = args.lr_decay_epoch.split("-") 305 | lr_decay_epoch = [int(_) for _ in lr_decay_epoch] 306 | 307 | decay_coefficient = 1.0 308 | for decay_epoch in lr_decay_epoch: 309 | if epoch >= decay_epoch: 310 | decay_coefficient = decay_coefficient * 0.5 311 | scheduler.base_lrs = [args.lr * decay_coefficient] 312 | for param_group in optimizer.param_groups: 313 | param_group["lr"] = args.lr * decay_coefficient 314 | 315 | use_amp = (scaler == None) 316 | with torch.set_grad_enabled(is_train): 317 | for batch_index, data in enumerate(dataloader, 0): 318 | data["is_train"] = is_train 319 | data["gpu"] = gpu 320 | for tensor_name in data["cuda_tensor_lis"]: 321 | data[tensor_name] = data[tensor_name].to("cuda:"+str(gpu), non_blocking=True) 322 | optimizer.zero_grad() 323 | num_of_sample = len(data["pred_num_lis"]) 324 | 325 | with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_data_type): 326 | agent_reg_res, agent_cls_res, pred_indice_bool_type_lis = model(data) 327 | reg_labels = [data["label_lis"][pred_indice_bool_type_lis[_]] for _ in range(3)] 328 | 329 | auxiliary_labels = [data["auxiliary_label_lis"][pred_indice_bool_type_lis[_]] for _ in range(3)] 330 | auxiliary_labels_future = [data["auxiliary_label_future_lis"][pred_indice_bool_type_lis[_]] for _ in range(3)] 331 | label_masks = [data["label_mask_lis"][pred_indice_bool_type_lis[_]] for _ in range(3)] 332 | 333 | agent_closest_index_lis = [[] for _ in range(3)] 334 | loss = 0.0 335 | reg_loss = 0.0 336 | cls_loss = 0.0 337 | total_num_of_mask = 0.0 338 | total_num_of_agent = 0.0 339 | for agent_type_index in range(3): 340 | if agent_reg_res[agent_type_index].shape[0] == 0: 341 | continue 342 | num_of_mask_per_agent = label_masks[agent_type_index].sum(dim=-1) 343 | mask_sum = num_of_mask_per_agent.sum() 344 | if mask_sum != 0: 345 | dist_between_pred_label = reg_criteria(agent_reg_res[agent_type_index], reg_labels[agent_type_index].unsqueeze(1).unsqueeze(1).repeat(1, refine_num+1, num_prediction, 1, 1)).mean(-1) ## N_Agent, N_refine, num_prediction, 80 346 | dist_between_pred_label = (dist_between_pred_label * label_masks[agent_type_index].unsqueeze(1).unsqueeze(1)).sum(-1) / (num_of_mask_per_agent.unsqueeze(-1).unsqueeze(-1)+1) ## N_Agent, N_refine, num_prediction 347 | agent_closest_index = dist_between_pred_label[:, -1, :].argmin(dim=-1) 348 | 349 | reg_loss += (dist_between_pred_label[torch.arange(agent_closest_index.shape[0]), :, agent_closest_index]).sum() / (refine_num+1) 350 | 351 | log_pis = agent_cls_res[agent_type_index] 352 | log_pis = log_pis - torch.logsumexp(log_pis, dim=-1, keepdim=True) 353 | log_pi = log_pis[torch.arange(agent_closest_index.shape[0]), agent_closest_index].sum() 354 | cls_loss += (-log_pi) 355 | agent_closest_index_lis[agent_type_index] = (agent_closest_index) 356 | total_num_of_agent += len(agent_closest_index) 357 | total_num_of_mask += mask_sum 358 | 359 | loss += (cls_loss / total_num_of_agent * cls_weight + reg_loss / total_num_of_mask * reg_weight) 360 | reg_loss_cnt = total_num_of_mask 361 | cls_loss_cnt = total_num_of_agent 362 | 363 | if is_train: 364 | if loss != 0: 365 | if torch.isnan(loss): 366 | print("Bad Gradients!", epoch, batch_index) 367 | optimizer.zero_grad() 368 | del data 369 | del agent_reg_res, agent_cls_res, pred_indice_bool_type_lis 370 | del loss 371 | continue 372 | if scaler is not None: 373 | scaler.scale(loss).backward() 374 | scaler.unscale_(optimizer) 375 | torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) 376 | scaler.step(optimizer) 377 | scaler.update() 378 | else: 379 | loss.backward() 380 | torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) 381 | optimizer.step() 382 | scheduler.step() 383 | 384 | if global_rank == 0: 385 | with torch.no_grad(): 386 | if loss != 0: 387 | recorder["loss"].update(loss.item(), num_of_sample) 388 | recorder["cls_loss"].update(cls_loss.item()/cls_loss_cnt * cls_weight, cls_loss_cnt) 389 | recorder["reg_loss"].update(reg_loss.item()/reg_loss_cnt * reg_weight, reg_loss_cnt) 390 | 391 | neighbor_size_lis = data["pred_num_lis"] 392 | cumsum_neighbor_size_lis = np.cumsum(neighbor_size_lis, axis=0).tolist() 393 | cumsum_neighbor_size_lis = [0] + cumsum_neighbor_size_lis 394 | for agent_type_index in range(3): 395 | now_agent_cls_res = agent_cls_res[agent_type_index] 396 | if now_agent_cls_res.shape[0] == 0: 397 | continue 398 | 399 | now_agent_reg_res = agent_reg_res[agent_type_index][:, -1, ...].detach().cpu().numpy() 400 | now_labels = reg_labels[agent_type_index].detach().cpu().numpy() 401 | now_auxiliary_labels = auxiliary_labels[agent_type_index].detach().cpu().numpy() 402 | now_auxiliary_labels_future = auxiliary_labels_future[agent_type_index].detach().cpu().numpy() 403 | now_label_masks = label_masks[agent_type_index].detach().cpu().numpy() 404 | now_agent_closest_index = agent_closest_index_lis[agent_type_index].detach().cpu().numpy() 405 | now_cls_sorted_index = now_agent_cls_res.argsort(dim=-1, descending=True).detach().cpu().numpy() 406 | now_agent_cls_res = now_agent_cls_res.detach().cpu().numpy() 407 | 408 | cls_acc = 0.0 409 | cls_acc6 = 0.0 410 | best_preds = [0] * len(now_labels) 411 | best_6preds = [0] * len(now_labels) 412 | for item_index in range(len(now_labels)): 413 | if now_agent_closest_index[item_index] == now_cls_sorted_index[item_index][0]: 414 | cls_acc += 1.0 415 | if now_agent_closest_index[item_index] in now_cls_sorted_index[item_index][:6].tolist(): 416 | cls_acc6 += 1.0 417 | best_preds[item_index] = now_agent_reg_res[item_index, ...][now_cls_sorted_index[item_index][0], :, :] 418 | best_6preds[item_index] = now_agent_reg_res[item_index][now_cls_sorted_index[item_index][:6], :, :] 419 | cls_acc /= now_agent_reg_res.shape[0] 420 | cls_acc6 /= now_agent_reg_res.shape[0] 421 | recorder[agent_type_lis[agent_type_index]+"_"+"accs"].update(cls_acc, now_agent_reg_res.shape[0]) 422 | recorder[agent_type_lis[agent_type_index]+"_"+"acc6s"].update(cls_acc6, now_agent_reg_res.shape[0]) 423 | best_preds = np.stack(best_preds, axis=0) 424 | best_6preds = np.stack(best_6preds, axis=0) 425 | 426 | for length_indices in range(3): 427 | res_lis, res_cnt_lis = cal_ade_fde_mr(best_preds[:, :length_lis[length_indices], :][:, 4::5, :], now_labels[:, :length_lis[length_indices], :][:, 4::5, :], now_label_masks[:, :length_lis[length_indices]][:, 4::5]) 428 | if res_lis: 429 | for metric_indices, metric_type in enumerate(["ade", "fde", "mr"]): 430 | if res_cnt_lis[metric_indices] > 0: 431 | recorder[agent_type_lis[agent_type_index]+"_"+str(length_lis[length_indices])+"_"+metric_type].update(res_lis[metric_indices], res_cnt_lis[metric_indices]) 432 | 433 | res_lis, res_cnt_lis = cal_min6_ade_fde_mr(best_6preds[:, :, :length_lis[length_indices], :][:, :, 4::5, :], now_labels[:, :length_lis[length_indices], :][:, 4::5, :], now_label_masks[:, :length_lis[length_indices]][:, 4::5]) 434 | if res_lis: 435 | for metric_indices, metric_type in enumerate(["ade6", "fde6", "mr6"]): 436 | if res_cnt_lis[metric_indices] > 0: 437 | recorder[agent_type_lis[agent_type_index]+"_"+str(length_lis[length_indices])+"_"+metric_type].update(res_lis[metric_indices], res_cnt_lis[metric_indices]) 438 | if (is_train and ((batch_index+1) % print_freq) == 0): 439 | print_dic = {metric_type:0.0 for metric_type in metric_type_lis} 440 | sub_print_dic = {} 441 | for agent_type in agent_type_lis: 442 | for metric_type in metric_type_lis: 443 | sub_print_dic[agent_type + "_" + metric_type] = 0 444 | for length in length_lis: 445 | sub_print_dic[agent_type + "_" + metric_type] += recorder[agent_type+"_"+str(length)+"_"+metric_type].avg 446 | 447 | detail_text = "" 448 | for agent_type in agent_type_lis: 449 | for length in length_lis: 450 | for metric_type in metric_type_lis: 451 | print_dic[metric_type] += recorder[agent_type+"_"+str(length)+"_"+metric_type].avg 452 | detail_text += ", "+agent_type+"_"+str(length)+"_" + metric_type + " {:.4f}".format(recorder[agent_type+"_"+str(length)+"_"+metric_type].avg) 453 | for agent_type in agent_type_lis: 454 | for acc_metric_type in acc_metric_type_lis: 455 | detail_text += ", "+agent_type+"_" + acc_metric_type + " "+str(recorder[agent_type+"_"+acc_metric_type].avg) 456 | print_dic = {k:v/9.0 for k, v in print_dic.items()} 457 | sub_print_dic = {k:v/3.0 for k, v in sub_print_dic.items()} 458 | 459 | print_text = ' Epoch: [{0}][{1}/{2}-Batch {3}], '.format(epoch, (batch_index+1)*batch_size*gpu_count, num_sample, batch_index) 460 | print_text += "Loss {:.8f}, ".format(recorder["loss"].avg) 461 | print_text += "Cls Loss {:.8f}, ".format(recorder["cls_loss"].avg) 462 | print_text += "Reg Loss {:.8f}, ".format(recorder["reg_loss"].avg) 463 | for k, v in print_dic.items(): 464 | print_text += k + " {:.4f}, ".format(v) 465 | for k, v in sub_print_dic.items(): 466 | print_text += k + " {:.4f}, ".format(v) 467 | print_text += "Time(s): {:.4f}".format(time.time()-start_time) 468 | print_text += detail_text 469 | 470 | print_text += ", LR: {:.4e}".format(scheduler.get_last_lr()[0]) 471 | print(print_text, flush=True) 472 | 473 | ## Reinit Train Recorder 474 | recorder = {} 475 | for agent_type in agent_type_lis: 476 | for acc_metric_type in acc_metric_type_lis: 477 | recorder[agent_type+"_"+acc_metric_type] = AverageMeter() 478 | for length in length_lis: 479 | for metric_type in metric_type_lis: 480 | recorder[agent_type+"_"+str(length)+"_"+metric_type] = AverageMeter() 481 | recorder["loss"] = AverageMeter() 482 | recorder["reg_loss"] = AverageMeter() 483 | recorder["cls_loss"] = AverageMeter() 484 | 485 | if is_train and ((batch_index+1) % val_every_train_step == 0 or (val_every_train_step <= 0 and batch_index == (len(dataloader)-1))) and not is_dev: 486 | if gpu_count > 1: 487 | val_model = model.module 488 | else: 489 | val_model = model 490 | val_model.eval() 491 | run_model(val_dataloader, val_sample_num, val_model, optimizer, scheduler, epoch, gpu, global_rank, gpu_count, is_train=False, args=args, scaler=scaler, amp_data_type=amp_data_type) 492 | model.train() 493 | file_path = os.path.join(snapshot_dir, "Epoch_"+str(epoch)+"_batch"+str(batch_index)+".pt") 494 | checkpoint = {} 495 | 496 | if gpu_count > 1: 497 | checkpoint["model_state_dict"] = model.module.state_dict() 498 | else: 499 | checkpoint["model_state_dict"] = model.state_dict() 500 | 501 | checkpoint['optimizer_state_dict'] = optimizer.state_dict() 502 | torch.save(checkpoint, file_path) 503 | print("Epoch %d Batch %d Save Model"%(epoch, batch_index)) 504 | del data 505 | if (not is_train): 506 | print_dic = {metric_type:0.0 for metric_type in metric_type_lis} 507 | sub_print_dic = {} 508 | for agent_type in agent_type_lis: 509 | for metric_type in metric_type_lis: 510 | sub_print_dic[agent_type + "_" + metric_type] = 0 511 | for length in length_lis: 512 | sub_print_dic[agent_type + "_" + metric_type] += recorder[agent_type+"_"+str(length)+"_"+metric_type].avg 513 | 514 | detail_text = "" 515 | for agent_type in agent_type_lis: 516 | for length in length_lis: 517 | for metric_type in metric_type_lis: 518 | print_dic[metric_type] += recorder[agent_type+"_"+str(length)+"_"+metric_type].avg 519 | detail_text += ", "+agent_type+"_"+str(length)+"_" + metric_type + " {:.4f}".format(recorder[agent_type+"_"+str(length)+"_"+metric_type].avg) 520 | for agent_type in agent_type_lis: 521 | for acc_metric_type in acc_metric_type_lis: 522 | detail_text += ", "+agent_type+"_" + acc_metric_type + " "+str(recorder[agent_type+"_"+acc_metric_type].avg) 523 | print_dic = {k:v/9.0 for k, v in print_dic.items()} 524 | sub_print_dic = {k:v/3.0 for k, v in sub_print_dic.items()} 525 | 526 | print_text = '****Val Epoch: [{0}][{1}/{2}], '.format(epoch, (batch_index+1)*batch_size*gpu_count, num_sample) 527 | print_text += "Loss {:.8f}, ".format(recorder["loss"].avg) 528 | print_text += "Cls Loss {:.8f}, ".format(recorder["cls_loss"].avg) 529 | print_text += "Reg Loss {:.8f}, ".format(recorder["reg_loss"].avg) 530 | for k, v in print_dic.items(): 531 | print_text += k + " {:.4f}, ".format(v) 532 | for k, v in sub_print_dic.items(): 533 | print_text += k + " {:.4f}, ".format(v) 534 | print_text += "Time(s): {:.4f}".format(time.time()-start_time) 535 | print_text += detail_text 536 | print(print_text, flush=True) 537 | 538 | if __name__ == '__main__': 539 | main() 540 | -------------------------------------------------------------------------------- /training/waymo_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import torch 4 | from torch.utils.data import Dataset, DataLoader 5 | from torch.utils.data.sampler import BatchSampler 6 | import random 7 | import numpy as np 8 | import pandas as pd 9 | from functools import partial 10 | 11 | 12 | 13 | class HDGTDataset(Dataset): 14 | def __init__(self, dataset_path, data_folder, is_train, num_of_data, train_sample_batch_lookup): 15 | self.dataset_path = dataset_path 16 | self.data_folder = data_folder 17 | self.num_of_data = num_of_data 18 | self.is_train = is_train 19 | self.train_sample_batch_lookup = train_sample_batch_lookup ## To check which folder the sample is in 20 | def __getitem__(self, idx): 21 | for i in range(1, len(self.train_sample_batch_lookup)): 22 | if idx >= self.train_sample_batch_lookup[i-1]["cumulative_sample_cnt"] and idx < self.train_sample_batch_lookup[i]["cumulative_sample_cnt"]: 23 | batch_index = i-1 24 | break 25 | file_name = os.path.join(self.dataset_path, self.train_sample_batch_lookup[batch_index+1]["data_folder"], self.data_folder+"_case"+str(idx-self.train_sample_batch_lookup[batch_index]["cumulative_sample_cnt"])+".pkl") 26 | with open(file_name, "rb") as f: 27 | sample = pickle.load(f) 28 | return sample 29 | 30 | def __len__(self): 31 | return self.num_of_data 32 | 33 | ## To make sure each batch has approximately the same number of node 34 | class BalancedBatchSampler(BatchSampler): 35 | def __init__(self, input_size_lis, seed_num, gpu, gpu_cnt, batch_size, is_train): 36 | self.batch_size = batch_size 37 | input_size_lis = input_size_lis 38 | sorted_index = input_size_lis.argsort()[::-1].tolist() 39 | self.index_lis = [] 40 | self.is_train = is_train 41 | for i in range(self.batch_size): 42 | self.index_lis.append(sorted_index[int(len(sorted_index)//self.batch_size * i):int(len(sorted_index)//self.batch_size * (i+1))]) 43 | if len(sorted_index)//self.batch_size * self.batch_size < len(sorted_index): 44 | self.index_lis[-1] = self.index_lis[-1] + sorted_index[len(sorted_index)//self.batch_size * self.batch_size:] 45 | self.seed_num = seed_num 46 | self.gpu = gpu 47 | self.sample_per_gpu = len(self.index_lis[0])//gpu_cnt 48 | 49 | def __iter__(self): 50 | if self.is_train: 51 | for i in range(len(self.index_lis)): 52 | random.Random(self.seed_num+i).shuffle(self.index_lis[i]) 53 | self.seed_num += 1 54 | for i in range(int(self.gpu*self.sample_per_gpu), int((self.gpu+1)*self.sample_per_gpu)): 55 | yield [self.index_lis[j][i] for j in range(self.batch_size)] 56 | def __len__(self): 57 | return self.sample_per_gpu 58 | 59 | 60 | @torch.no_grad() 61 | def obtain_dataset(gpu, gpu_count, seed_num, args): 62 | dataset_path = os.path.join(os.path.dirname(os.getcwd()), "dataset", "waymo") 63 | if args.dev_mode == "True": 64 | seed_num = 0 65 | print(gpu, seed_num, flush=True) 66 | 67 | data_folder = args.data_folder 68 | train_folder = "training" 69 | num_of_train_folder = 12 70 | val_folder = "validation" 71 | 72 | ## Initialize 73 | train_num_of_agent_arr = [] 74 | train_sample_batch_lookup = [{"cumulative_sample_cnt":0}] 75 | for train_pacth_index in range(num_of_train_folder): 76 | with open(os.path.join(dataset_path, train_folder, data_folder+str(train_pacth_index), data_folder+"_number_of_case.pkl"), "rb") as f: 77 | train_num_of_agent_arr.append(pickle.load(f)) 78 | train_sample_batch_lookup.append({"cumulative_sample_cnt":train_sample_batch_lookup[-1]["cumulative_sample_cnt"]+train_num_of_agent_arr[-1].shape[0], "data_folder":os.path.join("training", data_folder+str(train_pacth_index))}) 79 | train_num_of_agent_arr = np.concatenate(train_num_of_agent_arr, axis=0) 80 | 81 | val_num_of_agent_arr = [] 82 | val_sample_batch_lookup = [{"cumulative_sample_cnt":0}] 83 | with open(os.path.join(dataset_path, val_folder, data_folder+str(num_of_train_folder), data_folder+"_number_of_case.pkl"), "rb") as f: 84 | val_num_of_agent_arr = pickle.load(f) 85 | val_sample_batch_lookup.append({"cumulative_sample_cnt":val_num_of_agent_arr.shape[0], "data_folder":os.path.join("validation", data_folder+str(num_of_train_folder))}) 86 | 87 | 88 | if args.dev_mode == "True": 89 | args.num_worker = 0 90 | dev_train_num = 2 91 | train_num_of_agent_arr = train_num_of_agent_arr[:dev_train_num] 92 | val_num_of_agent_arr = val_num_of_agent_arr[:dev_train_num] 93 | 94 | train_sampler = BalancedBatchSampler(train_num_of_agent_arr, seed_num=seed_num, gpu=gpu, gpu_cnt=gpu_count, batch_size=args.batch_size, is_train=True) 95 | if gpu == 0: 96 | val_sampler = BalancedBatchSampler(val_num_of_agent_arr, seed_num=seed_num, gpu=0, gpu_cnt=1, batch_size=args.val_batch_size, is_train=False) 97 | print("train sample num:", len(train_num_of_agent_arr), "val sample num:", len(val_num_of_agent_arr), flush=True) 98 | 99 | 100 | train_dataset = HDGTDataset(dataset_path=dataset_path, data_folder=args.data_folder, is_train=True, 101 | num_of_data=len(train_num_of_agent_arr)//gpu_count, train_sample_batch_lookup=train_sample_batch_lookup) 102 | setting_dic = {} 103 | train_dataloader = DataLoader(train_dataset, pin_memory=True, collate_fn=partial(HDGT_collate_fn, setting_dic=setting_dic, args=args, is_train=True), batch_sampler=train_sampler, num_workers=args.num_worker) 104 | train_sample_num = len(train_dataset) * gpu_count 105 | 106 | val_dataloader = None 107 | val_sample_num = 0 108 | if gpu == 0: 109 | val_worker_num = args.num_worker 110 | # if args.is_local == "multi_node" or args.is_local == "FalseM": 111 | # val_worker_num *= 7 112 | val_dataset = HDGTDataset(dataset_path=dataset_path, data_folder=args.data_folder, is_train=False, num_of_data=len(val_num_of_agent_arr), train_sample_batch_lookup=val_sample_batch_lookup) 113 | val_dataloader = DataLoader(val_dataset, pin_memory=True, collate_fn=partial(HDGT_collate_fn, setting_dic=setting_dic, args=args, is_train=False), batch_sampler=val_sampler, num_workers=val_worker_num) 114 | val_sample_num = len(val_dataset) 115 | if gpu == 0: 116 | print('data loaded', flush=True) 117 | return train_dataloader, val_dataloader, train_sample_num, val_sample_num 118 | 119 | 120 | import numpy as np 121 | import torch 122 | import dgl 123 | import random 124 | import math 125 | 126 | def euclid_np(label, pred): 127 | return np.sqrt((label[...,0]-pred[...,0])**2 + (label[...,1]-pred[...,1])**2) 128 | 129 | uv_dict = {} 130 | ## Sparse adj mat of fully connected graph of neighborhood size 131 | def return_uv(neighborhood_size): 132 | global uv_dict 133 | if neighborhood_size in uv_dict: 134 | return uv_dict[neighborhood_size] 135 | else: 136 | v = torch.LongTensor([[_]*(neighborhood_size-1) for _ in range(neighborhood_size)]).view(-1) 137 | u = torch.LongTensor([list(range(0, _)) +list(range(_+1,neighborhood_size)) for _ in range(neighborhood_size)]).view(-1) 138 | uv_dict[neighborhood_size] = (u, v) 139 | return (u, v) 140 | 141 | def generate_heterogeneous_graph(agent_fea, map_fea, agent_map_size_lis): 142 | max_in_edge_per_type = 32 ## For saving GPU memory 143 | uv_dic = {} 144 | uv_dic[("agent", "self", "agent")] = [list(range(agent_fea.shape[0])), list(range(agent_fea.shape[0]))] ## Self-loop 145 | num_of_agent = agent_fea.shape[0] 146 | ## Agent Adj 147 | uv_dic[("agent", "other", "agent")] = [[], []] 148 | for agent_index_i in range(num_of_agent): 149 | final_dist_between_agent = euclid_np(agent_fea[agent_index_i, -1, :][np.newaxis, :2], agent_fea[:, -1, :2]) 150 | nearby_agent_index = np.where(final_dist_between_agent < np.maximum(agent_map_size_lis[agent_index_i][np.newaxis], agent_map_size_lis))[0] 151 | nearby_agent_index = np.delete(nearby_agent_index, obj=np.where(nearby_agent_index == agent_index_i)) 152 | if len(nearby_agent_index) > max_in_edge_per_type: 153 | final_dist_between_agent_sorted_nearby_index = np.argsort(final_dist_between_agent[nearby_agent_index]) 154 | nearby_agent_index = nearby_agent_index[final_dist_between_agent_sorted_nearby_index][:max_in_edge_per_type] 155 | nearby_agent_index = nearby_agent_index.tolist() 156 | if len(nearby_agent_index) > 0: 157 | uv_dic[("agent", "other", "agent")][0] += [agent_index_i]*(len(nearby_agent_index)) 158 | uv_dic[("agent", "other", "agent")][1] += nearby_agent_index 159 | 160 | 161 | polygon_index_cnt = 0 162 | graphindex2polygonindex = {} 163 | uv_dic[("polygon", "g2a", "agent")] = [[], []] 164 | ## Agent_Polygon Adj 165 | if len(map_fea[1]) > 0: 166 | dist_between_agent_polygon = np.stack([(euclid_np(agent_fea[:, -1, :][:, np.newaxis, :], _[1][np.newaxis, :, :]).min(1)) for _ in map_fea[1]], axis=-1) 167 | all_agent_nearby_polygon_index_lis = dist_between_agent_polygon < agent_map_size_lis[:, np.newaxis] 168 | for agent_index_i in range(num_of_agent): 169 | nearby_polygon_index_lis = np.where(all_agent_nearby_polygon_index_lis[agent_index_i, :])[0] 170 | if len(nearby_polygon_index_lis) > max_in_edge_per_type: 171 | current_dist_between_agent_polygon = dist_between_agent_polygon[agent_index_i, :] 172 | nearby_polygon_index_lis_sorted = np.argsort(current_dist_between_agent_polygon[nearby_polygon_index_lis]) 173 | nearby_polygon_index_lis = nearby_polygon_index_lis[nearby_polygon_index_lis_sorted][:max_in_edge_per_type] 174 | nearby_polygon_index_lis = nearby_polygon_index_lis.tolist() 175 | for now_cnt, nearby_polygon_index in enumerate(nearby_polygon_index_lis): 176 | uv_dic[("polygon", "g2a", "agent")][0].append(polygon_index_cnt) 177 | uv_dic[("polygon", "g2a", "agent")][1].append(agent_index_i) 178 | graphindex2polygonindex[polygon_index_cnt] = nearby_polygon_index 179 | polygon_index_cnt += 1 180 | 181 | laneindex2graphindex = {} 182 | graphindex_cnt = 0 183 | uv_dic[("lane", "l2a", "agent")] = [[], []] 184 | uv_dic[("agent", "a2l", "lane")] = [[], []] 185 | ## Agent-Map Adj 186 | if len(map_fea[0]) > 0: 187 | all_polyline_coor = np.array([_["xyz"] for _ in map_fea[0]]) 188 | final_dist_between_agent_lane = euclid_np(agent_fea[:, -1, :2][:, np.newaxis, np.newaxis, :], all_polyline_coor[np.newaxis, :, :, :]).min(2) 189 | all_agent_nearby_lane_index_lis = final_dist_between_agent_lane < agent_map_size_lis[:, np.newaxis] 190 | for agent_index_i in range(num_of_agent): 191 | nearby_road_index_lis = np.where(all_agent_nearby_lane_index_lis[agent_index_i, :])[0]#.tolist() 192 | if len(nearby_road_index_lis) > max_in_edge_per_type: 193 | current_dist_between_agent_lane = final_dist_between_agent_lane[agent_index_i] 194 | nearby_road_index_lis_sorted = np.argsort(current_dist_between_agent_lane[nearby_road_index_lis]) 195 | nearby_road_index_lis = nearby_road_index_lis[nearby_road_index_lis_sorted][:max_in_edge_per_type] 196 | nearby_road_index_lis = nearby_road_index_lis.tolist() 197 | for now_cnt, nearby_road_index in enumerate(nearby_road_index_lis): 198 | if nearby_road_index not in laneindex2graphindex: 199 | laneindex2graphindex[nearby_road_index] = graphindex_cnt 200 | graphindex_cnt += 1 201 | uv_dic[("agent", "a2l", "lane")][0].append(agent_index_i) 202 | uv_dic[("lane", "l2a", "agent")][1].append(agent_index_i) 203 | uv_dic[("lane", "l2a", "agent")][0].append(laneindex2graphindex[nearby_road_index]) 204 | uv_dic[("agent", "a2l", "lane")][1].append(laneindex2graphindex[nearby_road_index]) 205 | 206 | lane2lane_boundary_dic = {} 207 | ## Map-Map Adj 208 | for etype in ["left", "right", "prev", "follow"]: 209 | uv_dic[("lane", etype, "lane")] = [[], []] 210 | lane2lane_boundary_dic[("lane", etype, "lane")] = [] 211 | if len(map_fea[0]) > 0: 212 | all_in_graph_lane = list(laneindex2graphindex.keys()) 213 | for in_graph_lane in all_in_graph_lane: 214 | info_dic = map_fea[0][in_graph_lane] 215 | for etype in ["left", "right", "prev", "follow"]: 216 | neighbors = [_ for _ in info_dic[etype] if _[0] in laneindex2graphindex] 217 | lane2lane_boundary_dic[("lane", etype, "lane")] += [_[1] for _ in neighbors] 218 | neighbors = [_[0] for _ in neighbors] 219 | uv_dic[("lane", etype, "lane")][0] += [laneindex2graphindex[in_graph_lane]] * len(neighbors) 220 | uv_dic[("lane", etype, "lane")][1] += [laneindex2graphindex[_] for _ in neighbors] 221 | 222 | output_dic = {} 223 | for _ in uv_dic: 224 | uv_dic[_] = (torch.LongTensor(uv_dic[_][0]), torch.LongTensor(uv_dic[_][1])) 225 | 226 | output_dic["uv_dic"] = uv_dic 227 | output_dic["graphindex2polylineindex"] = {v: k for k, v in laneindex2graphindex.items()} 228 | output_dic["graphindex2polygonindex"] = graphindex2polygonindex 229 | output_dic["boundary_type_dic"] = {k:torch.LongTensor(v) for k, v in lane2lane_boundary_dic.items()} 230 | return output_dic 231 | 232 | def rotate(data, cos_theta, sin_theta): 233 | data[..., 0], data[..., 1] = data[..., 0]*cos_theta - data[..., 1]*sin_theta, data[..., 1]*cos_theta + data[..., 0]*sin_theta 234 | return data 235 | 236 | def normal_agent_feature(feature, ref_coor, ref_psi, cos_theta, sin_theta): 237 | feature[..., :3] -= ref_coor[:, np.newaxis, :] 238 | feature[..., 0], feature[..., 1] = feature[..., 0]*cos_theta - feature[..., 1]*sin_theta, feature[..., 1]*cos_theta + feature[..., 0]*sin_theta 239 | feature[..., 3], feature[..., 4] = feature[..., 3]*cos_theta - feature[..., 4]*sin_theta, feature[..., 4]*cos_theta + feature[..., 3]*sin_theta 240 | feature[..., 5] -= ref_psi 241 | cos_psi = np.cos(feature[..., 5]) 242 | sin_psi = np.sin(feature[..., 5]) 243 | feature = np.concatenate([feature[..., :5], cos_psi[...,np.newaxis], sin_psi[...,np.newaxis], feature[..., 6:]], axis=-1) 244 | return feature 245 | 246 | def normal_polygon_feature(all_polygon_coor, all_polygon_type, ref_coor, cos_theta, sin_theta): 247 | now_polygon_coor = all_polygon_coor - ref_coor 248 | rotate(now_polygon_coor, cos_theta, sin_theta) 249 | return now_polygon_coor, all_polygon_type 250 | 251 | def normal_lane_feature(now_polyline_coor, now_polyline_type, now_polyline_speed_limit, now_polyline_stop, now_polyline_signal, polyline_index, ref_coor, cos_theta, sin_theta): 252 | output_polyline_coor = now_polyline_coor[polyline_index] - ref_coor[:, np.newaxis, :] 253 | rotate(output_polyline_coor, cos_theta, sin_theta) 254 | output_stop_fea = {i:np.array(now_polyline_stop[_][0]) for i, _ in enumerate(polyline_index) if len(now_polyline_stop[_]) != 0} 255 | output_signal_fea = {i:np.array(now_polyline_signal[_][0]) for i, _ in enumerate(polyline_index) if len(now_polyline_signal[_]) != 0} 256 | output_stop_index, output_stop_fea = list(output_stop_fea.keys()), list(output_stop_fea.values()) 257 | 258 | if len(output_stop_fea) != 0: 259 | output_stop_fea = np.stack(output_stop_fea, axis=0) 260 | output_stop_fea -= ref_coor[output_stop_index] 261 | if type(cos_theta) == np.float64: 262 | rotate(output_stop_fea, cos_theta, sin_theta) 263 | else: 264 | rotate(output_stop_fea, cos_theta[output_stop_index].flatten(), sin_theta[output_stop_index].flatten()) 265 | 266 | output_signal_index, output_signal_fea = list(output_signal_fea.keys()), list(output_signal_fea.values()) 267 | if len(output_signal_fea) != 0: 268 | output_signal_fea = np.stack(output_signal_fea, axis=0) 269 | output_signal_fea[..., :3] -= ref_coor[output_signal_index] 270 | if type(cos_theta) == np.float64: 271 | rotate(output_signal_fea, cos_theta, sin_theta) 272 | else: 273 | rotate(output_signal_fea, cos_theta[output_signal_index].flatten(), sin_theta[output_signal_index].flatten()) 274 | return output_polyline_coor, now_polyline_type[polyline_index], now_polyline_speed_limit[polyline_index], output_stop_fea, output_stop_index, output_signal_fea, output_signal_index 275 | 276 | def return_rel_e_feature(src_ref_coor, dst_ref_coor, src_ref_psi, dst_ref_psi): 277 | rel_coor = src_ref_coor - dst_ref_coor 278 | if rel_coor.ndim == 0 or rel_coor.ndim == 1: 279 | rel_coor = np.atleast_1d(rel_coor)[np.newaxis, :] 280 | rel_coor = rotate(rel_coor, np.cos(-dst_ref_psi), np.sin(-dst_ref_psi)) 281 | rel_psi = np.atleast_1d(src_ref_psi - dst_ref_psi)[:, np.newaxis] 282 | rel_sin_theta = np.sin(rel_psi) 283 | rel_cos_theta = np.cos(rel_psi) 284 | return np.concatenate([rel_coor, rel_sin_theta, rel_cos_theta], axis=-1) 285 | 286 | 287 | map_size_lis = {1.0:30, 2.0:10, 3.0:20} 288 | @torch.no_grad() 289 | def HDGT_collate_fn(batch, setting_dic, args, is_train): 290 | agent_drop = args.agent_drop 291 | 292 | agent_feature_lis = [item["agent_feature"] for item in batch] 293 | agent_type_lis = [item["agent_type"] for item in batch] 294 | #agent_map_size_lis = [np.vectorize(setting_dic["agenttype2mapsize"].get)(_) for _ in agent_type_lis] 295 | pred_num_lis = np.array([item["pred_num"] for item in batch]) 296 | label_lis = [item["label"] for item in batch] 297 | auxiliary_label_lis = [item["auxiliary_label"] for item in batch] 298 | label_mask_lis = [item["label_mask"] for item in batch] 299 | other_label_lis = [item["other_label"] for item in batch] 300 | other_label_mask_lis = [item["other_label_mask"] for item in batch] 301 | map_fea_lis = [item["map_fea"] for item in batch] 302 | case_id_lis = [item["scene_id"] for item in batch] 303 | object_id_lis = [item["obejct_id_lis"] for item in batch] 304 | 305 | if agent_drop > 0 and is_train: 306 | for i in range(len(agent_feature_lis)): 307 | keep_index = (np.random.random(agent_feature_lis[i].shape[0]) > agent_drop) 308 | while keep_index[:pred_num_lis[i]].sum() == 0: 309 | keep_index = (np.random.random(agent_feature_lis[i].shape[0]) > agent_drop) 310 | origin_pred_num = pred_num_lis[i] 311 | original_agent_num = agent_feature_lis[i].shape[0] 312 | target_keep_index = keep_index[:origin_pred_num] 313 | agent_feature_lis[i] = agent_feature_lis[i][keep_index] 314 | agent_type_lis[i] = agent_type_lis[i][keep_index] 315 | pred_num_lis[i] = int(target_keep_index.sum()) 316 | 317 | label_lis[i] = label_lis[i][target_keep_index] 318 | auxiliary_label_lis[i] = auxiliary_label_lis[i][target_keep_index] 319 | label_mask_lis[i] = label_mask_lis[i][target_keep_index] 320 | if origin_pred_num != original_agent_num: 321 | other_label_lis[i] = other_label_lis[i][keep_index[origin_pred_num:]] 322 | other_label_mask_lis[i] = other_label_mask_lis[i][keep_index[origin_pred_num:]] 323 | 324 | neighbor_size = np.array([int(agent_feature_lis[i].shape[0]) for i in range(len(agent_feature_lis))]) 325 | 326 | out_lane_n_stop_sign_fea_lis = [] 327 | out_lane_n_stop_sign_index_lis = [] 328 | out_lane_n_signal_fea_lis = [] 329 | out_lane_n_signal_index_lis = [] 330 | 331 | out_normal_lis = [] 332 | out_graph_lis = [] 333 | out_label_lis = [] 334 | out_label_mask_lis = [] 335 | out_auxiliary_label_lis = [] 336 | out_auxiliary_label_future_lis = [] 337 | out_other_label_lis = [] 338 | out_other_label_mask_lis = [] 339 | lane_n_cnt = 0 340 | 341 | for i in range(len(agent_feature_lis)): 342 | all_agent_obs_final_v = np.sqrt(agent_feature_lis[i][:, -1, 3]**2+agent_feature_lis[i][:, -1, 4]**2) 343 | all_agent_map_size = np.vectorize(map_size_lis.__getitem__)(agent_type_lis[i]) 344 | all_agent_map_size = all_agent_obs_final_v * 8.0 + all_agent_map_size 345 | 346 | graph_dic = generate_heterogeneous_graph(agent_feature_lis[i], map_fea_lis[i], all_agent_map_size) 347 | g = dgl.heterograph(data_dict=graph_dic["uv_dic"]) 348 | g.edata['boundary_type'] = graph_dic["boundary_type_dic"] 349 | 350 | polylinelaneindex = list(graph_dic["graphindex2polylineindex"].values()) 351 | polygonlaneindex = list(graph_dic["graphindex2polygonindex"].values()) 352 | now_agent_feature = agent_feature_lis[i] 353 | now_agent_type = agent_type_lis[i] 354 | 355 | ### Type 0 edge a2a self-loop 356 | type0_u, type0_v = g.edges(etype="self") 357 | now_t0_v_feature = now_agent_feature[type0_v, :, :] 358 | now_t0_e_feature = now_agent_feature[type0_u].copy() 359 | if len(type0_v) == 1: 360 | now_t0_v_feature = now_t0_v_feature[np.newaxis, :, :] 361 | now_t0_e_feature = now_t0_e_feature[np.newaxis, :, :] 362 | now_t0_e_feature = return_rel_e_feature(now_t0_e_feature[:, -1, :3], now_t0_v_feature[:, -1, :3], now_t0_e_feature[:, -1, 5], now_t0_v_feature[:, -1, 5]) 363 | g.edata['a_e_fea'] = {("agent", "self", "agent"):torch.as_tensor(now_t0_e_feature.astype(np.float32))} 364 | g.edata['a_e_type'] = {("agent", "self", "agent"):torch.as_tensor((now_agent_type[type0_u].ravel()-1).astype(np.int32)).long()} 365 | 366 | ### Type 0 edge a2a other agent 367 | type1_u, type1_v = g.edges(etype="other") 368 | if len(type1_v) > 0: 369 | now_t1_v_feature = now_agent_feature[type1_v, :, :] 370 | now_t1_e_feature = now_agent_feature[type1_u].copy() 371 | if len(type1_v) == 1: 372 | now_t1_v_feature = now_t1_v_feature[np.newaxis, :, :] 373 | now_t1_e_feature = now_t1_e_feature[np.newaxis, :, :] 374 | now_t1_e_feature = return_rel_e_feature(now_t1_e_feature[:, -1, :3], now_t1_v_feature[:, -1, :3], now_t1_e_feature[:, -1, 5], now_t1_v_feature[:, -1, 5]) 375 | g.edata['a_e_fea'] = {("agent", "other", "agent"):torch.as_tensor(now_t1_e_feature.astype(np.float32))} 376 | g.edata['a_e_type'] = {("agent", "other", "agent"):torch.as_tensor((now_agent_type[type1_u].ravel()-1).astype(np.int32)).long()} 377 | else: 378 | g.edata['a_e_fea'] = {("agent", "other", "agent"):torch.zeros((0, 5))} 379 | g.edata['a_e_type'] = {("agent", "other", "agent"):torch.zeros((0, )).long()} 380 | 381 | ### Type 2 Edge: Agent -> Lane a2l 382 | if len(polylinelaneindex) > 0: 383 | now_polyline_info = [map_fea_lis[i][0][_] for _ in polylinelaneindex] 384 | now_polyline_coor = np.stack([_["xyz"] for _ in now_polyline_info], axis=0) 385 | now_polyline_yaw = np.array([_["yaw"] for _ in now_polyline_info]) 386 | now_polyline_type = np.array([_["type"] for _ in now_polyline_info]) 387 | now_polyline_speed_limit = np.array([_["speed_limit"] for _ in now_polyline_info]) 388 | now_polyline_stop = [_["stop"] for _ in now_polyline_info] 389 | now_polyline_signal = [_["signal"] for _ in now_polyline_info] 390 | now_polyline_mean_coor = now_polyline_coor[:, 2, :] 391 | type2_u = g.edges(etype="a2l")[0]#[0][cumu_edge_type_cnt_lis[2]:cumu_edge_type_cnt_lis[3]] 392 | type2_v = g.edges(etype="a2l")[1]#[1][cumu_edge_type_cnt_lis[2]:cumu_edge_type_cnt_lis[3]] - now_agent_feature.shape[0] - len(polygonlaneindex) 393 | if len(type2_v) > 0: 394 | now_t2_e_feature = now_agent_feature[type2_u].copy() 395 | if len(now_t2_e_feature.shape) == 2: 396 | now_t2_e_feature = now_t2_e_feature[np.newaxis, :, :] 397 | now_t2_e_feature = return_rel_e_feature(now_t2_e_feature[:, -1, :3], now_polyline_mean_coor[type2_v], now_t2_e_feature[:, -1, 5], now_polyline_yaw[type2_v]) 398 | g.edata['a_e_fea'] = {("agent", "a2l", "lane"):torch.as_tensor(now_t2_e_feature.astype(np.float32))} 399 | g.edata['a_e_type'] = {("agent", "a2l", "lane"):torch.as_tensor((now_agent_type[type2_u].ravel()-1).astype(np.int32)).long()} 400 | 401 | 402 | ### Type 3 Edge: Polygon -> Agent g2a 403 | type3_u = g.edges(etype="g2a")[0]#[cumu_edge_type_cnt_lis[3]:cumu_edge_type_cnt_lis[4]] - now_agent_feature.shape[0] 404 | type3_v = g.edges(etype="g2a")[1]#[cumu_edge_type_cnt_lis[3]:cumu_edge_type_cnt_lis[4]] 405 | if len(type3_v) > 0: 406 | now_polygon_type = np.array([map_fea_lis[i][1][_][0] for _ in polygonlaneindex]) 407 | now_polygon_coor = np.stack([map_fea_lis[i][1][_][1] for _ in polygonlaneindex], axis=0) 408 | now_t3_v_feature = now_agent_feature[type3_v] 409 | if len(now_t3_v_feature.shape) == 2: 410 | now_t3_v_feature = now_t3_v_feature[np.newaxis, :, :] 411 | ref_coor = now_t3_v_feature[:, -1, :3][:, np.newaxis, :] 412 | ref_psi = now_t3_v_feature[:, -1, 5][:, np.newaxis].copy() 413 | sin_theta = np.sin(-ref_psi) 414 | cos_theta = np.cos(-ref_psi) 415 | now_t3_e_coor_feature, now_t3_e_type_feature = normal_polygon_feature(now_polygon_coor, now_polygon_type, ref_coor, cos_theta, sin_theta) 416 | g.edata['g2a_e_fea'] = {("polygon", "g2a", "agent"):torch.as_tensor(now_t3_e_coor_feature.astype(np.float32))} 417 | g.edata['g2a_e_type'] = {("polygon", "g2a", "agent"):torch.as_tensor(now_t3_e_type_feature.ravel().astype(np.int32)).long()} 418 | 419 | ### Type 4 Edge: Lane -> Agent 420 | if len(polylinelaneindex) > 0: 421 | type4_u = g.edges(etype="l2a")[0] 422 | type4_v = g.edges(etype="l2a")[1] 423 | if len(type4_v) > 0: 424 | now_t4_v_feature = now_agent_feature[type4_v] 425 | if len(now_t4_v_feature.shape) == 2: 426 | now_t4_v_feature = now_t4_v_feature[np.newaxis, :, :] 427 | now_t4_e_feature = return_rel_e_feature(now_polyline_mean_coor[type4_u], now_t4_v_feature[:, -1, :3], now_polyline_yaw[type4_u], now_t4_v_feature[:, -1, 5]) 428 | g.edata['l_e_fea'] = {("lane", "l2a", "agent"):torch.as_tensor(now_t4_e_feature.astype(np.float32))} 429 | 430 | ### Type 5 Edge: Lane -> Lane 431 | if len(polylinelaneindex) > 0: 432 | for etype in ["left", "right", "prev", "follow"]: 433 | type5_u = g.edges(etype=etype)[0] 434 | type5_v = g.edges(etype=etype)[1] 435 | if len(type5_v) > 0: 436 | now_t5_e_feature = return_rel_e_feature(now_polyline_mean_coor[type5_u], now_polyline_mean_coor[type5_v], now_polyline_yaw[type5_u], now_polyline_yaw[type5_v]) 437 | g.edata['l_e_fea'] = {("lane", etype, "lane"):torch.as_tensor(now_t5_e_feature.astype(np.float32))} 438 | 439 | now_pred_num = pred_num_lis[i] 440 | selected_pred_indices = list(range(0, now_pred_num)) 441 | non_pred_indices = list(range(now_pred_num, now_agent_feature.shape[0])) 442 | 443 | ## Label + Full Agent Feature 444 | now_full_agent_n_feature = now_agent_feature[selected_pred_indices].copy() 445 | ref_coor = now_full_agent_n_feature[:, -1,:3].copy() 446 | now_label = label_lis[i][selected_pred_indices].copy() 447 | now_auxiliary_label = auxiliary_label_lis[i][selected_pred_indices].copy() 448 | now_label = now_label - ref_coor[:, np.newaxis, :2] 449 | ref_psi = now_full_agent_n_feature[:, -1, 5][:, np.newaxis].copy() 450 | normal_val = np.concatenate([ref_coor[..., :2], ref_psi], axis=-1) 451 | out_normal_lis.append(normal_val) 452 | 453 | sin_theta = np.sin(-ref_psi) 454 | cos_theta = np.cos(-ref_psi) 455 | rotate(now_label, cos_theta, sin_theta) 456 | rotate(now_auxiliary_label, cos_theta, sin_theta) 457 | now_auxiliary_label[..., 2] = now_auxiliary_label[..., 2] - ref_psi 458 | 459 | now_full_agent_n_feature = normal_agent_feature(now_full_agent_n_feature, ref_coor, ref_psi, cos_theta, sin_theta) 460 | now_auxiliary_label_future = now_auxiliary_label.copy() 461 | now_auxiliary_label = np.stack([now_full_agent_n_feature[..., 3], now_full_agent_n_feature[..., 4], now_agent_feature[selected_pred_indices, :, 5]-ref_psi, now_full_agent_n_feature[..., -1]], axis=-1) 462 | 463 | 464 | now_all_agent_n_feature = now_full_agent_n_feature 465 | if now_pred_num < now_agent_feature.shape[0]: 466 | now_other_agent_n_feature = now_agent_feature[non_pred_indices].copy() 467 | ref_coor = now_other_agent_n_feature[:, -1, :3] 468 | ref_psi = now_other_agent_n_feature[:, -1, 5][:, np.newaxis].copy() 469 | sin_theta = np.sin(-ref_psi) 470 | cos_theta = np.cos(-ref_psi) 471 | now_other_agent_n_feature = normal_agent_feature(now_other_agent_n_feature, ref_coor, ref_psi, cos_theta, sin_theta) 472 | now_all_agent_n_feature = np.concatenate([now_all_agent_n_feature, now_other_agent_n_feature], axis=0) 473 | g.ndata["a_n_fea"] = {"agent":torch.as_tensor(now_all_agent_n_feature.astype(np.float32))} 474 | g.ndata["a_n_type"] = {"agent":torch.as_tensor((now_agent_type-1).astype(np.int32)).long()} 475 | 476 | ## Lane Node Feature 477 | if len(polylinelaneindex) > 0: 478 | ref_coor = now_polyline_mean_coor 479 | ref_psi = now_polyline_yaw[:, np.newaxis].copy() 480 | sin_theta = np.sin(-ref_psi) 481 | cos_theta = np.cos(-ref_psi) 482 | now_lane_n_coor_feature, now_lane_n_type_feature, now_lane_n_speed_limit_feature, now_lane_n_stop_feature, now_lane_n_stop_index, now_lane_n_signal_feature, now_lane_n_signal_index = normal_lane_feature(now_polyline_coor, now_polyline_type, now_polyline_speed_limit, now_polyline_stop, now_polyline_signal, list(range(len(now_polyline_coor))), ref_coor, cos_theta, sin_theta) 483 | g.ndata["l_n_coor_fea"] = {"lane":torch.as_tensor(now_lane_n_coor_feature.astype(np.float32))} 484 | g.ndata["l_n_type_fea"] = {"lane":torch.as_tensor(now_lane_n_type_feature.astype(np.int32)).long()} 485 | 486 | ## Polyline Feature 487 | if len(polylinelaneindex) > 0: 488 | if len(now_lane_n_stop_index) != 0: 489 | out_lane_n_stop_sign_fea_lis.append(now_lane_n_stop_feature) 490 | out_lane_n_stop_sign_index_lis.append(np.array(now_lane_n_stop_index) + lane_n_cnt) 491 | if len(now_lane_n_signal_index) != 0: 492 | out_lane_n_signal_fea_lis.append(now_lane_n_signal_feature) 493 | out_lane_n_signal_index_lis.append(np.array(now_lane_n_signal_index)+lane_n_cnt) 494 | lane_n_cnt += now_lane_n_coor_feature.shape[0] 495 | 496 | out_graph_lis.append(g) 497 | out_label_lis.append(now_label) 498 | out_label_mask_lis.append(label_mask_lis[i][selected_pred_indices]) 499 | out_auxiliary_label_lis.append(now_auxiliary_label) 500 | out_auxiliary_label_future_lis.append(now_auxiliary_label_future) 501 | 502 | output_dic = {} 503 | #0-x, 1-y, 2-vx, 3-vy, 4-cos_psi, 5-sin_psi, 6-length, 7-width, 8-type, 9-mask 504 | output_dic["cuda_tensor_lis"] = ["graph_lis"] 505 | output_dic["cuda_tensor_lis"] += ["label_lis", "label_mask_lis", "auxiliary_label_lis", "auxiliary_label_future_lis"] 506 | if len(out_lane_n_stop_sign_fea_lis) > 0: 507 | output_dic["cuda_tensor_lis"] += ["lane_n_stop_sign_fea_lis", "lane_n_stop_sign_index_lis"] 508 | out_lane_n_stop_sign_index_lis = np.concatenate(out_lane_n_stop_sign_index_lis, axis=0) 509 | output_dic["lane_n_stop_sign_fea_lis"] = torch.as_tensor(np.concatenate(out_lane_n_stop_sign_fea_lis, axis=0).astype(np.float32)) 510 | output_dic["lane_n_stop_sign_index_lis"] = torch.as_tensor(out_lane_n_stop_sign_index_lis.astype(np.int32)).long() 511 | 512 | if len(out_lane_n_signal_fea_lis) > 0: 513 | output_dic["cuda_tensor_lis"] += ["lane_n_signal_fea_lis", "lane_n_signal_index_lis"] 514 | out_lane_n_signal_index_lis = np.concatenate(out_lane_n_signal_index_lis, axis=0) 515 | output_dic["lane_n_signal_fea_lis"] = torch.as_tensor(np.concatenate(out_lane_n_signal_fea_lis, axis=0).astype(np.float32)) 516 | output_dic["lane_n_signal_index_lis"] = torch.as_tensor(out_lane_n_signal_index_lis.astype(np.int32)).long() 517 | output_dic["label_lis"] = torch.as_tensor(np.concatenate(out_label_lis, axis=0).astype(np.float32)) 518 | output_dic["auxiliary_label_lis"] = torch.as_tensor(np.concatenate(out_auxiliary_label_lis, axis=0).astype(np.float32)) 519 | output_dic["auxiliary_label_future_lis"] = torch.as_tensor(np.concatenate(out_auxiliary_label_future_lis, axis=0).astype(np.float32)) 520 | 521 | output_dic["label_mask_lis"] = torch.as_tensor(np.concatenate(out_label_mask_lis, axis=0).astype(np.float32)) 522 | 523 | output_g = dgl.batch(out_graph_lis) 524 | a_e_type_dict = {} 525 | for out_etype in ["self", "a2l", "other"]: 526 | a_e_type_dict[out_etype] = [] 527 | for agent_tpye_index in range(3): 528 | a_e_type_dict[out_etype].append(torch.where(output_g.edges[out_etype].data["a_e_type"]==agent_tpye_index)[0]) 529 | a_n_type_lis = [torch.where(output_g.nodes["agent"].data["a_n_type"]==_)[0] for _ in range(3)] 530 | output_dic["a_e_type_dict"] = a_e_type_dict 531 | output_dic["a_n_type_lis"] = a_n_type_lis 532 | output_dic["graph_lis"] = output_g 533 | output_dic["neighbor_size_lis"] = neighbor_size 534 | output_dic["pred_num_lis"] = pred_num_lis 535 | output_dic["case_id_lis"] = case_id_lis 536 | output_dic["object_id_lis"] = object_id_lis 537 | output_dic["normal_lis"] = np.concatenate(out_normal_lis, axis=0) 538 | if "fname" in batch[0]: 539 | all_filename = [item["fname"] for item in batch] 540 | output_dic["fname"] = [] 541 | for _ in range(len(all_filename)): 542 | output_dic["fname"] += [all_filename[_]]*pred_num_lis[_] 543 | del batch 544 | return output_dic --------------------------------------------------------------------------------