├── .DS_Store ├── .gitignore ├── README.md ├── assets ├── architecture.jpg ├── bootstrap.min.css ├── corpus.js ├── editing.jpg ├── font.css ├── gpt_example.jpg ├── jquery.min.js ├── main_result.jpg ├── prosim_demo_video_1.mp4 ├── prosim_demo_video_2.gif ├── prosim_demo_video_2.mp4 ├── prosim_llm.jpg ├── prosim_model.jpg ├── prosim_table_1.jpg ├── prosim_table_2.jpg ├── style.css └── teaser.jpg ├── demo_dataset └── trajdata_cache │ └── waymo_train │ ├── maps │ ├── waymo_train_0.pb │ ├── waymo_train_0_2.00px_m.dill │ ├── waymo_train_0_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ ├── 1.1.1 │ │ ├── 2.0.0 │ │ ├── 2.0.1 │ │ ├── 2.1.0 │ │ └── 2.1.1 │ ├── waymo_train_0_kdtrees.dill │ ├── waymo_train_1.pb │ ├── waymo_train_10.pb │ ├── waymo_train_10_2.00px_m.dill │ ├── waymo_train_10_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ ├── 1.1.1 │ │ ├── 2.0.0 │ │ ├── 2.0.1 │ │ ├── 2.1.0 │ │ └── 2.1.1 │ ├── waymo_train_10_kdtrees.dill │ ├── waymo_train_11.pb │ ├── waymo_train_11_2.00px_m.dill │ ├── waymo_train_11_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ ├── 1.1.1 │ │ ├── 2.0.0 │ │ ├── 2.0.1 │ │ ├── 2.1.0 │ │ └── 2.1.1 │ ├── waymo_train_11_kdtrees.dill │ ├── waymo_train_12.pb │ ├── waymo_train_12_2.00px_m.dill │ ├── waymo_train_12_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ └── 1.1.1 │ ├── waymo_train_12_kdtrees.dill │ ├── waymo_train_13.pb │ ├── waymo_train_13_2.00px_m.dill │ ├── waymo_train_13_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ └── 1.1.1 │ ├── waymo_train_13_kdtrees.dill │ ├── waymo_train_14.pb │ ├── waymo_train_14_2.00px_m.dill │ ├── waymo_train_14_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ ├── 1.1.1 │ │ ├── 2.0.0 │ │ ├── 2.0.1 │ │ ├── 2.1.0 │ │ └── 2.1.1 │ ├── waymo_train_14_kdtrees.dill │ ├── waymo_train_15.pb │ ├── waymo_train_15_2.00px_m.dill │ ├── waymo_train_15_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 0.2.0 │ │ ├── 0.2.1 │ │ ├── 0.3.0 │ │ ├── 0.3.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ ├── 1.1.1 │ │ ├── 1.2.0 │ │ ├── 1.2.1 │ │ ├── 1.3.0 │ │ ├── 1.3.1 │ │ ├── 2.0.0 │ │ ├── 2.0.1 │ │ ├── 2.1.0 │ │ ├── 2.1.1 │ │ ├── 2.2.0 │ │ ├── 2.2.1 │ │ ├── 2.3.0 │ │ └── 2.3.1 │ ├── waymo_train_15_kdtrees.dill │ ├── waymo_train_1_2.00px_m.dill │ ├── waymo_train_1_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ └── 1.1.1 │ ├── waymo_train_1_kdtrees.dill │ ├── waymo_train_2.pb │ ├── waymo_train_2_2.00px_m.dill │ ├── waymo_train_2_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ ├── 1.1.1 │ │ ├── 2.0.0 │ │ ├── 2.0.1 │ │ ├── 2.1.0 │ │ └── 2.1.1 │ ├── waymo_train_2_kdtrees.dill │ ├── waymo_train_3.pb │ ├── waymo_train_3_2.00px_m.dill │ ├── waymo_train_3_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ ├── 1.1.1 │ │ ├── 2.0.0 │ │ ├── 2.0.1 │ │ ├── 2.1.0 │ │ └── 2.1.1 │ ├── waymo_train_3_kdtrees.dill │ ├── waymo_train_4.pb │ ├── waymo_train_4_2.00px_m.dill │ ├── waymo_train_4_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ └── 1.1.1 │ ├── waymo_train_4_kdtrees.dill │ ├── waymo_train_5.pb │ ├── waymo_train_5_2.00px_m.dill │ ├── waymo_train_5_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ └── 1.1.1 │ ├── waymo_train_5_kdtrees.dill │ ├── waymo_train_6.pb │ ├── waymo_train_6_2.00px_m.dill │ ├── waymo_train_6_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ └── 1.1.1 │ ├── waymo_train_6_kdtrees.dill │ ├── waymo_train_7.pb │ ├── waymo_train_7_2.00px_m.dill │ ├── waymo_train_7_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ └── 1.1.1 │ ├── waymo_train_7_kdtrees.dill │ ├── waymo_train_8.pb │ ├── waymo_train_8_2.00px_m.dill │ ├── waymo_train_8_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ └── 1.1.1 │ ├── waymo_train_8_kdtrees.dill │ ├── waymo_train_9.pb │ ├── waymo_train_9_2.00px_m.dill │ ├── waymo_train_9_2.00px_m.zarr │ │ ├── .zarray │ │ ├── 0.0.0 │ │ ├── 0.0.1 │ │ ├── 0.1.0 │ │ ├── 0.1.1 │ │ ├── 1.0.0 │ │ ├── 1.0.1 │ │ ├── 1.1.0 │ │ ├── 1.1.1 │ │ ├── 2.0.0 │ │ ├── 2.0.1 │ │ ├── 2.1.0 │ │ └── 2.1.1 │ └── waymo_train_9_kdtrees.dill │ ├── scene_0 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_1 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_10 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_11 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_12 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_13 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_14 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_15 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_2 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_3 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_4 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_5 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_6 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_7 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_8 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ ├── scene_9 │ ├── agent_data_dt0.10.feather │ ├── scene_index_dt0.10.pkl │ ├── scene_metadata_dt0.10.dill │ └── tls_data_dt0.10.feather │ └── scenes_list.dill ├── index.html ├── install_local_env.sh ├── prosim ├── __init__.py ├── config │ ├── __init__.py │ ├── cond_sampler │ │ ├── base.yaml │ │ ├── text_goal_dragpoint_v_action_tag_0.25.yaml │ │ └── uncondition.yaml │ ├── default.py │ └── path_cfg.py ├── core │ ├── basic.py │ └── registry.py ├── create_dataset.py ├── data_list │ └── demo_waymo.txt ├── dataset │ ├── __init__.py │ ├── basic.py │ ├── condition_utils.py │ ├── data_utils.py │ ├── format_utils.py │ ├── imitation.py │ ├── motion_tag_utils.py │ ├── prompt_utils.py │ ├── prosim_instruct_520k │ │ ├── waymo_train_IDs.pkl │ │ └── waymo_val_IDs.pkl │ └── text_utils.py ├── demo │ ├── vis.py │ └── vis_from_dict.py ├── exp │ └── demo │ │ ├── vis_debug.yaml │ │ └── visualization.yaml ├── loss │ ├── __init__.py │ ├── loss_func.py │ └── offroad_loss.py ├── main.py ├── metrics │ ├── __init__.py │ ├── base.py │ └── motion_pred.py ├── models │ ├── __init__.py │ ├── base.py │ ├── condition_transformer │ │ ├── __init__.py │ │ ├── attn_utils.py │ │ ├── base.py │ │ ├── base_llm.py │ │ ├── condition_attns.py │ │ ├── condition_encoders.py │ │ └── text_attns.py │ ├── decoder │ │ ├── __init__.py │ │ ├── base.py │ │ └── sym_coord.py │ ├── layers │ │ ├── attention_layer.py │ │ ├── fourier_embedding.py │ │ └── mlp.py │ ├── policy │ │ ├── __init__.py │ │ ├── act_decoder.py │ │ ├── base.py │ │ └── temporal_ar.py │ ├── prompt_encoder │ │ ├── __init__.py │ │ └── base.py │ ├── prompt_generator │ │ ├── __init__.py │ │ └── generators.py │ ├── scene_encoder │ │ ├── __init__.py │ │ ├── attn_fusion.py │ │ ├── base.py │ │ ├── map_encoder.py │ │ ├── obs_encoder.py │ │ └── pointnet_encoder.py │ ├── traj_sam.py │ └── utils │ │ ├── __init__.py │ │ ├── data.py │ │ ├── geometry.py │ │ ├── graph.py │ │ ├── pos_enc.py │ │ ├── visualization.py │ │ └── weight_init.py ├── rollout │ ├── __init__.py │ ├── baseline.py │ ├── callbacks.py │ ├── distributed_utils.py │ ├── gpu_utils.py │ ├── metrics.py │ ├── package_submission.py │ ├── run_distributed_rollout.py │ ├── utils.py │ └── waymo_utils.py └── trainer.py └── prosim_demo ├── cfg ├── no_text.yaml ├── waymo_demo.yaml └── with_text.yaml ├── load_prosim_instruct_520k.ipynb └── text_prompt_inference.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__/ 3 | .vscode/ 4 | .idea/ 5 | *.log 6 | *.pth 7 | *.ckpt 8 | *.pt 9 | *.npy 10 | *.npz 11 | *.zip 12 | *.tar.gz 13 | .DS_Store 14 | 15 | */data_indexes/* -------------------------------------------------------------------------------- /assets/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/architecture.jpg -------------------------------------------------------------------------------- /assets/editing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/editing.jpg -------------------------------------------------------------------------------- /assets/font.css: -------------------------------------------------------------------------------- 1 | /* Homepage Font */ 2 | 3 | /* latin-ext */ 4 | @font-face { 5 | font-family: 'Lato'; 6 | font-style: normal; 7 | font-weight: 400; 8 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjxAwXjeu.woff2) format('woff2'); 9 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 10 | } 11 | 12 | /* latin */ 13 | @font-face { 14 | font-family: 'Lato'; 15 | font-style: normal; 16 | font-weight: 400; 17 | src: local('Lato Regular'), local('Lato-Regular'), url(https://fonts.gstatic.com/s/lato/v16/S6uyw4BMUTPHjx4wXg.woff2) format('woff2'); 18 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 19 | } 20 | 21 | /* latin-ext */ 22 | @font-face { 23 | font-family: 'Lato'; 24 | font-style: normal; 25 | font-weight: 700; 26 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwaPGR_p.woff2) format('woff2'); 27 | unicode-range: U+0100-024F, U+0259, U+1E00-1EFF, U+2020, U+20A0-20AB, U+20AD-20CF, U+2113, U+2C60-2C7F, U+A720-A7FF; 28 | } 29 | 30 | /* latin */ 31 | @font-face { 32 | font-family: 'Lato'; 33 | font-style: normal; 34 | font-weight: 700; 35 | src: local('Lato Bold'), local('Lato-Bold'), url(https://fonts.gstatic.com/s/lato/v16/S6u9w4BMUTPHh6UVSwiPGQ.woff2) format('woff2'); 36 | unicode-range: U+0000-00FF, U+0131, U+0152-0153, U+02BB-02BC, U+02C6, U+02DA, U+02DC, U+2000-206F, U+2074, U+20AC, U+2122, U+2191, U+2193, U+2212, U+2215, U+FEFF, U+FFFD; 37 | } 38 | -------------------------------------------------------------------------------- /assets/gpt_example.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/gpt_example.jpg -------------------------------------------------------------------------------- /assets/main_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/main_result.jpg -------------------------------------------------------------------------------- /assets/prosim_demo_video_1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/prosim_demo_video_1.mp4 -------------------------------------------------------------------------------- /assets/prosim_demo_video_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/prosim_demo_video_2.gif -------------------------------------------------------------------------------- /assets/prosim_demo_video_2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/prosim_demo_video_2.mp4 -------------------------------------------------------------------------------- /assets/prosim_llm.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/prosim_llm.jpg -------------------------------------------------------------------------------- /assets/prosim_model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/prosim_model.jpg -------------------------------------------------------------------------------- /assets/prosim_table_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/prosim_table_1.jpg -------------------------------------------------------------------------------- /assets/prosim_table_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/prosim_table_2.jpg -------------------------------------------------------------------------------- /assets/style.css: -------------------------------------------------------------------------------- 1 | /* Body */ 2 | body { 3 | background: #f0f2f5; 4 | color: #ffffff; 5 | font-family: 'Lato', Verdana, Helvetica, sans-serif; 6 | font-weight: 300; 7 | font-size: 14pt; 8 | } 9 | 10 | /* Hyperlinks */ 11 | a { 12 | text-decoration: none; 13 | } 14 | 15 | a:link { 16 | color: #1772d0; 17 | } 18 | 19 | a:visited { 20 | color: #1772d0; 21 | } 22 | 23 | a:active { 24 | color: red; 25 | } 26 | 27 | a:hover { 28 | color: #f09228; 29 | } 30 | 31 | /* Pre-formatted Text */ 32 | pre { 33 | margin: 5pt 0; 34 | border: 0; 35 | font-size: 12pt; 36 | background: #fcfcfc; 37 | } 38 | 39 | *:focus { 40 | outline: 0; 41 | } 42 | 43 | .video-container { 44 | /*aspect-ratio: 16/9;*/ 45 | /*outline: unset;*/ 46 | /*overflow: hidden;*/ 47 | /*top: -200px;*/ 48 | clip-path: inset(1px 1px); 49 | /*border: 1pt blanchedalmond;*/ 50 | } 51 | 52 | /* Project Page Style */ 53 | /* Section */ 54 | .section { 55 | width: 800pt; 56 | min-height: 100pt; 57 | margin: 15pt auto; 58 | padding: 30pt 40pt; 59 | border: 1pt hidden #000; 60 | text-align: justify; 61 | color: #000000; 62 | background: #ffffff; 63 | } 64 | 65 | /*.section .text{*/ 66 | /* width: 90%;*/ 67 | /* text-align: center;*/ 68 | /*}*/ 69 | 70 | /* Header (Title and Logo) */ 71 | .section .header { 72 | min-height: 80pt; 73 | margin-top: 0pt; 74 | } 75 | 76 | .section .header .logo { 77 | width: 80pt; 78 | margin-left: 10pt; 79 | float: left; 80 | } 81 | 82 | .section .header .logo img { 83 | width: 80pt; 84 | object-fit: cover; 85 | } 86 | 87 | .section .header .title { 88 | margin: 0 0; 89 | text-align: center; 90 | font-size: 22pt; 91 | } 92 | 93 | /* Author */ 94 | .section .author { 95 | margin: 0 0; 96 | /*padding-left: 15pt;*/ 97 | text-align: center; 98 | font-size: 14pt; 99 | } 100 | 101 | /* Institution */ 102 | .section .institution { 103 | margin: 10pt 0; 104 | padding-left: 0pt; 105 | text-align: center; 106 | font-size: 12pt; 107 | } 108 | 109 | /* Hyperlink (such as Paper and Code) */ 110 | .section .link { 111 | margin: 5pt 0; 112 | text-align: center; 113 | font-size: 16pt; 114 | } 115 | 116 | /* Teaser */ 117 | .section .teaser { 118 | margin: 20pt 0; 119 | text-align: center; 120 | } 121 | 122 | .section .teaser img { 123 | width: 90%; 124 | } 125 | 126 | /* Section Title */ 127 | .section .title { 128 | text-align: center; 129 | font-size: 22pt; 130 | margin: 5pt 0 15pt 0; /* top right bottom left */ 131 | } 132 | 133 | /* Section Body */ 134 | .section .body { 135 | margin-bottom: 15pt; 136 | text-align: justify; 137 | font-size: 14pt; 138 | } 139 | 140 | /* BibTeX */ 141 | .section .bibtex { 142 | margin: 5pt 0; 143 | text-align: left; 144 | font-size: 22pt; 145 | } 146 | 147 | /* Related Work */ 148 | .section .ref { 149 | margin: 20pt 0 10pt 0; /* top right bottom left */ 150 | text-align: left; 151 | font-size: 18pt; 152 | font-weight: bold; 153 | } 154 | 155 | /* Citation */ 156 | .section .citation { 157 | min-height: 60pt; 158 | margin: 10pt 0; 159 | } 160 | 161 | .section .citation .image { 162 | width: 120pt; 163 | float: left; 164 | } 165 | 166 | .section .citation .image img { 167 | max-height: 60pt; 168 | width: 120pt; 169 | object-fit: cover; 170 | } 171 | 172 | .section .citation .comment { 173 | margin-left: 130pt; 174 | text-align: left; 175 | font-size: 14pt; 176 | } 177 | 178 | .txtt{ 179 | color: #000000; 180 | font-family: 'Courier', sans-serif; 181 | font-stretch: condensed; 182 | font-weight: 600; 183 | font-size: 95%; 184 | } -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/assets/teaser.jpg -------------------------------------------------------------------------------- /demo_dataset/trajdata_cache/waymo_train/maps/waymo_train_0.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/demo_dataset/trajdata_cache/waymo_train/maps/waymo_train_0.pb -------------------------------------------------------------------------------- /demo_dataset/trajdata_cache/waymo_train/maps/waymo_train_0_2.00px_m.dill: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/demo_dataset/trajdata_cache/waymo_train/maps/waymo_train_0_2.00px_m.dill -------------------------------------------------------------------------------- /demo_dataset/trajdata_cache/waymo_train/maps/waymo_train_0_2.00px_m.zarr/.zarray: -------------------------------------------------------------------------------- 1 | { 2 | "chunks": [ 3 | 1, 4 | 468, 5 | 203 6 | ], 7 | "compressor": { 8 | "blocksize": 0, 9 | "clevel": 5, 10 | "cname": "lz4", 11 | "id": "blosc", 12 | "shuffle": 1 13 | }, 14 | "dtype": " Callable: 36 | def wrap(to_register): 37 | if assert_type is not None: 38 | assert issubclass( 39 | to_register, assert_type 40 | ), "{} must be a subclass of {}".format( 41 | to_register, assert_type 42 | ) 43 | register_name = to_register.__name__ if name is None else name 44 | 45 | cls.mapping[_type][register_name] = to_register 46 | return to_register 47 | 48 | if to_register is None: 49 | return wrap 50 | else: 51 | return wrap(to_register) 52 | 53 | @classmethod 54 | def register_dataset(cls, to_register=None, *, name: Optional[str] = None): 55 | return cls._register_impl( 56 | "dataset", to_register, name, assert_type=Dataset 57 | ) 58 | 59 | @classmethod 60 | def register_metric(cls, to_register=None, *, name: Optional[str] = None): 61 | return cls._register_impl( 62 | "metric", to_register, name, assert_type=Metric 63 | ) 64 | 65 | @classmethod 66 | def register_model(cls, to_register=None, *, name: Optional[str] = None): 67 | return cls._register_impl( 68 | "model", to_register, name, assert_type=LightningModule 69 | ) 70 | 71 | @classmethod 72 | def register_scene_encoder(cls, to_register=None, *, name: Optional[str] = None): 73 | return cls._register_impl( 74 | "scene_encoder", to_register, name, assert_type=nn.Module 75 | ) 76 | 77 | @classmethod 78 | def register_prompt_encoder(cls, to_register=None, *, name: Optional[str] = None): 79 | return cls._register_impl( 80 | "prompt_encoder", to_register, name, assert_type=nn.Module 81 | ) 82 | 83 | @classmethod 84 | def register_decoder(cls, to_register=None, *, name: Optional[str] = None): 85 | return cls._register_impl( 86 | "decoder", to_register, name, assert_type=nn.Module 87 | ) 88 | 89 | @classmethod 90 | def register_hist_encoder(cls, to_register=None, *, name: Optional[str] = None): 91 | return cls._register_impl( 92 | "hist_encoder", to_register, name, assert_type=nn.Module 93 | ) 94 | 95 | @classmethod 96 | def register_policy(cls, to_register=None, *, name: Optional[str] = None): 97 | return cls._register_impl( 98 | "policy", to_register, name, assert_type=nn.Module 99 | ) 100 | 101 | @classmethod 102 | def _get_impl(cls, _type: str, name: str) -> Type: 103 | return cls.mapping[_type].get(name, None) 104 | 105 | @classmethod 106 | def get_dataset(cls, name: str) -> Type[Dataset]: 107 | return cls._get_impl("dataset", name) 108 | 109 | @classmethod 110 | def get_metric(cls, name: str) -> Type[Metric]: 111 | return cls._get_impl("metric", name) 112 | 113 | @classmethod 114 | def get_model(cls, name: str) -> Type[LightningModule]: 115 | return cls._get_impl("model", name) 116 | 117 | @classmethod 118 | def get_scene_encoder(cls, name: str) -> Type[nn.Module]: 119 | return cls._get_impl("scene_encoder", name) 120 | 121 | @classmethod 122 | def get_prompt_encoder(cls, name: str) -> Type[nn.Module]: 123 | return cls._get_impl("prompt_encoder", name) 124 | 125 | @classmethod 126 | def get_hist_encoder(cls, name: str) -> Type[nn.Module]: 127 | return cls._get_impl("hist_encoder", name) 128 | 129 | @classmethod 130 | def get_policy(cls, name: str) -> Type[nn.Module]: 131 | return cls._get_impl("policy", name) 132 | 133 | @classmethod 134 | def get_decoder(cls, name: str) -> Type[nn.Module]: 135 | return cls._get_impl("decoder", name) 136 | 137 | 138 | registry = Registry() 139 | -------------------------------------------------------------------------------- /prosim/create_dataset.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.getcwd()) 5 | sys.path.append(os.path.dirname(os.getcwd())) 6 | sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) 7 | 8 | import argparse 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | 15 | from torch.utils.data import DataLoader 16 | 17 | from prosim.config.default import Config, get_config 18 | from prosim.core.registry import registry 19 | 20 | def main(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument( 23 | "--exp-config", 24 | type=str, 25 | required=True, 26 | help="path to config yaml containing info about experiment", 27 | ) 28 | parser.add_argument( 29 | "opts", 30 | default=None, 31 | nargs=argparse.REMAINDER, 32 | help="Modify config options from command line", 33 | ) 34 | 35 | args = parser.parse_args() 36 | run_exp(**vars(args)) 37 | 38 | 39 | def execute_exp(config: Config) -> None: 40 | r"""This function runs the specified config with the specified runtype 41 | Args: 42 | config: Habitat.config 43 | runtype: str {train or eval} 44 | """ 45 | task_config = config 46 | 47 | dataset_configs = {'train': config.TRAIN, 48 | 'val': config.VAL, 'test': config.TEST} 49 | dataset_type = task_config.DATASET.TYPE 50 | 51 | data_loaders = {} 52 | for mode, config in dataset_configs.items(): 53 | dataset = registry.get_dataset(dataset_type)(task_config, config.SPLIT) 54 | batch_size = config.BATCH_SIZE 55 | data_loaders[mode] = DataLoader(dataset, batch_size=batch_size, shuffle=config.SHUFFLE, pin_memory=True, drop_last=config.DROP_LAST, num_workers=config.NUM_WORKERS, collate_fn=dataset.get_collate_fn()) 56 | 57 | def run_exp(exp_config: str, opts=None) -> None: 58 | r"""Runs experiment given mode and config 59 | 60 | Args: 61 | exp_config: path to config file. 62 | run_type: "train" or "eval. 63 | opts: list of strings of additional config options. 64 | 65 | Returns: 66 | None. 67 | """ 68 | 69 | config = get_config(exp_config, opts) 70 | execute_exp(config) 71 | 72 | if __name__ == "__main__": 73 | main() -------------------------------------------------------------------------------- /prosim/data_list/demo_waymo.txt: -------------------------------------------------------------------------------- 1 | scene_12760 2 | scene_8280 -------------------------------------------------------------------------------- /prosim/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import * 2 | from .imitation import * -------------------------------------------------------------------------------- /prosim/dataset/imitation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from tqdm import tqdm 4 | from typing import Union 5 | from functools import partial 6 | 7 | from prosim.core.registry import registry 8 | from prosim.dataset.basic import ProSimDataset 9 | from prosim.dataset.data_utils import get_vectorized_lanes 10 | from prosim.dataset.format_utils import ImitationBatchFormat 11 | 12 | @registry.register_dataset(name='prosim_imitation') 13 | class ProSimImitationDataset(ProSimDataset): 14 | def __init__(self, config, split, **args): 15 | super().__init__(config, split, **args) 16 | 17 | def _get_trajdata_cfg(self, cfg, split): 18 | td_cfg = super()._get_trajdata_cfg(cfg, split) 19 | td_cfg['centric'] = 'scene' 20 | 21 | # add batch formating augmentation for imitation learning 22 | if cfg.DATASET.NO_PROCESSING: 23 | print('do not add batch formating augmentation for imitation learning') 24 | else: 25 | td_cfg['augmentations'] = [ImitationBatchFormat(cfg, split)] 26 | 27 | return td_cfg 28 | 29 | def _get_vec_lane_func(self, data_cfg): 30 | MAP_RANGE = data_cfg.MAP.RANGE[self.split.upper()] 31 | 32 | vec_lane_func = partial(get_vectorized_lanes, 33 | data_cfg=data_cfg, 34 | map_range=MAP_RANGE) 35 | 36 | return vec_lane_func -------------------------------------------------------------------------------- /prosim/dataset/prompt_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from prosim.dataset.data_utils import rotate 4 | from prosim.models.utils.data import extract_agent_obs_from_center_obs 5 | 6 | class PromptGenerator: 7 | def __init__(self, config): 8 | self.config = config 9 | 10 | def get_prompt_dim(self): 11 | raise NotImplementedError 12 | 13 | def prompt_for_scene_batch(self): 14 | raise NotImplementedError 15 | 16 | def _append_tgt_agent_info(self, result, tgt_agent_info): 17 | result['prompt_mask'] = tgt_agent_info['prompt_mask'] 18 | result['agent_ids'] = tgt_agent_info['agent_ids'] 19 | result['position'] = tgt_agent_info['position'] 20 | result['heading'] = tgt_agent_info['heading'] 21 | 22 | return result 23 | 24 | def _get_batch_tgt_agent_info(self, batch): 25 | # extract agent information with tgt_agent_idx 26 | # B: batch size 27 | # N: max number of agents in the batch 28 | # outputs: 29 | # goal point - (B, N, 2) 30 | # position - (B, N, 2) 31 | # heading - (B, N, 1) 32 | # prompt_mask - (B, N, 1) 33 | # agent_ids - List[List[str]] 34 | 35 | tgt_agent_idxs = batch.tgt_agent_idxs 36 | B = len(tgt_agent_idxs) 37 | N = max([len(tgt_idx) for tgt_idx in tgt_agent_idxs]) 38 | device = batch.agent_hist.device 39 | 40 | agent_info = {} 41 | 42 | agent_info['goal_point'] = torch.zeros((B, N, 2), device=device) 43 | agent_info['position'] = torch.zeros((B, N, 2), device=device) 44 | agent_info['heading'] = torch.zeros((B, N, 1), device=device) 45 | agent_info['prompt_mask'] = torch.zeros((B, N), device=device).to(torch.bool) 46 | agent_info['agent_type'] = torch.zeros((B, N), device=device).to(torch.long) 47 | 48 | agent_info['agent_extend'] = torch.zeros((B, N, 2), device=device) 49 | agent_info['agent_vel'] = torch.zeros((B, N, 2), device=device) 50 | 51 | agent_info['agent_ids'] = [] 52 | 53 | for b in range(B): 54 | o_idx = tgt_agent_idxs[b] 55 | n = len(o_idx) 56 | 57 | b_idx = torch.ones(n, device=device).long() * b 58 | t_idx = batch.agent_fut_len[b, o_idx] - 1 59 | 60 | if batch.agent_fut.shape[2] > 0: 61 | agent_info['goal_point'][b, :n] = batch.agent_fut[b_idx, o_idx, t_idx].as_format('x,y').float() 62 | 63 | agent_info['position'][b, :n] = batch.agent_hist[b_idx, o_idx, -1].as_format('x,y').float() 64 | agent_info['heading'][b, :n] = batch.agent_hist[b_idx, o_idx, -1].as_format('h').float() 65 | agent_info['prompt_mask'][b, :n] = True 66 | agent_info['agent_type'][b, :n] = batch.agent_type[b_idx, o_idx] 67 | agent_info['agent_ids'].append([batch.agent_names[b][idx] for idx in o_idx]) 68 | agent_info['agent_extend'][b, :n] = batch.agent_hist_extent[b_idx, o_idx, -1, :2] 69 | agent_info['agent_vel'][b, :n] = batch.agent_hist[b_idx, o_idx, -1].as_format('xd,yd').float() 70 | 71 | return agent_info 72 | 73 | def prompt_for_batch(self, batch): 74 | tgt_agent_info = self._get_batch_tgt_agent_info(batch) 75 | 76 | prompt_dict = self.prompt_for_scene_batch(tgt_agent_info) 77 | 78 | prompt_dict['agent_type'] = tgt_agent_info['agent_type'] 79 | 80 | if 'prompt' in prompt_dict: 81 | assert torch.isnan(prompt_dict['prompt'][prompt_dict['prompt_mask']]).any() == False 82 | 83 | return prompt_dict 84 | 85 | def prompt_for_rollout_batch(self, query_names, center_obs, prompt_value=None): 86 | agent_obs = extract_agent_obs_from_center_obs(query_names, center_obs) 87 | result = self._prompt_for_rollout_batch_helper(query_names, agent_obs, prompt_value) 88 | result['agent_type'] = agent_obs['type'] 89 | 90 | return result 91 | 92 | 93 | class AgentStatusGenerator(PromptGenerator): 94 | def __init__(self, config): 95 | super().__init__(config) 96 | self.config = config 97 | self.prompt_dim = 0 98 | 99 | if self.config.USE_VEL: 100 | self.prompt_dim += 2 101 | 102 | if self.config.USE_EXTEND: 103 | self.prompt_dim += 2 104 | 105 | if self.config.USE_AGENT_TYPE: 106 | self.prompt_dim += 3 107 | 108 | def get_prompt_dim(self): 109 | return self.prompt_dim 110 | 111 | def prompt_for_scene_batch(self, tgt_agent_info): 112 | ''' 113 | # extract agent status from tgt_agent_info with tgt_agent_idx 114 | # B: batch size 115 | # N: max number of agents in the batch 116 | 117 | # outputs: 118 | # prompt - (B, N, D) 119 | # position - (B, N, 2) 120 | # heading - (B, N, 1) 121 | # prompt_mask - (B, N, 1) 122 | 123 | ''' 124 | 125 | result = {} 126 | 127 | prompt = [] 128 | if self.config.USE_VEL: 129 | abs_agent_vel = tgt_agent_info['agent_vel'] 130 | agent_heading = tgt_agent_info['heading'] 131 | agent_vel = rotate(abs_agent_vel[..., 0], abs_agent_vel[..., 1], -agent_heading[..., 0]) 132 | prompt.append(agent_vel) 133 | 134 | if self.config.USE_EXTEND: 135 | prompt.append(tgt_agent_info['agent_extend']) 136 | 137 | if self.config.USE_AGENT_TYPE: 138 | agent_type = tgt_agent_info['agent_type'] 139 | agent_type_one_hot = torch.zeros(agent_type.shape[0], agent_type.shape[1], 3).to(agent_type.device) 140 | for type_id in [1,2,3]: 141 | agent_type_one_hot[..., type_id-1] = (agent_type == type_id).float() 142 | 143 | prompt.append(agent_type_one_hot) 144 | 145 | prompt = torch.cat(prompt, dim=-1) 146 | result['prompt'] = prompt 147 | 148 | result = self._append_tgt_agent_info(result, tgt_agent_info) 149 | 150 | return result -------------------------------------------------------------------------------- /prosim/dataset/prosim_instruct_520k/waymo_train_IDs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/prosim/dataset/prosim_instruct_520k/waymo_train_IDs.pkl -------------------------------------------------------------------------------- /prosim/dataset/prosim_instruct_520k/waymo_val_IDs.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ariostgx/ProSim/78a398c1859b10fbabcc455f1266cf0103f51605/prosim/dataset/prosim_instruct_520k/waymo_val_IDs.pkl -------------------------------------------------------------------------------- /prosim/dataset/text_utils.py: -------------------------------------------------------------------------------- 1 | AGENT_TEMPLATE = "" 2 | MAX_AGENT_NUM = 128 -------------------------------------------------------------------------------- /prosim/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss_func import * -------------------------------------------------------------------------------- /prosim/main.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.getcwd()) 5 | sys.path.append(os.path.dirname(os.getcwd())) 6 | sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) 7 | 8 | import argparse 9 | import random 10 | 11 | import numpy as np 12 | import torch 13 | import wandb 14 | 15 | from prosim.config.default import Config, get_config 16 | from prosim.trainer import BaseTrainer as Trainer 17 | 18 | 19 | def main(): 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument( 22 | "--run-type", 23 | choices=["train", "eval", "data_debug"], 24 | required=True, 25 | help="run type of the experiment (train or eval)", 26 | ) 27 | parser.add_argument( 28 | "--exp-config", 29 | type=str, 30 | required=True, 31 | help="path to config yaml containing info about experiment", 32 | ) 33 | parser.add_argument( 34 | "--cluster", 35 | default='local', # 'local', 'ngc', 'slurm', 36 | type=str, 37 | help="which cluster to run on", 38 | ) 39 | parser.add_argument( 40 | "opts", 41 | default=None, 42 | nargs=argparse.REMAINDER, 43 | help="Modify config options from command line", 44 | ) 45 | 46 | args = parser.parse_args() 47 | run_exp(**vars(args)) 48 | 49 | 50 | def execute_exp(config: Config, run_type: str) -> None: 51 | r"""This function runs the specified config with the specified runtype 52 | Args: 53 | config: Habitat.config 54 | runtype: str {train or eval} 55 | """ 56 | random.seed(config.SEED) 57 | np.random.seed(config.SEED) 58 | torch.manual_seed(config.SEED) 59 | 60 | wandb_api_key = os.environ.get('WANDB_API_KEY') 61 | if wandb_api_key: 62 | wandb.login(key=wandb_api_key) 63 | 64 | trainer = Trainer(config) 65 | 66 | if run_type == "train": 67 | trainer.train() 68 | elif run_type == "eval": 69 | trainer.eval() 70 | elif run_type == 'data_debug': 71 | trainer.data_debug() 72 | 73 | if config.LOGGER == 'wandb': 74 | wandb.finish() 75 | 76 | return trainer.save_dir 77 | 78 | def run_exp(exp_config: str, run_type: str, opts, cluster) -> None: 79 | r"""Runs experiment given mode and config 80 | 81 | Args: 82 | exp_config: path to config file. 83 | run_type: "train" or "eval. 84 | opts: list of strings of additional config options. 85 | 86 | Returns: 87 | None. 88 | """ 89 | 90 | config = get_config(exp_config, opts, cluster) 91 | execute_exp(config, run_type) 92 | 93 | if __name__ == "__main__": 94 | main() -------------------------------------------------------------------------------- /prosim/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .motion_pred import * -------------------------------------------------------------------------------- /prosim/metrics/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytorch_lightning as pl 3 | import torch 4 | import torch.nn as nn 5 | 6 | # from ax import Metric 7 | import numpy as np 8 | from scipy.optimize import linear_sum_assignment 9 | from torchmetrics import Accuracy, MeanMetric, Metric 10 | 11 | from prosim.core.registry import registry 12 | 13 | from pytorch_lightning.callbacks import Callback 14 | from prosim.rollout.distributed_utils import check_mem_usage, print_system_mem_usage, get_gpu_memory_usage 15 | 16 | class metric_callback(Callback): 17 | def __init__(self, config): 18 | self.config = config 19 | super().__init__() 20 | 21 | def _shared_update(self, trainer, pl_module, outputs, batch, mode, dataloader_idx): 22 | cond_sets = pl_module.eval_dataset_cond_sets 23 | 24 | model_output = outputs['model_output'] 25 | for task in pl_module.tasks: 26 | for metric_name in pl_module.metrics[mode][task].keys(): 27 | if len(cond_sets) == 0: 28 | pl_module.metrics[mode][task][metric_name](batch, model_output[task]) 29 | else: 30 | cond_set = cond_sets[dataloader_idx] 31 | pl_module.metrics[mode][task][metric_name][cond_set](batch, model_output[task]) 32 | 33 | def _shared_log(self, trainer, pl_module, mode): 34 | on_step = False 35 | on_epoch = True 36 | sync_dist = trainer.num_devices > 1 37 | 38 | cond_sets = pl_module.eval_dataset_cond_sets 39 | 40 | for task in pl_module.tasks: 41 | for metric_name in pl_module.metrics[mode][task].keys(): 42 | 43 | if len(cond_sets) == 0: 44 | metric_value = pl_module.metrics[mode][task][metric_name].compute() 45 | for subname, subvalue in metric_value.items(): 46 | pl_module.log('{}/metric-{}-{}-{}'.format(mode, task, metric_name, subname), subvalue.detach().cpu().item(), on_epoch=on_epoch, on_step=on_step, sync_dist=sync_dist) 47 | else: 48 | for cond_set in cond_sets: 49 | metric_value = pl_module.metrics[mode][task][metric_name][cond_set].compute() 50 | for subname, subvalue in metric_value.items(): 51 | pl_module.log('{}/metric-{}-{}-{}-{}'.format(mode, task, metric_name, cond_set, subname), subvalue.detach().cpu().item(), on_epoch=on_epoch, on_step=on_step, sync_dist=sync_dist) 52 | 53 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): 54 | self._shared_update(trainer, pl_module, outputs, batch, 'val', dataloader_idx) 55 | 56 | def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None): 57 | self._shared_update(trainer, pl_module, outputs, batch, 'test', dataloader_idx) 58 | 59 | def on_validation_epoch_end(self, trainer, pl_module) -> None: 60 | self._shared_log(trainer, pl_module, 'val') 61 | 62 | def on_test_epoch_end(self, trainer, pl_module) -> None: 63 | self._shared_log(trainer, pl_module, 'test') 64 | 65 | 66 | @registry.register_metric(name='debug') 67 | class Debug(Metric): 68 | def __init__(self, config): 69 | super().__init__() 70 | self.config = config 71 | 72 | def update(self, batch, model_output): 73 | pass 74 | 75 | def compute(self): 76 | return torch.tensor(0) 77 | 78 | def reset(self): 79 | pass -------------------------------------------------------------------------------- /prosim/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .traj_sam import * 2 | from .decoder import * 3 | from .scene_encoder import * 4 | from .prompt_encoder import * 5 | from .policy import * -------------------------------------------------------------------------------- /prosim/models/condition_transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ConditionTransformer -------------------------------------------------------------------------------- /prosim/models/condition_transformer/attn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def obtain_valid_edge_node_idex(edge_mask, node_mask): 4 | ''' 5 | Obtain global valid edge indices from edge mask and node mask 6 | 7 | Global edge indices: 8 | 9 | flattened node indices in a batch [0, P), where P is the number of valid nodes in the batch 10 | P = node_mask.sum() 11 | 12 | Usage: 13 | We can directly use the these indices to index the valid nodes in the batch: 14 | start_nodes = node_emds[node_mask][edge_node_idex[:, 0]] 15 | end_nodes = node_emds[node_mask][edge_node_idex[:, 1]] 16 | 17 | 18 | Input: 19 | edge_mask (tensor): [B, N, N] - binary edge mask 20 | node_mask (tensor): [B, N] - binary node mask 21 | 22 | Output: 23 | edge_node_idex (tensor): [2, E] - valid edge indices in global node indices 24 | E = edge_mask.sum() - number of valid edges 25 | ''' 26 | B, N = node_mask.shape 27 | 28 | device = edge_mask.device 29 | 30 | # Global node indices 31 | node_indices = torch.ones(B, N, dtype=torch.long, device=device) * -1 # Size: [B, N] 32 | flat_node_indices = torch.arange(node_mask.sum(), device=device) # Re-index to [0, P) 33 | node_indices[node_mask] = flat_node_indices # Update global indices, # Size: [B, N] 34 | 35 | # Flatten edges and filter valid ones 36 | valid_edges = edge_mask.nonzero() # Size: [E, 3] (E is number of valid edges) 37 | valid_edges = valid_edges[edge_mask[valid_edges[:, 0], valid_edges[:, 1], valid_edges[:, 2]]] 38 | 39 | # Create a map from 2D node indices to flattened 1D indices 40 | node_map = -torch.ones(B, N, dtype=torch.long, device=device) # Initialize with -1 (invalid) 41 | node_map[node_mask] = torch.arange(flat_node_indices.size(0), device=device) # Map to valid indices 42 | 43 | # Map to 1D indices 44 | start = flat_node_indices[node_map[valid_edges[:, 0], valid_edges[:, 1]]] # Size: [E] 45 | end = flat_node_indices[node_map[valid_edges[:, 0], valid_edges[:, 2]]] # Size: [E] 46 | edge_node_index = torch.stack([start, end], dim=0) # Size: [2, E] 47 | 48 | return edge_node_index 49 | 50 | if __name__ == '__main__': 51 | # Settings 52 | B, N = 2, 4 # Batches = 2, Nodes per batch = 4 53 | # Edge matrix (randomly generated for demo) 54 | edge_mask = torch.rand(B, N, N) > 0.5 # Size: [B, N, N] 55 | # Valid mask (randomly generated for demo) 56 | node_mask = torch.rand(B, N) > 0.3 # Size: [B, N] 57 | 58 | print('Edge mask:' + str(edge_mask)) 59 | print('Node mask:' + str(node_mask)) 60 | print('Valid edge indices:' + str(obtain_valid_edge_node_idex(edge_mask, node_mask))) -------------------------------------------------------------------------------- /prosim/models/condition_transformer/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .condition_encoders import condition_encoders 4 | from .condition_attns import condition_attns 5 | from .text_attns import text_attns 6 | class ConditionTransformer(nn.Module): 7 | def __init__(self, config): 8 | super().__init__() 9 | self.config = config 10 | self.use_pe = self.config.MODEL.CONDITION_TRANSFORMER.PE.ENABLE 11 | self._config_cond_types() 12 | self._config_models() 13 | 14 | def _config_cond_types(self): 15 | cond_types = self.config.PROMPT.CONDITION.TYPES 16 | self.cond_types = [type_name for type_name in cond_types if 'OneText' not in type_name] 17 | self.text_types = [type_name for type_name in cond_types if 'OneText' in type_name] 18 | 19 | def _config_models(self): 20 | self._config_condition_encoders() 21 | self._config_condition_attn() 22 | 23 | def _config_condition_encoders(self): 24 | if len(self.cond_types) > 0: 25 | self.condition_encoders = nn.ModuleDict() 26 | for cond_type in self.cond_types: 27 | self.condition_encoders[cond_type] = condition_encoders[cond_type](self.config) 28 | 29 | def _config_condition_attn(self): 30 | if len(self.cond_types) > 0: 31 | model_type = self.config.MODEL.CONDITION_TRANSFORMER.ATTN_TYPE 32 | self.condition_attn = condition_attns[model_type](self.config) 33 | 34 | if len(self.text_types) > 0: 35 | model_type = self.config.MODEL.CONDITION_TRANSFORMER.TEXT_ATTN.TYPE 36 | self.text_attn = text_attns[model_type](self.config) 37 | 38 | def forward(self, condition_data, **kwargs): 39 | condition_emds = {} 40 | 41 | # apply condition encoders for fixed prompt_idx condition types (non-text) 42 | if len(self.cond_types) > 0: 43 | for cond_type, cond_encoder in self.condition_encoders.items(): 44 | if cond_type in condition_data.keys() and condition_data[cond_type]['input'].shape[1] > 0: 45 | cond_emd_dicts = cond_encoder(condition_data[cond_type], **kwargs) 46 | condition_emds.update(cond_emd_dicts) 47 | 48 | prompt_condition_emd = self.condition_attn(condition_emds=condition_emds, **kwargs) 49 | 50 | else: 51 | prompt_condition_emd = kwargs['prompt_emd'] 52 | 53 | # apply condition encoders for text prompt_idx condition types 54 | if len(self.text_types) > 0 and self.text_types[0] in condition_data.keys(): 55 | text_cond = condition_data[self.text_types[0]] 56 | prompt_condition_emd, prompt_loss = self.text_attn(text_cond, prompt_condition_emd, **kwargs) 57 | else: 58 | print('No condition data for text condition transformer') 59 | prompt_loss = None 60 | 61 | return prompt_condition_emd, prompt_loss -------------------------------------------------------------------------------- /prosim/models/condition_transformer/base_llm.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from prosim.models.layers.mlp import MLP 8 | 9 | from transformers import LlamaTokenizer, LlamaForCausalLM 10 | 11 | import os 12 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 13 | class LLMEncoder(nn.Module): 14 | def __init__(self, config): 15 | super().__init__() 16 | self.config = config 17 | self.hidden_dim = config.MODEL.HIDDEN_DIM 18 | self.llm_config = self.config.MODEL.CONDITION_TRANSFORMER.CONDITION_ENCODER.TEXT.LLM 19 | self.max_txt_len = self.llm_config.MAX_TXT_LEN 20 | 21 | self._config_models() 22 | 23 | def maybe_autocast(self, device, dtype=torch.float16): 24 | # if on cpu, don't use autocast 25 | # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 26 | enable_autocast = device != torch.device("cpu") 27 | enable_autocast = True 28 | 29 | if enable_autocast: 30 | return torch.cuda.amp.autocast(dtype=dtype) 31 | else: 32 | return contextlib.nullcontext() 33 | 34 | def _config_models(self): 35 | self.hidden_dim = self.config.MODEL.HIDDEN_DIM 36 | llm_dim = self.llm_config.HIDDEN_DIM 37 | use_gpu = self.config.GPU is not None and len(self.config.GPU) > 0 38 | 39 | if self.llm_config.USE_PROMPT_TOKEN: 40 | self.prompt_to_llm_emd = MLP([self.hidden_dim, self.hidden_dim, llm_dim], ret_before_act=True, without_norm=True) 41 | 42 | self.llm_to_cond_emd = MLP([llm_dim, self.hidden_dim, self.hidden_dim], ret_before_act=True, without_norm=True) 43 | 44 | llm_path = self.llm_config.MODEL_PATH[self.llm_config.MODEL.upper()] 45 | 46 | self.llm_tokenizer = LlamaTokenizer.from_pretrained(llm_path, use_fast=False, truncation_side="left") 47 | torch_dtype = torch.float16 if use_gpu else torch.float32 48 | 49 | self.llm_model = LlamaForCausalLM.from_pretrained(llm_path, torch_dtype=torch_dtype, low_cpu_mem_usage=True) 50 | # forzen llm 51 | for name, param in self.llm_model.named_parameters(): 52 | param.requires_grad = False 53 | 54 | self.llm_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) 55 | self.llm_tokenizer.add_special_tokens({'bos_token': ''}) 56 | self.llm_tokenizer.add_special_tokens({'eos_token': ''}) 57 | self.llm_tokenizer.add_special_tokens({'unk_token': ''}) 58 | self.llm_tokenizer.padding_side = "right" 59 | self.llm_tokenizer.truncation_side = 'left' 60 | 61 | self.llm_model.resize_token_embeddings(len(self.llm_tokenizer)) 62 | 63 | def _get_llm_text_emd(self, texts: list[str], device: str = 'cuda'): 64 | text_inputs = self.llm_tokenizer( 65 | texts, 66 | return_tensors="pt", 67 | padding="longest", 68 | truncation=True, 69 | max_length=self.max_txt_len) 70 | 71 | text_emds = self.llm_model.get_input_embeddings()(text_inputs.input_ids.to(device)) # (V, T, D) 72 | 73 | return {'llm_input': text_emds, 'attn_mask': text_inputs.attention_mask.to(device)} 74 | 75 | def _get_llm_feature(self, llm_input_emd, llm_input_mask): 76 | device = llm_input_emd.device 77 | 78 | with self.maybe_autocast(device): 79 | with torch.no_grad(): 80 | output = self.llm_model( 81 | inputs_embeds=llm_input_emd, 82 | attention_mask=llm_input_mask, 83 | output_hidden_states=True, 84 | return_dict=True, 85 | ) 86 | 87 | hidden_state = output.hidden_states[-1] # (V, T, D) 88 | 89 | seq_idx = (torch.sum(llm_input_mask, dim=1) - 1).to(torch.long) 90 | b_idx = torch.arange(seq_idx.size(0)).to(hidden_state.device) 91 | llm_emd = hidden_state[b_idx, seq_idx, :] # (V, D) 92 | return llm_emd 93 | 94 | def _append_prompt_token(self, llm_input_dict, prompt_emd): 95 | llm_inputs = [] 96 | llm_attn_masks = [] 97 | 98 | device = llm_input_dict['attn_mask'].device 99 | 100 | for vid in range(prompt_emd.shape[0]): 101 | token_cnt = llm_input_dict['attn_mask'][vid].sum().item() 102 | llm_inputs.append( 103 | torch.cat([ 104 | llm_input_dict['llm_input'][vid, :token_cnt], 105 | prompt_emd[vid].unsqueeze(0), 106 | llm_input_dict['llm_input'][vid, token_cnt:] 107 | ], dim=0) 108 | ) 109 | llm_attn_masks.append( 110 | torch.cat([ 111 | llm_input_dict['attn_mask'][vid, :token_cnt], 112 | torch.ones(1, device=device), 113 | llm_input_dict['attn_mask'][vid, token_cnt:] 114 | ], dim=0) 115 | ) 116 | 117 | result = {} 118 | 119 | result['llm_input'] = torch.stack(llm_inputs, dim=0) 120 | result['attn_mask'] = torch.stack(llm_attn_masks, dim=0) 121 | 122 | return result 123 | 124 | def forward(self, valid_texts, bidxs, nidxs, **kwargs): 125 | ''' 126 | valid_texts: list of strings, each string is a valid text (size V) 127 | ''' 128 | device = bidxs.device 129 | 130 | llm_input_dict = self._get_llm_text_emd(valid_texts, device) 131 | 132 | if self.llm_config.USE_PROMPT_TOKEN: 133 | all_prompt_emd = kwargs['prompt_emd'] # (B, N, hidden_dim) 134 | prompt_emd = all_prompt_emd[bidxs, nidxs] # (V, hidden_dim) 135 | prompt_emd = self.prompt_to_llm_emd(prompt_emd) # (V, D) 136 | 137 | llm_input_dict = self._append_prompt_token(llm_input_dict, prompt_emd) 138 | 139 | llm_emd = self._get_llm_feature(llm_input_dict['llm_input'], llm_input_dict['attn_mask']) # (V, D) 140 | 141 | cond_emds = self.llm_to_cond_emd(llm_emd.to(torch.float32)) # (V, hidden_dim) 142 | 143 | return cond_emds -------------------------------------------------------------------------------- /prosim/models/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .sym_coord import * -------------------------------------------------------------------------------- /prosim/models/decoder/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from prosim.models.layers.mlp import MLP 5 | from prosim.core.registry import registry 6 | 7 | @registry.register_decoder(name='base') 8 | class Decoder(nn.Module): 9 | def __init__(self, config): 10 | super().__init__() 11 | self.config = config 12 | self.hidden_dim = config.HIDDEN_DIM 13 | self.goal_cfg = config.DECODER.GOAL_PRED 14 | self.K = self.goal_cfg.K 15 | self._config_models() 16 | 17 | def _config_models(self): 18 | if self.goal_cfg.ENABLE: 19 | self.goal_prob_head = MLP([self.hidden_dim, self.hidden_dim//2, self.K], ret_before_act=True) 20 | self.goal_point_head = MLP([self.hidden_dim, self.hidden_dim//2, self.K * 2], ret_before_act=True) 21 | 22 | def _goal_pred(self, goal_input, prompt_enc, result): 23 | ''' 24 | Input: 25 | goal_input: [B, N, D] 26 | Output: 27 | result['goal_prob']: [B, N, K] 28 | result['goal_point']: [B, N, K, 2] 29 | ''' 30 | 31 | B, N, D = goal_input.shape 32 | 33 | K = self.K 34 | result_goal_prob = torch.zeros(B, N, K).to(goal_input.device) 35 | result_goal_point = torch.zeros(B, N, K, 2).to(goal_input.device) 36 | 37 | prompt_mask = prompt_enc['prompt_mask'] 38 | 39 | # [Q, D] 40 | valid_k_input = goal_input[prompt_mask] 41 | 42 | # [Q, K] 43 | goal_prob = self.goal_prob_head(valid_k_input).view(-1, K) 44 | 45 | # [Q, K, 2] 46 | goal_point = self.goal_point_head(valid_k_input).view(-1, K, 2) 47 | 48 | result_goal_prob[prompt_mask] = goal_prob 49 | result_goal_point[prompt_mask] = goal_point 50 | 51 | # [B, N, K] 52 | result['goal_prob'] = result_goal_prob 53 | 54 | # [B, N, K, 2] 55 | result['goal_point'] = result_goal_point 56 | 57 | return result 58 | 59 | 60 | def _fusion(self, scene_emb, prompt_emd, prompt_mask): 61 | raise NotImplementedError 62 | 63 | def forward(self, scene_emb, prompt_enc): 64 | result = {} 65 | 66 | prompt_emd = prompt_enc['prompt_emd'] 67 | prompt_mask = prompt_enc['prompt_mask'] 68 | 69 | result['emd'] = self._fusion(scene_emb, prompt_emd, prompt_mask) 70 | 71 | if self.goal_cfg.ENABLE: 72 | self._goal_pred(scene_emb, result['emd'], result) 73 | 74 | return result -------------------------------------------------------------------------------- /prosim/models/layers/attention_layer.py: -------------------------------------------------------------------------------- 1 | # Mostly copied from https://github.com/ZikangZhou/QCNet/blob/main/layers/attention_layer.py 2 | 3 | from typing import Optional, Tuple, Union 4 | 5 | import time 6 | import torch 7 | import torch.nn as nn 8 | from torch_geometric.nn.conv import MessagePassing 9 | from torch_geometric.utils import softmax 10 | from prosim.models.utils import weight_init 11 | 12 | 13 | class AttentionLayer(MessagePassing): 14 | def __init__(self, 15 | hidden_dim: int, 16 | num_heads: int, 17 | head_dim: int, 18 | dropout: float, 19 | bipartite: bool, 20 | has_pos_emb: bool, 21 | **kwargs) -> None: 22 | super(AttentionLayer, self).__init__(aggr='add', node_dim=0, **kwargs) 23 | self.num_heads = num_heads 24 | self.head_dim = head_dim 25 | self.has_pos_emb = has_pos_emb 26 | self.scale = head_dim ** -0.5 27 | 28 | self.to_q = nn.Linear(hidden_dim, head_dim * num_heads) 29 | self.to_k = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) 30 | self.to_v = nn.Linear(hidden_dim, head_dim * num_heads) 31 | if has_pos_emb: 32 | self.to_k_r = nn.Linear(hidden_dim, head_dim * num_heads, bias=False) 33 | self.to_v_r = nn.Linear(hidden_dim, head_dim * num_heads) 34 | self.to_s = nn.Linear(hidden_dim, head_dim * num_heads) 35 | self.to_g = nn.Linear(head_dim * num_heads + hidden_dim, head_dim * num_heads) 36 | self.to_out = nn.Linear(head_dim * num_heads, hidden_dim) 37 | self.attn_drop = nn.Dropout(dropout) 38 | self.ff_mlp = nn.Sequential( 39 | nn.Linear(hidden_dim, hidden_dim * 4), 40 | nn.ReLU(inplace=True), 41 | nn.Dropout(dropout), 42 | nn.Linear(hidden_dim * 4, hidden_dim), 43 | ) 44 | if bipartite: 45 | self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) 46 | self.attn_prenorm_x_dst = nn.LayerNorm(hidden_dim) 47 | else: 48 | self.attn_prenorm_x_src = nn.LayerNorm(hidden_dim) 49 | self.attn_prenorm_x_dst = self.attn_prenorm_x_src 50 | if has_pos_emb: 51 | self.attn_prenorm_r = nn.LayerNorm(hidden_dim) 52 | self.attn_postnorm = nn.LayerNorm(hidden_dim) 53 | self.ff_prenorm = nn.LayerNorm(hidden_dim) 54 | self.ff_postnorm = nn.LayerNorm(hidden_dim) 55 | 56 | def forward(self, 57 | x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], 58 | r: Optional[torch.Tensor], 59 | edge_index: torch.Tensor) -> torch.Tensor: 60 | if isinstance(x, torch.Tensor): 61 | x_src = self.attn_prenorm_x_src(x) 62 | x_dst = self.attn_prenorm_x_dst(x) 63 | else: 64 | x_src, x_dst = x 65 | x_src = self.attn_prenorm_x_src(x_src) 66 | x_dst = self.attn_prenorm_x_dst(x_dst) 67 | x = x[1] 68 | if self.has_pos_emb and r is not None: 69 | r = self.attn_prenorm_r(r) 70 | 71 | x_src = x_src.to(x.dtype) 72 | x_dst = x_dst.to(x.dtype) 73 | r = r.to(x.dtype) 74 | 75 | # cast postnorm to the same dtype as q_i (needed for torch.bfloat16) 76 | x = x + self.attn_postnorm(self._attn_block(x_src, x_dst, r, edge_index)).to(x.dtype) 77 | x = x + self.ff_postnorm(self._ff_block(self.ff_prenorm(x))).to(x.dtype) 78 | return x 79 | 80 | def message(self, 81 | q_i: torch.Tensor, 82 | k_j: torch.Tensor, 83 | v_j: torch.Tensor, 84 | r: Optional[torch.Tensor], 85 | index: torch.Tensor, 86 | ptr: Optional[torch.Tensor]) -> torch.Tensor: 87 | if self.has_pos_emb and r is not None: 88 | k_j = k_j + self.to_k_r(r).view(-1, self.num_heads, self.head_dim) 89 | v_j = v_j + self.to_v_r(r).view(-1, self.num_heads, self.head_dim) 90 | sim = (q_i * k_j).sum(dim=-1) * self.scale 91 | attn = softmax(sim, index, ptr) 92 | attn = self.attn_drop(attn) 93 | 94 | # cast attn to the same dtype as q_i (needed for torch.bfloat16) 95 | # start = time.time() 96 | attn = attn.to(q_i.dtype) 97 | # cast_time = time.time() - start 98 | # print('\t\t\tAttn message dtype cast_time:', cast_time) 99 | 100 | return v_j * attn.unsqueeze(-1) 101 | 102 | def update(self, 103 | inputs: torch.Tensor, 104 | x_dst: torch.Tensor) -> torch.Tensor: 105 | inputs = inputs.view(-1, self.num_heads * self.head_dim) 106 | g = torch.sigmoid(self.to_g(torch.cat([inputs, x_dst], dim=-1))) 107 | return inputs + g * (self.to_s(x_dst) - inputs) 108 | 109 | def _attn_block(self, 110 | x_src: torch.Tensor, 111 | x_dst: torch.Tensor, 112 | r: Optional[torch.Tensor], 113 | edge_index: torch.Tensor) -> torch.Tensor: 114 | q = self.to_q(x_dst).view(-1, self.num_heads, self.head_dim) 115 | k = self.to_k(x_src).view(-1, self.num_heads, self.head_dim) 116 | v = self.to_v(x_src).view(-1, self.num_heads, self.head_dim) 117 | agg = self.propagate(edge_index=edge_index, x_dst=x_dst, q=q, k=k, v=v, r=r) 118 | return self.to_out(agg) 119 | 120 | def _ff_block(self, x: torch.Tensor) -> torch.Tensor: 121 | return self.ff_mlp(x) -------------------------------------------------------------------------------- /prosim/models/layers/fourier_embedding.py: -------------------------------------------------------------------------------- 1 | # Mostly copied from https://github.com/ZikangZhou/QCNet/blob/main/layers/fourier_embedding.py 2 | import math 3 | from typing import List, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from prosim.models.utils.weight_init import weight_init 9 | 10 | 11 | class FourierEmbedding(nn.Module): 12 | def __init__(self, 13 | input_dim: int, 14 | hidden_dim: int, 15 | num_freq_bands: int) -> None: 16 | super(FourierEmbedding, self).__init__() 17 | self.input_dim = input_dim 18 | self.hidden_dim = hidden_dim 19 | 20 | self.freqs = nn.Embedding(input_dim, num_freq_bands) if input_dim != 0 else None 21 | self.mlps = nn.ModuleList( 22 | [nn.Sequential( 23 | nn.Linear(num_freq_bands * 2 + 1, hidden_dim), 24 | nn.LayerNorm(hidden_dim), 25 | nn.ReLU(inplace=True), 26 | nn.Linear(hidden_dim, hidden_dim), 27 | ) 28 | for _ in range(input_dim)]) 29 | self.to_out = nn.Sequential( 30 | nn.LayerNorm(hidden_dim), 31 | nn.ReLU(inplace=True), 32 | nn.Linear(hidden_dim, hidden_dim), 33 | ) 34 | self.apply(weight_init) 35 | 36 | def forward(self, 37 | continuous_inputs: Optional[torch.Tensor] = None, 38 | categorical_embs: Optional[List[torch.Tensor]] = None) -> torch.Tensor: 39 | if continuous_inputs is None: 40 | if categorical_embs is not None: 41 | x = torch.stack(categorical_embs).sum(dim=0) 42 | else: 43 | raise ValueError('Both continuous_inputs and categorical_embs are None') 44 | else: 45 | x = continuous_inputs.unsqueeze(-1) * self.freqs.weight * 2 * math.pi 46 | # Warning: if your data are noisy, don't use learnable sinusoidal embedding 47 | x = torch.cat([x.cos(), x.sin(), continuous_inputs.unsqueeze(-1)], dim=-1) 48 | continuous_embs: List[Optional[torch.Tensor]] = [None] * self.input_dim 49 | for i in range(self.input_dim): 50 | continuous_embs[i] = self.mlps[i](x[:, i]) 51 | x = torch.stack(continuous_embs).sum(dim=0) 52 | if categorical_embs is not None: 53 | x = x + torch.stack(categorical_embs).sum(dim=0) 54 | return self.to_out(x) 55 | 56 | class FourierEmbeddingFix(nn.Module): 57 | def __init__(self, num_pos_feats=128, temperature=10000) -> None: 58 | super(FourierEmbeddingFix, self).__init__() 59 | 60 | self.num_pos_feats = num_pos_feats 61 | self.temperature = temperature 62 | 63 | def forward(self, continuous_inputs: Optional[torch.Tensor] = None) -> torch.Tensor: 64 | pos = continuous_inputs 65 | 66 | scale = 2 * math.pi 67 | pos = pos * scale 68 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=pos.device) 69 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 70 | 71 | D = pos.shape[-1] 72 | pos_dims = [] 73 | for i in range(D): 74 | pos_dim_i = pos[..., i, None] / dim_t 75 | pos_dim_i = torch.stack((pos_dim_i[..., 0::2].sin(), pos_dim_i[..., 1::2].cos()), dim=-1).flatten(-2) 76 | pos_dims.append(pos_dim_i) 77 | posemb = torch.cat(pos_dims, dim=-1) 78 | 79 | return posemb -------------------------------------------------------------------------------- /prosim/models/policy/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * -------------------------------------------------------------------------------- /prosim/models/policy/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from prosim.core.registry import registry 4 | from .temporal_ar import PolicyNoRNN 5 | 6 | act_decoders = {'policy_no_rnn': PolicyNoRNN} 7 | 8 | @registry.register_policy(name='rel_pe_temporal') 9 | class Policy_RelPE_Temporal(nn.Module): 10 | def __init__(self, config): 11 | super().__init__() 12 | self.config = config 13 | self.p_config = config.MODEL.POLICY 14 | self._config_models() 15 | 16 | def _config_models(self): 17 | self.act_decoder = act_decoders[self.p_config.ACT_DECODER.TYPE](self.config) 18 | 19 | def forward(self, policy_emd, batch_obs, batch_map, batch_pos, pair_names, latent_state): 20 | return self.act_decoder(policy_emd, batch_obs, batch_map, batch_pos, pair_names, latent_state) 21 | 22 | def format_latent_state(self, lante_state_dict, all_batch_pair_names): 23 | return self.act_decoder.format_latent_state(lante_state_dict, all_batch_pair_names) 24 | -------------------------------------------------------------------------------- /prosim/models/policy/temporal_ar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from .act_decoder import AttnRelPE 6 | 7 | class TemporalARDecoder(AttnRelPE): 8 | def _plain_batch_to_temporal(self, batch, pair_names): 9 | ''' 10 | convert the plain batch to temporal batch according to the pair_names 11 | 12 | Input: Batch of shape [B, D], pair_names of shape [B] 13 | Output: Batch of shape [N, T, D], batch_idx of shape [2, B] 14 | 15 | shape change: [B, D] -> [N, T, D] 16 | B: original batch size (all avaliable agents in avaliable time steps) 17 | N: the number of pairs 18 | T: max number of time steps 19 | B <= N * T (some agents may not have T time steps) 20 | ''' 21 | 22 | agent_names = ['-'.join(name.split('-')[:-1]) for name in pair_names] 23 | time_steps = [int(name.split('-')[-1]) for name in pair_names] 24 | 25 | unique_agent_names = sorted(list(set(agent_names))) 26 | unique_time_steps = sorted(list(set(time_steps))) 27 | 28 | N = len(unique_agent_names) 29 | T = len(unique_time_steps) 30 | 31 | D = batch.shape[-1] 32 | batch_T = torch.zeros(N, T, D, dtype=batch.dtype, device=batch.device) 33 | 34 | agent_idx = [unique_agent_names.index(name) for name in agent_names] 35 | time_idx = [unique_time_steps.index(time) for time in time_steps] 36 | 37 | batch_T[agent_idx, time_idx] = batch 38 | 39 | return batch_T, [agent_idx, time_idx] 40 | 41 | def forward(self, policy_emd, batch_obs, batch_map, batch_pos, pair_names, latent_state): 42 | from prosim.rollout.distributed_utils import get_gpu_memory_usage 43 | 44 | context_emd = self._extract_context(policy_emd) 45 | policy_batch_idx = policy_emd['batch_idx'] 46 | fuse_feature = self.attn_fuse(context_emd, policy_batch_idx, batch_obs, batch_map, batch_pos) 47 | fuse_feature_T, T_idx = self._plain_batch_to_temporal(fuse_feature, pair_names) 48 | 49 | pred_feature, latent_state = self._temporal_pred(context_emd, fuse_feature_T, latent_state, T_idx) 50 | 51 | result = self._compute_traj(pred_feature, policy_emd) 52 | 53 | if 'goal' in policy_emd: 54 | result['goal'] = policy_emd['goal'] 55 | 56 | if 'goal_prob' in policy_emd: 57 | result['goal_prob'] = policy_emd['goal_prob'] 58 | result['goal_point'] = policy_emd['goal_point'] 59 | result['select_idx'] = policy_emd['select_idx'] 60 | 61 | result['latent_state'] = latent_state 62 | 63 | return result 64 | 65 | 66 | class PolicyNoRNN(TemporalARDecoder): 67 | def _temporal_pred(self, context_emd, fuse_feature_T, latent_state, T_idx): 68 | pred_feature = fuse_feature_T[T_idx[0], T_idx[1]] 69 | return pred_feature, latent_state 70 | 71 | def format_latent_state(self, lante_state_dict, all_batch_pair_names): 72 | 73 | return None 74 | 75 | def forward(self, policy_emd, batch_obs, batch_map, batch_pos, pair_names, latent_state): 76 | context_emd = self._extract_context(policy_emd) 77 | policy_batch_idx = policy_emd['batch_idx'] 78 | fuse_feature = self.attn_fuse(context_emd, policy_batch_idx, batch_obs, batch_map, batch_pos) 79 | fuse_feature_T, T_idx = self._plain_batch_to_temporal(fuse_feature, pair_names) 80 | pred_feature, latent_state = self._temporal_pred(None, fuse_feature_T, latent_state, T_idx) 81 | 82 | result = self._compute_traj(pred_feature, policy_emd) 83 | 84 | if 'goal' in policy_emd: 85 | result['goal'] = policy_emd['goal'] 86 | 87 | if 'goal_prob' in policy_emd: 88 | result['goal_prob'] = policy_emd['goal_prob'] 89 | result['goal_point'] = policy_emd['goal_point'] 90 | 91 | result['latent_state'] = latent_state 92 | return result 93 | -------------------------------------------------------------------------------- /prosim/models/prompt_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * -------------------------------------------------------------------------------- /prosim/models/prompt_encoder/base.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from prosim.models.layers.mlp import MLP 5 | from prosim.core.registry import registry 6 | @registry.register_prompt_encoder(name='agent_status') 7 | class PromptEncoder(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | self.config = config 11 | self._config_models() 12 | self.state_encoder.apply(self._init_weights) 13 | 14 | def _init_weights(self, module): 15 | if isinstance(module, (nn.Linear, nn.Embedding)): 16 | module.weight.data.normal_(mean=0.0, std=0.02) 17 | if isinstance(module, nn.Linear) and module.bias is not None: 18 | module.bias.data.zero_() 19 | elif isinstance(module, nn.LayerNorm): 20 | module.bias.data.zero_() 21 | module.weight.data.fill_(1.0) 22 | 23 | def _config_prompt_models(self): 24 | status_cfg = self.config.PROMPT.AGENT_STATUS 25 | input_dim = 0 26 | if status_cfg.USE_VEL: 27 | input_dim += 2 28 | if status_cfg.USE_EXTEND: 29 | input_dim += 2 30 | if status_cfg.USE_AGENT_TYPE: 31 | input_dim += 3 32 | self.state_encoder = MLP([input_dim, self.config.MODEL.HIDDEN_DIM, self.config.MODEL.HIDDEN_DIM], ret_before_act=True) 33 | 34 | def _config_models(self): 35 | self._config_prompt_models() 36 | 37 | def _prompt_encode(self, prompt_input): 38 | prompt = prompt_input['prompt'] 39 | prompt_mask = prompt_input['prompt_mask'] 40 | device = next(self.parameters()).device 41 | 42 | prompt = prompt.to(device) 43 | prompt_mask = prompt_mask.to(device) 44 | 45 | prompt_emd = self.state_encoder(prompt) 46 | return prompt_emd, prompt_mask 47 | 48 | def forward(self, prompt_input): 49 | prompt_emd, prompt_mask = self._prompt_encode(prompt_input) 50 | 51 | return prompt_emd, prompt_mask -------------------------------------------------------------------------------- /prosim/models/prompt_generator/__init__.py: -------------------------------------------------------------------------------- 1 | from .generators import * -------------------------------------------------------------------------------- /prosim/models/scene_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .attn_fusion import * -------------------------------------------------------------------------------- /prosim/models/scene_encoder/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | from prosim.core.registry import registry 7 | from .map_encoder import map_encoders 8 | from .obs_encoder import obs_encoders 9 | 10 | @registry.register_scene_encoder(name='base') 11 | class SceneEncoder(nn.Module): 12 | def __init__(self, config, model_cfg): 13 | super().__init__() 14 | self.config = config 15 | self.model_cfg = model_cfg 16 | self.hidden_dim = self.config.MODEL.HIDDEN_DIM 17 | self._config_models() 18 | 19 | def _config_models(self): 20 | self.map_encoder = map_encoders[self.model_cfg.MAP_TYPE](self.config, self.model_cfg) 21 | self.obs_encoder = obs_encoders[self.model_cfg.OBS_TYPE](self.config, self.model_cfg) 22 | 23 | self._config_fusion() 24 | 25 | def _config_fusion(self): 26 | raise NotImplementedError 27 | 28 | def _scene_fusion(self, map_emd, map_mask, obs_emd, obs_mask): 29 | raise NotImplementedError 30 | 31 | def forward(self, batch_obs, batch_map): 32 | # inputs: 33 | # batch_obs encode the observation of the agent (other agents): {'input', 'mask'} 34 | # batch_map encode the map information of the agent (other agents) {'input', 'mask'} 35 | 36 | # output dict: 37 | # scene_tokens: the tokens of the scene elements 38 | # scene_mask: the mask of the scene elements 39 | # scene_emd: a D-dim emb for each scene in the batch 40 | 41 | map_emd, map_mask = self.map_encoder(batch_map) 42 | obs_emd, obs_mask = self.obs_encoder(batch_obs) 43 | 44 | result = self._scene_fusion(batch_map, batch_obs, map_emd, obs_emd, map_mask, obs_mask) 45 | 46 | return result 47 | -------------------------------------------------------------------------------- /prosim/models/scene_encoder/map_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from prosim.models.layers.mlp import MLP 3 | from .pointnet_encoder import PointNetPolylineEncoder 4 | 5 | class MLP_MAP_ENCODER(nn.Module): 6 | def __init__(self, cfg, model_cfg): 7 | super().__init__() 8 | self.config = cfg 9 | self.model_cfg = model_cfg 10 | self.hidden_dim = self.config.MODEL.HIDDEN_DIM 11 | self.pool_func = self.config.MODEL.MAP_ENCODER.MLP.POOL 12 | 13 | self._config_models() 14 | 15 | def _config_models(self): 16 | self.lane_encode = MLP([4, 256, 512, self.hidden_dim], ret_before_act=True) 17 | 18 | self.type_embedding = nn.Embedding(4, self.hidden_dim) 19 | self.traf_embedding = nn.Embedding(4, self.hidden_dim) 20 | 21 | def map_lane_encode(self, lane_inp): 22 | polyline = lane_inp[..., :4] 23 | polyline_type = lane_inp[..., 4].to(int) 24 | polyline_traf = lane_inp[..., 5].to(int) + 1 25 | 26 | polyline_type_embed = self.type_embedding(polyline_type) 27 | polyline_traf_embed = self.traf_embedding(polyline_traf) 28 | 29 | lane_enc = self.lane_encode(polyline) + polyline_traf_embed + polyline_type_embed 30 | 31 | return lane_enc 32 | 33 | def _pool_hist(self, lane_enc, lane_mask): 34 | # lane_enc: [B, M, N, D] 35 | 36 | if self.pool_func == 'mean': 37 | lane_enc = lane_enc.masked_fill(~lane_mask[..., None], 0.0) 38 | lane_enc = lane_enc.sum(dim=2) / lane_mask.sum(dim=2, keepdim=True) 39 | lane_enc = lane_enc.masked_fill(~lane_mask.any(dim=2, keepdim=True), 0.0) 40 | 41 | elif self.pool_func == 'max': 42 | lane_enc = lane_enc.masked_fill(~lane_mask[..., None], -1e9) 43 | lane_enc = lane_enc.max(dim=2)[0] 44 | 45 | else: 46 | raise NotImplementedError 47 | 48 | return lane_enc 49 | 50 | 51 | def forward(self, batch_map): 52 | # encode the map into scene embedding 53 | map_input = batch_map['input'] 54 | map_mask = batch_map['mask'].any(dim=-1) # [B, M, P] 55 | 56 | lane_enc = self.map_lane_encode(map_input) 57 | 58 | # lane_enc: [B, M, N, D] 59 | # M: number of lanes 60 | # N: the number of points for each lane 61 | if len(lane_enc.shape) == 4: 62 | lane_mask = map_input[..., 4] > 0 # [B, M, N] 63 | lane_enc = self._pool_hist(lane_enc, lane_mask) # [B, M, D] 64 | 65 | return lane_enc, map_mask 66 | 67 | class POINTNET_MAP_ENCODER(PointNetPolylineEncoder): 68 | def __init__(self, cfg, model_cfg): 69 | in_dim = 6 70 | 71 | if cfg.DATASET.FORMAT.MAP.WITH_TYPE_EMB: 72 | in_dim += 3 73 | 74 | if cfg.DATASET.FORMAT.MAP.WITH_DIR: 75 | in_dim += 2 76 | 77 | hidden_dim = cfg.MODEL.HIDDEN_DIM 78 | layer_cfg = cfg.MODEL.MAP_ENCODER.POINTNET 79 | super().__init__(in_dim, hidden_dim, layer_cfg) 80 | 81 | def forward(self, batch_map): 82 | # encode the map into scene embedding 83 | map_input = batch_map['input'] 84 | map_mask = batch_map['mask'] # [B, M, P] 85 | 86 | lane_enc = super().forward(map_input, map_mask) 87 | 88 | return lane_enc, map_mask.any(dim=-1) # [B, M] 89 | 90 | map_encoders = {'mlp': MLP_MAP_ENCODER, 'pointnet': POINTNET_MAP_ENCODER} -------------------------------------------------------------------------------- /prosim/models/scene_encoder/obs_encoder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from prosim.models.layers.mlp import MLP 3 | from .pointnet_encoder import PointNetPolylineEncoder 4 | 5 | def get_obs_input_dim(cfg): 6 | in_dim = len(cfg.DATASET.FORMAT.HISTORY.ELEMENTS.split(',')) 7 | 8 | if cfg.DATASET.FORMAT.HISTORY.WITH_EXTEND: 9 | in_dim += 2 10 | 11 | if cfg.DATASET.FORMAT.HISTORY.WITH_AGENT_TYPE: 12 | in_dim += 3 13 | 14 | if cfg.DATASET.FORMAT.HISTORY.WITH_TIME_EMB: 15 | in_dim += cfg.DATASET.FORMAT.HISTORY.STEPS 16 | 17 | return in_dim 18 | 19 | class MLP_OBV_ENCODER(nn.Module): 20 | def __init__(self, cfg, model_cfg): 21 | self.cfg = cfg 22 | self.hidden_dim = cfg.MODEL.HIDDEN_DIM 23 | self.pool_func = cfg.MODEL.OBS_ENCODER.MLP.POOL 24 | super().__init__() 25 | self._config_enc() 26 | 27 | def _config_enc(self): 28 | hist_dim = get_obs_input_dim(self.cfg) 29 | 30 | if self.pool_func == 'none': 31 | hist_step = self.cfg.DATASET.FORMAT.HISTORY.STEPS 32 | input_dim = hist_step * hist_dim 33 | else: 34 | input_dim = hist_dim 35 | 36 | self.hist_encoder = MLP([input_dim, self.hidden_dim // 2, self.hidden_dim], ret_before_act=True) 37 | 38 | def _pool_hist(self, hist_enc, hist_mask): 39 | # hist_enc: [B, N, T, D] 40 | # hist_mask: [B, N, T] 41 | 42 | if self.pool_func == 'mean': 43 | hist_enc = hist_enc.masked_fill(~hist_mask[..., None], 0.0) 44 | hist_enc = hist_enc.sum(dim=2) / hist_mask.sum(dim=2, keepdim=True) 45 | hist_enc = hist_enc.masked_fill(~hist_mask.any(dim=2, keepdim=True), 0.0) 46 | 47 | elif self.pool_func == 'max': 48 | hist_enc = hist_enc.masked_fill(~hist_mask[..., None], -1e9) 49 | hist_enc = hist_enc.max(dim=2)[0] 50 | 51 | else: 52 | raise NotImplementedError 53 | 54 | return hist_enc 55 | 56 | def forward(self, batch_obs): 57 | B, N = batch_obs['input'].shape[:2] 58 | obs_mask = batch_obs['mask'].all(dim=-1) # [B, N, T] 59 | 60 | # avoid propagation of nan values 61 | obs_input = batch_obs['input'].masked_fill(~obs_mask[..., None], 0.0) # [B, N, T, d] 62 | 63 | if self.pool_func == 'none': 64 | obs_input = obs_input.reshape(B, N, -1) # [B, N, T*d] 65 | obs_mask = obs_mask.all(dim=-1) # [B, N] 66 | hist_enc = self.hist_encoder(obs_input) # [B, N, D] 67 | 68 | else: 69 | hist_enc = self.hist_encoder(obs_input) # [B, N, T, D] 70 | hist_enc = self._pool_hist(hist_enc, obs_mask) # [B, N, D] 71 | obs_mask = obs_mask.any(dim=-1) # [B, N] 72 | 73 | return hist_enc, obs_mask 74 | 75 | class POINTNET_OBV_ENCODER(PointNetPolylineEncoder): 76 | def __init__(self, cfg, model_cfg): 77 | in_dim = get_obs_input_dim(cfg) 78 | hidden_dim = cfg.MODEL.HIDDEN_DIM 79 | layer_cfg = cfg.MODEL.OBS_ENCODER.POINTNET 80 | super().__init__(in_dim, hidden_dim, layer_cfg) 81 | 82 | def forward(self, batch_obs): 83 | obs_input = batch_obs['input'] 84 | obs_mask = batch_obs['mask'].all(dim=-1) # [B, N, T] 85 | 86 | hist_enc = super().forward(obs_input, obs_mask) 87 | return hist_enc, obs_mask.any(dim=-1) # [B, N] 88 | 89 | 90 | obs_encoders = {'mlp': MLP_OBV_ENCODER, 'pointnet': POINTNET_OBV_ENCODER} -------------------------------------------------------------------------------- /prosim/models/scene_encoder/pointnet_encoder.py: -------------------------------------------------------------------------------- 1 | # Copied from: https://github.com/sshaoshuai/MTR/blob/master/mtr/models/utils/polyline_encoder.py 2 | # Motion Transformer (MTR): https://arxiv.org/abs/2209.13508 3 | # Published at NeurIPS 2022 4 | # Written by Shaoshuai Shi 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | from prosim.models.layers.mlp import MLP 10 | 11 | # # in_channels, hidden_dim, num_layers=3, num_pre_layers=1 12 | 13 | class PointNetPolylineEncoder(nn.Module): 14 | def __init__(self, in_dim, hidden_dim, layer_cfg): 15 | super().__init__() 16 | 17 | num_layers = layer_cfg.NUM_MLP_LAYERS 18 | num_pre_layers = layer_cfg.NUM_PRE_LAYERS 19 | 20 | self.pre_mlps = MLP([in_dim] + [hidden_dim] * num_pre_layers, ret_before_act=False) 21 | self.mlps = MLP([hidden_dim * 2] + [hidden_dim] * (num_layers - num_pre_layers), ret_before_act=False) 22 | self.out_mlps = MLP([hidden_dim] * 3, without_norm=True, ret_before_act=True) 23 | 24 | def forward(self, polylines, polylines_mask): 25 | """ 26 | Args: 27 | polylines (batch_size, num_polylines, num_points_each_polylines, C): 28 | polylines_mask (batch_size, num_polylines, num_points_each_polylines): 29 | 30 | Returns: 31 | """ 32 | batch_size, num_polylines, num_points_each_polylines, C = polylines.shape 33 | 34 | # print('polylines.dtype:', polylines.dtype) 35 | # print('polylines_mask.dtype:', polylines_mask.dtype) 36 | 37 | # pre-mlp 38 | polylines_feature_valid = self.pre_mlps(polylines[polylines_mask]) # (N, C) 39 | model_dtype = polylines_feature_valid.dtype 40 | polylines_feature = polylines.new_zeros(batch_size, num_polylines, num_points_each_polylines, polylines_feature_valid.shape[-1], dtype=model_dtype) 41 | polylines_feature[polylines_mask] = polylines_feature_valid 42 | 43 | # get global feature 44 | pooled_feature = polylines_feature.max(dim=2)[0] 45 | polylines_feature = torch.cat((polylines_feature, pooled_feature[:, :, None, :].repeat(1, 1, num_points_each_polylines, 1)), dim=-1) 46 | 47 | # mlp 48 | polylines_feature_valid = self.mlps(polylines_feature[polylines_mask]) 49 | feature_buffers = polylines_feature.new_zeros(batch_size, num_polylines, num_points_each_polylines, polylines_feature_valid.shape[-1], dtype=model_dtype) 50 | feature_buffers[polylines_mask] = polylines_feature_valid 51 | 52 | # max-pooling 53 | feature_buffers = feature_buffers.max(dim=2)[0] # (batch_size, num_polylines, C) 54 | 55 | # out-mlp 56 | valid_mask = (polylines_mask.sum(dim=-1) > 0) 57 | feature_buffers_valid = self.out_mlps(feature_buffers[valid_mask]) # (N, C) 58 | feature_buffers = feature_buffers.new_zeros(batch_size, num_polylines, feature_buffers_valid.shape[-1], dtype=model_dtype) 59 | feature_buffers[valid_mask] = feature_buffers_valid 60 | 61 | # print('feature_buffers.dtype:', feature_buffers.dtype) 62 | 63 | return feature_buffers -------------------------------------------------------------------------------- /prosim/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .visualization import visualization_callback 2 | from .weight_init import weight_init -------------------------------------------------------------------------------- /prosim/models/utils/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from trajdata.utils.state_utils import StateTensor 3 | 4 | from pytorch_lightning.callbacks import Callback 5 | from prosim.dataset.data_utils import transform_to_frame_offset_rot 6 | 7 | def extract_history_obs_from_batch(batch): 8 | NAN_PADDING = -1e3 9 | 10 | ego_hist = batch.agent_hist.as_format('x,y').float().unsqueeze(1) 11 | neigh_hist = batch.neigh_hist.as_format('x,y').float() 12 | 13 | B = ego_hist.shape[0] 14 | N = neigh_hist.shape[1] 15 | device = ego_hist.device 16 | 17 | ego_hist = ego_hist.reshape(B, 1, -1) 18 | 19 | if N > 0: 20 | neigh_hist = neigh_hist.reshape(B, N, -1) 21 | neigh_mask = torch.zeros([B, N]).to(device).to(torch.bool) 22 | for b in range(B): 23 | neigh_mask[b, :batch.num_neigh[b]] = True 24 | neigh_hist[b, batch.num_neigh[b]:] = NAN_PADDING 25 | 26 | neigh_hist[neigh_hist.isnan()] = NAN_PADDING 27 | 28 | hist_input = torch.cat([ego_hist, neigh_hist], dim=1) 29 | hist_mask = torch.cat([torch.ones([B, 1]).to(device).to(torch.bool), neigh_mask], dim=1) 30 | else: 31 | hist_input = ego_hist 32 | hist_mask = torch.ones([B, 1]).to(device).to(torch.bool) 33 | 34 | return hist_input, hist_mask 35 | 36 | def extract_agent_obs_from_center_obs(query_names, center_obs): 37 | agent_obs = {} 38 | 39 | center_name = center_obs.agent_name 40 | agent_names = center_obs.neigh_names[0] + center_name 41 | 42 | control_idx = [agent_names.index(name) for name in query_names] 43 | 44 | obs_format = center_obs.neigh_fut._format 45 | 46 | all_hist = torch.concat([center_obs.neigh_hist, center_obs.agent_hist[:, None]], dim=1) 47 | agent_obs['hist'] = StateTensor.from_array(all_hist[:, control_idx], obs_format) 48 | 49 | T = center_obs.neigh_fut.shape[2] 50 | all_fut = torch.concat([center_obs.neigh_fut[:, :, :T], center_obs.agent_fut[:, None, :T]], dim=1) 51 | agent_obs['fut'] = StateTensor.from_array(all_fut[:, control_idx], obs_format) 52 | 53 | all_fut_len = torch.concat([center_obs.neigh_fut_len, center_obs.agent_fut_len[:, None]], dim=1) 54 | agent_obs['fut_len'] = all_fut_len[:, control_idx] 55 | 56 | all_agent_type = torch.concat([center_obs.neigh_types, center_obs.agent_type[:, None]], dim=1) 57 | agent_obs['type'] = all_agent_type[:, control_idx] 58 | 59 | return agent_obs 60 | 61 | def get_agent_pos_dict(agent_hist): 62 | result = {} 63 | 64 | device = agent_hist.device 65 | 66 | agent_curr = agent_hist[:, :, -1] 67 | 68 | result['position'] = agent_curr.as_format('x,y').float().clone() 69 | result['heading'] = agent_curr.as_format('h').float().clone() 70 | 71 | return result -------------------------------------------------------------------------------- /prosim/models/utils/geometry.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/ZikangZhou/QCNet/blob/main/utils/geometry.py#L19 2 | 3 | import math 4 | import torch 5 | 6 | def angle_between_2d_vectors( 7 | 8 | ctr_vector: torch.Tensor, 9 | nbr_vector: torch.Tensor) -> torch.Tensor: 10 | return torch.atan2(ctr_vector[..., 0] * nbr_vector[..., 1] - ctr_vector[..., 1] * nbr_vector[..., 0], 11 | (ctr_vector[..., :2] * nbr_vector[..., :2]).sum(dim=-1)) 12 | 13 | def wrap_angle( 14 | angle: torch.Tensor, 15 | min_val: float = -math.pi, 16 | max_val: float = math.pi) -> torch.Tensor: 17 | return min_val + (angle + max_val) % (max_val - min_val) 18 | 19 | def batch_rotate_2D(xy, theta): 20 | x1 = xy[..., 0] * torch.cos(theta) - xy[..., 1] * torch.sin(theta) 21 | y1 = xy[..., 1] * torch.cos(theta) + xy[..., 0] * torch.sin(theta) 22 | return torch.stack([x1, y1], dim=-1) 23 | 24 | def rel_traj_coord_to_last_step(traj): 25 | """ 26 | Convert an arbitray trajectory to a trajectory relative to the last step. 27 | Args: 28 | traj: tensor of shape (B, traj_len, 4) 29 | x, y, sin, cos 30 | """ 31 | traj_theta = torch.atan2(traj[..., 2], traj[..., 3]) 32 | 33 | origin = traj[..., -1, :] 34 | 35 | xy_offset = traj[..., :2] - origin[..., None, :2] 36 | 37 | xy_offset = batch_rotate_2D(xy_offset, -traj_theta[..., -1:]) 38 | 39 | theta_offset = wrap_angle(traj_theta - traj_theta[..., -1:]) 40 | sin = torch.sin(theta_offset) 41 | cos = torch.cos(theta_offset) 42 | 43 | rel_traj = torch.cat([xy_offset, sin[..., None], cos[..., None]], dim=-1) 44 | 45 | return rel_traj 46 | 47 | def rel_vel_coord_to_last_step(traj, vel): 48 | """ 49 | Convert a list of vel to a vel relative to the last step. 50 | Args: 51 | traj: tensor of shape (B, traj_len, 4) 52 | vel: tensor of shape (B, traj_len, 2) 53 | dx, dy 54 | """ 55 | traj_theta = torch.atan2(traj[..., 2], traj[..., 3]) 56 | 57 | rel_vel = batch_rotate_2D(vel, -traj_theta[..., -1:]) 58 | 59 | return rel_vel -------------------------------------------------------------------------------- /prosim/models/utils/graph.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/ZikangZhou/QCNet/blob/main/utils/graph.py 2 | 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import torch 6 | from torch_geometric.utils import coalesce 7 | from torch_geometric.utils import degree 8 | 9 | 10 | def add_edges( 11 | from_edge_index: torch.Tensor, 12 | to_edge_index: torch.Tensor, 13 | from_edge_attr: Optional[torch.Tensor] = None, 14 | to_edge_attr: Optional[torch.Tensor] = None, 15 | replace: bool = True) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 16 | from_edge_index = from_edge_index.to(device=to_edge_index.device, dtype=to_edge_index.dtype) 17 | mask = ((to_edge_index[0].unsqueeze(-1) == from_edge_index[0].unsqueeze(0)) & 18 | (to_edge_index[1].unsqueeze(-1) == from_edge_index[1].unsqueeze(0))) 19 | if replace: 20 | to_mask = mask.any(dim=1) 21 | if from_edge_attr is not None and to_edge_attr is not None: 22 | from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) 23 | to_edge_attr = torch.cat([to_edge_attr[~to_mask], from_edge_attr], dim=0) 24 | to_edge_index = torch.cat([to_edge_index[:, ~to_mask], from_edge_index], dim=1) 25 | else: 26 | from_mask = mask.any(dim=0) 27 | if from_edge_attr is not None and to_edge_attr is not None: 28 | from_edge_attr = from_edge_attr.to(device=to_edge_attr.device, dtype=to_edge_attr.dtype) 29 | to_edge_attr = torch.cat([to_edge_attr, from_edge_attr[~from_mask]], dim=0) 30 | to_edge_index = torch.cat([to_edge_index, from_edge_index[:, ~from_mask]], dim=1) 31 | return to_edge_index, to_edge_attr 32 | 33 | 34 | def merge_edges( 35 | edge_indices: List[torch.Tensor], 36 | edge_attrs: Optional[List[torch.Tensor]] = None, 37 | reduce: str = 'add') -> Tuple[torch.Tensor, Optional[torch.Tensor]]: 38 | edge_index = torch.cat(edge_indices, dim=1) 39 | if edge_attrs is not None: 40 | edge_attr = torch.cat(edge_attrs, dim=0) 41 | else: 42 | edge_attr = None 43 | return coalesce(edge_index=edge_index, edge_attr=edge_attr, reduce=reduce) 44 | -------------------------------------------------------------------------------- /prosim/models/utils/pos_enc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def pos2posemb2d(pos, num_pos_feats=128, temperature=10000): 5 | """ 6 | Copied from https://github.com/OpenDriveLab/UniAD/blob/main/projects/mmdet3d_plugin/models/utils/functional.py 7 | Convert 2D position into positional embeddings. 8 | 9 | Args: 10 | pos (torch.Tensor): Input 2D position tensor. 11 | num_pos_feats (int, optional): Number of positional features. Default is 128. 12 | temperature (int, optional): Temperature factor for positional embeddings. Default is 10000. 13 | 14 | Returns: 15 | torch.Tensor: Positional embeddings tensor. 16 | """ 17 | scale = 2 * math.pi 18 | pos = pos * scale 19 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 20 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 21 | pos_x = pos[..., 0, None] / dim_t 22 | pos_y = pos[..., 1, None] / dim_t 23 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) 24 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) 25 | posemb = torch.cat((pos_y, pos_x), dim=-1) 26 | return posemb 27 | 28 | def pos2posemb(pos, num_pos_feats=128, temperature=10000): 29 | 30 | scale = 2 * math.pi 31 | pos = pos * scale 32 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 33 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 34 | 35 | D = pos.shape[-1] 36 | posembs = [] 37 | for i in range(D): 38 | pos_i = pos[..., i, None] / dim_t 39 | pos_i = torch.stack((pos_i[..., 0::2].sin(), pos_i[..., 1::2].cos()), dim=-1).flatten(-2) 40 | posembs.append(pos_i) 41 | 42 | posemb = torch.cat(posembs, dim=-1) 43 | 44 | return posemb -------------------------------------------------------------------------------- /prosim/models/utils/weight_init.py: -------------------------------------------------------------------------------- 1 | # Mostly copied from https://github.com/ZikangZhou/QCNet/blob/main/utils/weight_init.py 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def weight_init(m: nn.Module) -> None: 7 | if isinstance(m, nn.Linear): 8 | nn.init.xavier_uniform_(m.weight) 9 | if m.bias is not None: 10 | nn.init.zeros_(m.bias) 11 | elif isinstance(m, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 12 | fan_in = m.in_channels / m.groups 13 | fan_out = m.out_channels / m.groups 14 | bound = (6.0 / (fan_in + fan_out)) ** 0.5 15 | nn.init.uniform_(m.weight, -bound, bound) 16 | if m.bias is not None: 17 | nn.init.zeros_(m.bias) 18 | elif isinstance(m, nn.Embedding): 19 | nn.init.normal_(m.weight, mean=0.0, std=0.02) 20 | elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 21 | nn.init.ones_(m.weight) 22 | nn.init.zeros_(m.bias) 23 | elif isinstance(m, nn.LayerNorm): 24 | nn.init.ones_(m.weight) 25 | nn.init.zeros_(m.bias) 26 | elif isinstance(m, nn.MultiheadAttention): 27 | if m.in_proj_weight is not None: 28 | fan_in = m.embed_dim 29 | fan_out = m.embed_dim 30 | bound = (6.0 / (fan_in + fan_out)) ** 0.5 31 | nn.init.uniform_(m.in_proj_weight, -bound, bound) 32 | else: 33 | nn.init.xavier_uniform_(m.q_proj_weight) 34 | nn.init.xavier_uniform_(m.k_proj_weight) 35 | nn.init.xavier_uniform_(m.v_proj_weight) 36 | if m.in_proj_bias is not None: 37 | nn.init.zeros_(m.in_proj_bias) 38 | nn.init.xavier_uniform_(m.out_proj.weight) 39 | if m.out_proj.bias is not None: 40 | nn.init.zeros_(m.out_proj.bias) 41 | if m.bias_k is not None: 42 | nn.init.normal_(m.bias_k, mean=0.0, std=0.02) 43 | if m.bias_v is not None: 44 | nn.init.normal_(m.bias_v, mean=0.0, std=0.02) 45 | elif isinstance(m, (nn.LSTM, nn.LSTMCell)): 46 | for name, param in m.named_parameters(): 47 | if 'weight_ih' in name: 48 | for ih in param.chunk(4, 0): 49 | nn.init.xavier_uniform_(ih) 50 | elif 'weight_hh' in name: 51 | for hh in param.chunk(4, 0): 52 | nn.init.orthogonal_(hh) 53 | elif 'weight_hr' in name: 54 | nn.init.xavier_uniform_(param) 55 | elif 'bias_ih' in name: 56 | nn.init.zeros_(param) 57 | elif 'bias_hh' in name: 58 | nn.init.zeros_(param) 59 | nn.init.ones_(param.chunk(4, 0)[1]) 60 | elif isinstance(m, (nn.GRU, nn.GRUCell)): 61 | for name, param in m.named_parameters(): 62 | if 'weight_ih' in name: 63 | for ih in param.chunk(3, 0): 64 | nn.init.xavier_uniform_(ih) 65 | elif 'weight_hh' in name: 66 | for hh in param.chunk(3, 0): 67 | nn.init.orthogonal_(hh) 68 | elif 'bias_ih' in name: 69 | nn.init.zeros_(param) 70 | elif 'bias_hh' in name: 71 | nn.init.zeros_(param) -------------------------------------------------------------------------------- /prosim/rollout/__init__.py: -------------------------------------------------------------------------------- 1 | from .callbacks import * -------------------------------------------------------------------------------- /prosim/rollout/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from shapely import affinity 5 | from shapely.geometry import Polygon 6 | from pandas.core.series import Series 7 | 8 | from trajdata.simulation.sim_metrics import SimMetric 9 | from typing import ( 10 | Any, 11 | Callable, 12 | Dict, 13 | Final, 14 | Iterable, 15 | List, 16 | Optional, 17 | Set, 18 | Tuple, 19 | Union, 20 | ) 21 | class CrashDetect(SimMetric): 22 | def __init__(self, tgt_agent_ids: List[str], agent_extends: Dict[str, List[float]], iou_threshold=0.1, mode='sim') -> None: 23 | super().__init__("crash_detect") 24 | 25 | self.tgt_agent_ids = tgt_agent_ids 26 | self.agent_extends = agent_extends 27 | self.iou_threshold = iou_threshold 28 | self.mode = mode 29 | 30 | def _get_box_polygon(self, agent: Series, extents: List[float]): 31 | box_points = np.array([[-0.5, -0.5], [0.5, -0.5], [0.5, 0.5], [-0.5, 0.5]]) 32 | box_points[:, 0] = box_points[:, 0] * extents[1] 33 | box_points[:, 1] = box_points[:, 1] * extents[0] 34 | box = Polygon(box_points) 35 | 36 | # Get the agent polygon 37 | box = affinity.rotate(box, agent['heading'], origin='centroid') 38 | box = affinity.translate(box, agent['x'], agent['y']) 39 | return box 40 | 41 | def _poly_iou_check(self, poly1: Polygon, poly2: Polygon, iou_threshold: float): 42 | if not poly1.intersects(poly2): 43 | return False 44 | 45 | union = poly1.union(poly2).area 46 | inter = poly1.intersection(poly2).area 47 | iou = inter / union 48 | 49 | return iou > iou_threshold 50 | 51 | def __call__(self, gt_df: pd.DataFrame, sim_df: pd.DataFrame): 52 | data = sim_df if self.mode == 'sim' else gt_df 53 | 54 | all_scene_ts = data.index.get_level_values('scene_ts').unique() 55 | 56 | crash_logs = {agent_id: [] for agent_id in self.tgt_agent_ids} 57 | 58 | for scene_ts in all_scene_ts: 59 | scene_data = data.xs(scene_ts, level='scene_ts') 60 | frame_agent_ids = scene_data.index.get_level_values('agent_id').unique() 61 | 62 | # create polygons for all agents in this frame 63 | agent_boxes = {} 64 | for agent_id in frame_agent_ids: 65 | agent = scene_data.loc[agent_id] 66 | 67 | # Get the extents for the agent 68 | agent_extents = self.agent_extends[agent_id] 69 | 70 | # Get the box polygon 71 | box = self._get_box_polygon(agent, agent_extents) 72 | agent_boxes[agent_id] = box 73 | 74 | for target_id in self.tgt_agent_ids: 75 | # target agent does not exist in this frame 76 | if target_id not in frame_agent_ids: 77 | crash_logs[target_id].append(0) 78 | continue 79 | 80 | target_box = agent_boxes[target_id] 81 | 82 | for agent_id in frame_agent_ids: 83 | if agent_id == target_id: 84 | continue 85 | 86 | agent_box = agent_boxes[agent_id] 87 | 88 | if self._poly_iou_check(target_box, agent_box, self.iou_threshold): 89 | crash_logs[target_id].append(1) 90 | break 91 | 92 | if scene_ts not in crash_logs[target_id]: 93 | crash_logs[target_id].append(0) 94 | 95 | crash_detect = {} 96 | for agent_id in self.tgt_agent_ids: 97 | crash_detect[agent_id] = int(crash_logs[agent_id].count(1) > 0) 98 | 99 | return crash_detect 100 | 101 | class GoalReach(SimMetric): 102 | def __init__(self, tgt_agent_ids: List[str], dist_threshold=2.0) -> None: 103 | super().__init__("goal_reach") 104 | 105 | self.tgt_agent_ids = tgt_agent_ids 106 | self.dist_threshold = dist_threshold 107 | 108 | def __call__(self, gt_df: pd.DataFrame, sim_df: pd.DataFrame): 109 | gt_ts = gt_df.index.get_level_values('scene_ts').unique() 110 | sim_ts = sim_df.index.get_level_values('scene_ts').unique() 111 | last_ts = min(gt_ts[-1], sim_ts[-1]) 112 | 113 | goal_reach = {} 114 | 115 | gt_last_frame = gt_df.xs(last_ts, level='scene_ts') 116 | gt_agent_ids = gt_last_frame.index.get_level_values('agent_id').unique() 117 | 118 | sim_last_frame = sim_df.xs(last_ts, level='scene_ts') 119 | sim_agent_ids = sim_last_frame.index.get_level_values('agent_id').unique() 120 | 121 | for agent_id in self.tgt_agent_ids: 122 | if agent_id not in gt_agent_ids or agent_id not in sim_agent_ids: 123 | continue 124 | 125 | gt_agent = gt_last_frame.loc[agent_id] 126 | sim_agent = sim_last_frame.loc[agent_id] 127 | 128 | gt_pos = np.array([gt_agent['x'], gt_agent['y']]) 129 | sim_pos = np.array([sim_agent['x'], sim_agent['y']]) 130 | 131 | dist = np.linalg.norm(gt_pos - sim_pos) 132 | 133 | goal_reach[agent_id] = dist < self.dist_threshold 134 | 135 | return goal_reach -------------------------------------------------------------------------------- /prosim/rollout/package_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tqdm 3 | import glob 4 | from pathlib import Path 5 | import tarfile 6 | import argparse 7 | from multiprocessing import Pool 8 | 9 | from waymo_open_dataset.protos.sim_agents_submission_pb2 import SimAgentsChallengeSubmission 10 | from waymo_open_dataset.protos.sim_agents_submission_pb2 import ScenarioRollouts 11 | 12 | import sys 13 | 14 | sys.path.append(os.getcwd()) 15 | sys.path.append(os.path.dirname(os.getcwd())) 16 | sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) 17 | 18 | from prosim.rollout.baseline import rollout_baseline, get_waymo_scene_object 19 | 20 | import psutil 21 | 22 | def print_system_mem_usage(): 23 | # Get the memory details 24 | memory = psutil.virtual_memory() 25 | 26 | # Total memory 27 | total_memory = memory.total 28 | 29 | # Available memory 30 | available_memory = memory.available 31 | 32 | # Used memory 33 | used_memory = memory.used 34 | 35 | # Memory usage percentage 36 | memory_percent = memory.percent 37 | 38 | print(f"Total Memory: {total_memory / (1024 * 1024 * 1024):.2f} GB", flush=True) 39 | print(f"Available Memory: {available_memory / (1024 * 1024 * 1024):.2f} GB", flush=True) 40 | print(f"Used Memory: {used_memory / (1024 * 1024 * 1024):.2f} GB", flush=True) 41 | print(f"Memory Usage: {memory_percent}%", flush=True) 42 | 43 | 44 | argparser = argparse.ArgumentParser() 45 | argparser.add_argument('--root', type=str, required=True) 46 | argparser.add_argument('--mode', type=str, default='val') 47 | argparser.add_argument('--baseline', type=bool, default=False) 48 | argparser.add_argument('--num_workers', type=int, default=150) 49 | argparser.add_argument('--method_name', type=str, default='debug') 50 | args = argparser.parse_args() 51 | 52 | save_root = Path(args.root) 53 | 54 | rollout_folder = save_root / 'rollout' 55 | rollout_files = glob.glob(str(rollout_folder / '*.pb')) 56 | 57 | if args.mode == 'val': 58 | mode_name = 'validation' 59 | elif args.mode == 'test': 60 | mode_name = 'testing' 61 | scene_template = Path('/lustre/fsw/portfolios/nvr/users/shuhant/waymo_v_1_2_0/scenario') / f'{mode_name}_splitted' 62 | file_template = f'{mode_name}_splitted_' + '{}.tfrecords' 63 | scene_template = scene_template / file_template 64 | 65 | if args.mode == 'val': 66 | assert len(rollout_files) == 44097 67 | 68 | def load_rollout_from_file(file_name): 69 | with open(file_name, 'rb') as f: 70 | string = f.read() 71 | return ScenarioRollouts.FromString(string) 72 | 73 | def run_baseline_rollout(scene_id): 74 | file_path = str(scene_template).format(scene_id) 75 | file_path = Path(file_path) 76 | file_path = os.path.expanduser(file_path) 77 | waymo_scene = get_waymo_scene_object(file_path) 78 | rollout = rollout_baseline(waymo_scene) 79 | 80 | return rollout 81 | 82 | def get_scene_id_from_file(file): 83 | return int(file.split('/')[-1].split('.')[0].split('_')[-1]) 84 | 85 | def pakage_submission_file(worker_id): 86 | print(f'worker {worker_id} started!\n') 87 | worker_files = [] 88 | for file in rollout_files: 89 | scene_id = get_scene_id_from_file(file) 90 | if scene_id % num_workers == worker_id: 91 | worker_files.append(file) 92 | 93 | scenario_rollouts = [] 94 | 95 | if args.baseline: 96 | print('running baseline!') 97 | else: 98 | print('loading rollouts!') 99 | 100 | for file in tqdm.tqdm(worker_files, desc=f'worker {worker_id}'): 101 | if args.baseline: 102 | scene_id = get_scene_id_from_file(file) 103 | rollout = run_baseline_rollout(scene_id) 104 | else: 105 | rollout = load_rollout_from_file(file) 106 | print_system_mem_usage() 107 | scenario_rollouts.append(rollout) 108 | 109 | unique_method_name = 'extrapolate_baseline' if args.baseline else args.method_name 110 | 111 | shard_submission = SimAgentsChallengeSubmission( 112 | scenario_rollouts=scenario_rollouts, 113 | submission_type=SimAgentsChallengeSubmission.SIM_AGENTS_SUBMISSION, 114 | account_name='shuhan@utexas.edu', 115 | unique_method_name=unique_method_name, 116 | authors=['shuhant'], 117 | affiliation='utexas', 118 | description='null', 119 | method_link='https://waymo.com/open/' 120 | ) 121 | 122 | output_file_name = submission_folder / f'submission.binproto-{worker_id:05d}-of-{num_workers:05d}' 123 | 124 | with open(output_file_name, 'wb') as f: 125 | f.write(shard_submission.SerializeToString()) 126 | 127 | return output_file_name 128 | 129 | num_workers = args.num_workers 130 | 131 | if args.baseline: 132 | submission_folder = save_root / 'baseline_submission' 133 | else: 134 | submission_folder = save_root / 'submission' 135 | submission_folder.mkdir(parents=True, exist_ok=True) 136 | 137 | submission_tar = submission_folder / 'submission.tar.gz' 138 | 139 | # Create a pool of workers 140 | with Pool(num_workers) as pool: 141 | # Distribute the work among the workers 142 | file_names = pool.map(pakage_submission_file, range(num_workers)) 143 | 144 | # Once we have created all the shards, we can package them directly into a 145 | # tar.gz archive, ready for submission. 146 | with tarfile.open(submission_tar, 'w:gz') as tar: 147 | for file_name in file_names: 148 | tar.add(file_name, arcname=file_name.name) -------------------------------------------------------------------------------- /prosim/rollout/run_distributed_rollout.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | 4 | sys.path.append(os.getcwd()) 5 | sys.path.append(os.path.dirname(os.getcwd())) 6 | sys.path.append(os.path.dirname(os.path.dirname(os.getcwd()))) 7 | 8 | import argparse 9 | 10 | argparser = argparse.ArgumentParser() 11 | argparser.add_argument("--config", type=str, required=True) 12 | argparser.add_argument("--ckpt", type=str, required=True) 13 | argparser.add_argument('--rollout_name', type=str, required=True) 14 | argparser.add_argument('--save_metric', type=bool, default=True) 15 | argparser.add_argument('--save_rollout', type=bool, default=True) 16 | argparser.add_argument('--cluster', type=str, default='local') 17 | argparser.add_argument("--M", type=int, default=32) 18 | argparser.add_argument("--action_noise_std", type=float, default=0.0) 19 | argparser.add_argument("--traj_noise_std", type=float, default=0.0) 20 | argparser.add_argument("--top_k", type=int, default=3) 21 | argparser.add_argument("--smooth_dist", type=float, default=5.0) 22 | argparser.add_argument("--sampler_cfg", type=str, default=None) 23 | 24 | args = argparser.parse_args() 25 | 26 | from prosim.core.registry import registry 27 | from prosim.config.default import Config, get_config 28 | from prosim.rollout.distributed_utils import rollout_scene_distributed 29 | 30 | print(args.cluster) 31 | 32 | print('save_metric: ', args.save_metric) 33 | print('save_rollout: ', args.save_rollout) 34 | 35 | config = get_config(args.config, cluster=args.cluster) 36 | rollout_scene_distributed(config, args.M, args.ckpt, args.rollout_name, args.save_metric, args.save_rollout, args.top_k, args.traj_noise_std, args.action_noise_std, args.sampler_cfg, args.smooth_dist) --------------------------------------------------------------------------------