├── utils ├── pllava │ ├── convert_pllava_weights_to_hf.py │ └── __init__.py ├── caption_analysis.py └── utils.py ├── .vscode └── settings.json ├── romp ├── vis_human │ ├── sim3drender │ │ ├── __init__.py │ │ └── lib │ │ │ └── rasterize.h │ ├── __init__.py │ └── pyrenderer.py └── __init__.py ├── preprocessing_script_upload ├── odometry_processed │ ├── huang-2-2019-01-25_0_pos.npy │ ├── stlc-111-2019-04-19_0_pos.npy │ ├── tressider-2019-03-16_0_pos.npy │ ├── tressider-2019-03-16_1_pos.npy │ ├── tressider-2019-04-26_2_pos.npy │ ├── bytes-cafe-2019-02-07_0_pos.npy │ ├── forbes-cafe-2019-01-22_0_pos.npy │ ├── huang-lane-2019-02-12_0_pos.npy │ ├── jordan-hall-2019-04-22_0_pos.npy │ ├── meyer-green-2019-03-16_0_pos.npy │ ├── nvidia-aud-2019-04-18_0_pos.npy │ ├── clark-center-2019-02-28_0_pos.npy │ ├── clark-center-2019-02-28_1_pos.npy │ ├── gates-ai-lab-2019-02-08_0_pos.npy │ ├── gates-to-clark-2019-02-28_1_pos.npy │ ├── huang-basement-2019-01-25_0_pos.npy │ ├── memorial-court-2019-03-16_0_pos.npy │ ├── huang-2-2019-01-25_0_orientation.npy │ ├── stlc-111-2019-04-19_0_orientation.npy │ ├── bytes-cafe-2019-02-07_0_orientation.npy │ ├── cubberly-auditorium-2019-04-22_0_pos.npy │ ├── forbes-cafe-2019-01-22_0_orientation.npy │ ├── huang-lane-2019-02-12_0_orientation.npy │ ├── jordan-hall-2019-04-22_0_orientation.npy │ ├── meyer-green-2019-03-16_0_orientation.npy │ ├── nvidia-aud-2019-04-18_0_orientation.npy │ ├── svl-meeting-gates-2-2019-04-08_0_pos.npy │ ├── svl-meeting-gates-2-2019-04-08_1_pos.npy │ ├── tressider-2019-03-16_0_orientation.npy │ ├── tressider-2019-03-16_1_orientation.npy │ ├── tressider-2019-04-26_2_orientation.npy │ ├── clark-center-2019-02-28_0_orientation.npy │ ├── clark-center-2019-02-28_1_orientation.npy │ ├── gates-ai-lab-2019-02-08_0_orientation.npy │ ├── gates-159-group-meeting-2019-04-03_0_pos.npy │ ├── gates-basement-elevators-2019-01-17_1_pos.npy │ ├── gates-to-clark-2019-02-28_1_orientation.npy │ ├── huang-basement-2019-01-25_0_orientation.npy │ ├── memorial-court-2019-03-16_0_orientation.npy │ ├── packard-poster-session-2019-03-20_0_pos.npy │ ├── packard-poster-session-2019-03-20_1_pos.npy │ ├── packard-poster-session-2019-03-20_2_pos.npy │ ├── clark-center-intersection-2019-02-28_0_pos.npy │ ├── cubberly-auditorium-2019-04-22_0_orientation.npy │ ├── hewlett-packard-intersection-2019-01-24_0_pos.npy │ ├── svl-meeting-gates-2-2019-04-08_0_orientation.npy │ ├── svl-meeting-gates-2-2019-04-08_1_orientation.npy │ ├── gates-159-group-meeting-2019-04-03_0_orientation.npy │ ├── packard-poster-session-2019-03-20_0_orientation.npy │ ├── packard-poster-session-2019-03-20_1_orientation.npy │ ├── packard-poster-session-2019-03-20_2_orientation.npy │ ├── clark-center-intersection-2019-02-28_0_orientation.npy │ ├── gates-basement-elevators-2019-01-17_1_orientation.npy │ └── hewlett-packard-intersection-2019-01-24_0_orientation.npy ├── bev │ └── split2process.py └── preprocess_1st_jrdb.py ├── parser_model ├── test.py ├── calibration │ ├── defaults.yaml │ └── cameras.yaml ├── text_encoder.py ├── caption_getter_vllm.py ├── text_encoder_bert.py └── caption_getter.py ├── README.md ├── .gitignore ├── HiVT ├── metrics │ ├── __init__.py │ ├── mr.py │ ├── ade.py │ └── fde.py ├── losses │ ├── __init__.py │ ├── soft_target_cross_entropy_loss.py │ └── laplace_nll_loss.py ├── datasets │ ├── __init__.py │ ├── jrdb_dataset.py │ └── jrdb_dataset_poseViz.py ├── datamodules │ ├── __init__.py │ ├── jrdb_datamodule.py │ ├── jrdb_datamodule_LED.py │ ├── jrdb_datamodule_poseViz.py │ └── argoverse_v1_datamodule.py ├── models │ └── __init__.py ├── ___scripts_copy │ ├── eval.py │ ├── visualization.py │ ├── train_opp.py │ ├── README.md │ ├── validate.py │ ├── train_ethucy.py │ ├── _train_mart_ethucy.py │ ├── _train_socialTransmotion_ethucy.py │ ├── _train_socialTransmotion_ethucy_v1_1.py │ ├── _train_mart_ethucy_lightning2.py │ ├── _train_LED.py │ ├── train_sit.py │ ├── validate_viz_sit.py │ ├── validate_viz_saveSample.py │ ├── train.py │ ├── _train_socialTransmotion.py │ ├── train_joint.py │ ├── validate_viz_saveSample_sit.py │ ├── train_kd_finetune.py │ └── validate_viz_qual.py ├── visualization.py ├── train.py ├── validate.py ├── train_ethucy.py ├── ___models_copy │ └── __init__.py ├── _train_mart_ethucy.py ├── _train_socialTransmotion_ethucy.py ├── _train_LED.py ├── train_sit.py ├── _train_socialTransmotion.py └── train_kd_finetune.py ├── preprocessing_script ├── reorder_datas_ethucy.py ├── reorder_datas.py ├── reorder_datas_sit.py └── preprocess_0_ethucy.py └── bev └── split2process.py /utils/pllava/convert_pllava_weights_to_hf.py: -------------------------------------------------------------------------------- 1 | # Not yet -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "git.ignoreLimitWarning": true 3 | } -------------------------------------------------------------------------------- /romp/vis_human/sim3drender/__init__.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | from .renderer import Sim3DR 4 | -------------------------------------------------------------------------------- /romp/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import ROMP, romp_settings 2 | from .utils import WebcamVideoStream, ResultSaver 3 | -------------------------------------------------------------------------------- /romp/vis_human/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import setup_renderer, rendering_romp_bev_results 2 | from .vis_utils import mesh_color_left2right -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/huang-2-2019-01-25_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/huang-2-2019-01-25_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/stlc-111-2019-04-19_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/stlc-111-2019-04-19_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/tressider-2019-03-16_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/tressider-2019-03-16_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/tressider-2019-03-16_1_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/tressider-2019-03-16_1_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/tressider-2019-04-26_2_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/tressider-2019-04-26_2_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/bytes-cafe-2019-02-07_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/bytes-cafe-2019-02-07_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/forbes-cafe-2019-01-22_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/forbes-cafe-2019-01-22_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/huang-lane-2019-02-12_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/huang-lane-2019-02-12_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/jordan-hall-2019-04-22_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/jordan-hall-2019-04-22_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/meyer-green-2019-03-16_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/meyer-green-2019-03-16_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/nvidia-aud-2019-04-18_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/nvidia-aud-2019-04-18_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/clark-center-2019-02-28_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/clark-center-2019-02-28_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/clark-center-2019-02-28_1_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/clark-center-2019-02-28_1_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/gates-ai-lab-2019-02-08_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/gates-ai-lab-2019-02-08_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/gates-to-clark-2019-02-28_1_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/gates-to-clark-2019-02-28_1_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/huang-basement-2019-01-25_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/huang-basement-2019-01-25_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/memorial-court-2019-03-16_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/memorial-court-2019-03-16_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/huang-2-2019-01-25_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/huang-2-2019-01-25_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/stlc-111-2019-04-19_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/stlc-111-2019-04-19_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/bytes-cafe-2019-02-07_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/bytes-cafe-2019-02-07_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/cubberly-auditorium-2019-04-22_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/cubberly-auditorium-2019-04-22_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/forbes-cafe-2019-01-22_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/forbes-cafe-2019-01-22_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/huang-lane-2019-02-12_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/huang-lane-2019-02-12_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/jordan-hall-2019-04-22_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/jordan-hall-2019-04-22_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/meyer-green-2019-03-16_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/meyer-green-2019-03-16_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/nvidia-aud-2019-04-18_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/nvidia-aud-2019-04-18_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/svl-meeting-gates-2-2019-04-08_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/svl-meeting-gates-2-2019-04-08_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/svl-meeting-gates-2-2019-04-08_1_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/svl-meeting-gates-2-2019-04-08_1_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/tressider-2019-03-16_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/tressider-2019-03-16_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/tressider-2019-03-16_1_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/tressider-2019-03-16_1_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/tressider-2019-04-26_2_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/tressider-2019-04-26_2_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/clark-center-2019-02-28_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/clark-center-2019-02-28_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/clark-center-2019-02-28_1_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/clark-center-2019-02-28_1_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/gates-ai-lab-2019-02-08_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/gates-ai-lab-2019-02-08_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/gates-159-group-meeting-2019-04-03_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/gates-159-group-meeting-2019-04-03_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/gates-basement-elevators-2019-01-17_1_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/gates-basement-elevators-2019-01-17_1_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/gates-to-clark-2019-02-28_1_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/gates-to-clark-2019-02-28_1_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/huang-basement-2019-01-25_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/huang-basement-2019-01-25_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/memorial-court-2019-03-16_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/memorial-court-2019-03-16_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_1_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_1_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_2_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_2_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/clark-center-intersection-2019-02-28_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/clark-center-intersection-2019-02-28_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/cubberly-auditorium-2019-04-22_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/cubberly-auditorium-2019-04-22_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/hewlett-packard-intersection-2019-01-24_0_pos.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/hewlett-packard-intersection-2019-01-24_0_pos.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/svl-meeting-gates-2-2019-04-08_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/svl-meeting-gates-2-2019-04-08_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/svl-meeting-gates-2-2019-04-08_1_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/svl-meeting-gates-2-2019-04-08_1_orientation.npy -------------------------------------------------------------------------------- /parser_model/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | sys.path.append('/mnt/jaewoo4tb/textraj/') 4 | 5 | test = torch.load("/mnt/jaewoo4tb/textraj/preprocessed_2nd/v1_fps_2_5_frame_20/bytes-cafe-2019-02-07_0_agents_0_to_1726/0.pt") 6 | 7 | input() -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/gates-159-group-meeting-2019-04-03_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/gates-159-group-meeting-2019-04-03_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_1_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_1_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_2_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/packard-poster-session-2019-03-20_2_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/clark-center-intersection-2019-02-28_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/clark-center-intersection-2019-02-28_0_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/gates-basement-elevators-2019-01-17_1_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/gates-basement-elevators-2019-01-17_1_orientation.npy -------------------------------------------------------------------------------- /preprocessing_script_upload/odometry_processed/hewlett-packard-intersection-2019-01-24_0_orientation.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jaewoo97/KDTF/HEAD/preprocessing_script_upload/odometry_processed/hewlett-packard-intersection-2019-01-24_0_orientation.npy -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KDTF 2 | 3 | > **Status:** Initial commit complete ✅ 4 | > Instructions, specifications, and detailed documentation will be added later this week. 5 | 6 | This repository has just been initialized. 7 | Please check back soon for: 8 | 9 | - Project overview & objectives 10 | - Installation & usage instructions 11 | - Specifications and examples 12 | 13 | Stay tuned! 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | HiVT/__pycache__ 2 | HiVT/__viz_* 3 | HiVT/_logs* 4 | HiVT/.vscode 5 | HiVT/logs* 6 | **/__pycache__/ 7 | LAVIS 8 | preprocessed_1st 9 | preprocessed_2nd 10 | *egovlm.py 11 | preprocessing_script/*nba.py 12 | preprocessing_script/*nfl.py 13 | preprocessing_script/*soccer.py 14 | preprocessing_script/preprocess_2nd_jrdb_v2_joint.py 15 | preprocessing_script/preprocess_2nd_jrdb_v2_viz.py 16 | preprocessing_script/preprocess_2nd.py 17 | preprocessing_script/preprocess_2nd_v3.py 18 | preprocessing_script/*.sh 19 | HiVT/models_original 20 | romp/files/ 21 | HiVT/___models_copy 22 | HiVT/___scripts_copy 23 | # ignore all ROMP blobs 24 | romp/files/ 25 | 26 | romp/files/ 27 | 28 | romp/files/ 29 | -------------------------------------------------------------------------------- /HiVT/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from metrics.ade import ADE 15 | from metrics.fde import FDE 16 | from metrics.mr import MR 17 | -------------------------------------------------------------------------------- /HiVT/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from losses.laplace_nll_loss import LaplaceNLLLoss 15 | from losses.soft_target_cross_entropy_loss import SoftTargetCrossEntropyLoss 16 | -------------------------------------------------------------------------------- /HiVT/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # from datasets.argoverse_v1_dataset import ArgoverseV1Dataset 15 | from datasets.jrdb_dataset import jrdbDataset 16 | from datasets.jrdb_dataset_poseViz import jrdbDataset_poseViz -------------------------------------------------------------------------------- /HiVT/datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # from datamodules.argoverse_v1_datamodule import ArgoverseV1DataModule 15 | from datamodules.jrdb_datamodule import jrdbDatamodule 16 | from datamodules.jrdb_datamodule_LED import jrdbDatamodule_LED 17 | from datamodules.jrdb_datamodule_poseViz import jrdbDatamodule_poseViz 18 | -------------------------------------------------------------------------------- /parser_model/calibration/defaults.yaml: -------------------------------------------------------------------------------- 1 | calibrated: 2 | # the lidar_to_rgb parameters allow tweaking of the transformation between lidar and rgb frames 3 | # the default transformation is taken from the TF Tree 4 | # NOTE: applied to the original (sensor/velodyne) frame [x forward, y left, z up]: 5 | lidar_upper_to_rgb: 6 | # in meters: [x,y,z] 7 | translation: [0, 0, -0.33529] 8 | # in radians: [x,y,z] 9 | rotation: [0, 0, 0.085] 10 | 11 | lidar_lower_to_rgb: 12 | 13 | translation: [0, 0, 0.13511] 14 | 15 | rotation: [0, 0, 0] 16 | image: 17 | # all in pixels 18 | width: 3760 19 | height: 480 20 | # y-axis forward pixel offset (e.g. 3760/2 => 1880, b/c center of the cylindrical image is forward) 21 | # TODO: move into calibrated params, when auto-calibration is possible 22 | stitched_image_offset: 1880 23 | 24 | frames: 25 | # lookup for people transforms 26 | global: base_link 27 | # name of the rgb360 camera frame to which we wish to transform 28 | rgb360: occam 29 | -------------------------------------------------------------------------------- /preprocessing_script/reorder_datas_ethucy.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import shutil 4 | 5 | root_dir = "/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_zara1_copy" 6 | train_dir = f"{root_dir}/train" 7 | val_dir = f"{root_dir}/val" 8 | 9 | if not os.path.exists(train_dir): 10 | os.makedirs(train_dir) 11 | if not os.path.exists(val_dir): 12 | os.makedirs(val_dir) 13 | 14 | scenes = os.listdir(root_dir) 15 | scenes.sort() 16 | scenes.remove("train") 17 | scenes.remove("val") 18 | 19 | val_idxs = [2] 20 | # ['eth', 'hotel', 'zara1', 'zara2', 'students'] 21 | for idx, scene in tqdm(enumerate(scenes)): 22 | sub_dir = f"{root_dir}/{scene}" 23 | files = os.listdir(sub_dir) 24 | for file in files: 25 | src = f"{sub_dir}/{file}" 26 | if idx in val_idxs: 27 | dst = f"{val_dir}/{str(idx).zfill(3)}_{file[:-3].zfill(4)}.pt" 28 | else: 29 | dst = f"{train_dir}/{str(idx).zfill(3)}_{file[:-3].zfill(4)}.pt" 30 | 31 | # shutil.copy(src, dst) 32 | shutil.move(src, dst) -------------------------------------------------------------------------------- /preprocessing_script/reorder_datas.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import shutil 4 | 5 | root_dir = "/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_egovlm" 6 | train_dir = f"{root_dir}/train" 7 | val_dir = f"{root_dir}/val" 8 | 9 | if not os.path.exists(train_dir): 10 | os.makedirs(train_dir) 11 | if not os.path.exists(val_dir): 12 | os.makedirs(val_dir) 13 | 14 | scenes = os.listdir(root_dir) 15 | scenes.sort() 16 | scenes.remove("train") 17 | scenes.remove("val") 18 | 19 | val_idxs = [0, 13, 14, 18] # bytes cafe, huang lane, jordan hall, poster session 0 => v2_fps_2_5_frame_20 20 | 21 | for idx, scene in tqdm(enumerate(scenes)): 22 | sub_dir = f"{root_dir}/{scene}" 23 | files = os.listdir(sub_dir) 24 | for file in files: 25 | if file[-3:] != ".pt": 26 | continue 27 | 28 | src = f"{sub_dir}/{file}" 29 | if idx in val_idxs: 30 | dst = f"{val_dir}/{str(idx).zfill(3)}_{file[:-3].zfill(4)}.pt" 31 | else: 32 | dst = f"{train_dir}/{str(idx).zfill(3)}_{file[:-3].zfill(4)}.pt" 33 | 34 | # shutil.copy(src, dst) 35 | shutil.move(src, dst) -------------------------------------------------------------------------------- /parser_model/text_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPTokenizer, CLIPModel, CLIPTextModel 3 | 4 | 5 | class TextEncoder: 6 | def __init__(self, model='openai/clip-vit-base-patch32'): # possible model: 'openai/clip-vit-base-patch16', 'openai/clip-vit-large-patch14' 7 | self.torch_device = "cuda" if torch.cuda.is_available() else "cpu" 8 | 9 | self.tokenizer = CLIPTokenizer.from_pretrained(model) 10 | self.text_encoder = CLIPTextModel.from_pretrained(model).to(self.torch_device) 11 | self.model = CLIPModel.from_pretrained(model).to(self.torch_device) 12 | 13 | def __call__(self, text: str) -> torch.Tensor: 14 | text_inputs = self.tokenizer( 15 | text, 16 | padding="max_length", 17 | return_tensors="pt", 18 | ).to(self.torch_device) 19 | # text_embeddings = torch.flatten(self.text_encoder(text_inputs.input_ids.to(self.torch_device))['last_hidden_state'],1,-1) 20 | text_features = self.model.get_text_features(**text_inputs) 21 | return text_features 22 | 23 | 24 | if __name__ == "__main__": 25 | test = TextEncoder() 26 | emb = test("this is a test sentence") 27 | print(emb.shape) -------------------------------------------------------------------------------- /preprocessing_script/reorder_datas_sit.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import shutil 4 | 5 | root_dir = "/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo_v2" 6 | train_dir = f"{root_dir}/train" 7 | val_dir = f"{root_dir}/val" 8 | 9 | if not os.path.exists(train_dir): 10 | os.makedirs(train_dir) 11 | if not os.path.exists(val_dir): 12 | os.makedirs(val_dir) 13 | 14 | scenes = os.listdir(root_dir) 15 | scenes.sort() 16 | scenes.remove("train") 17 | scenes.remove("val") 18 | 19 | # 0: Cafe_street_1-002_agents_0_to_200 / outdoor 20 | # 4: Cafeteria_3-004_agents_0_to_200 / indoor 21 | # 7: Corridor_1-010_agents_0_to_200 / indoor 22 | # 36: Lobby_6-001_agents_0_to_200 / indoor 23 | # 40: Outdoor_Alley_3-002_agents_0_to_200 / outdoor 24 | # 44: Three_way_Intersection_4-001_agents_0_to_200 / outdoor 25 | # 41: Subway_Entrance_2-004 / outdoor 26 | 27 | val_idxs = [0, 4, 7, 36, 40, 44, 41] 28 | 29 | for idx, scene in tqdm(enumerate(scenes)): 30 | sub_dir = f"{root_dir}/{scene}" 31 | files = os.listdir(sub_dir) 32 | for file in files: 33 | if file[-3:] != ".pt": 34 | continue 35 | 36 | src = f"{sub_dir}/{file}" 37 | if idx in val_idxs: 38 | dst = f"{val_dir}/{str(idx).zfill(3)}_{file[:-3].zfill(4)}.pt" 39 | else: 40 | dst = f"{train_dir}/{str(idx).zfill(3)}_{file[:-3].zfill(4)}.pt" 41 | 42 | # shutil.copy(src, dst) 43 | shutil.move(src, dst) -------------------------------------------------------------------------------- /HiVT/losses/soft_target_cross_entropy_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class SoftTargetCrossEntropyLoss(nn.Module): 20 | 21 | def __init__(self, reduction: str = 'mean') -> None: 22 | super(SoftTargetCrossEntropyLoss, self).__init__() 23 | self.reduction = reduction 24 | 25 | def forward(self, 26 | pred: torch.Tensor, 27 | target: torch.Tensor) -> torch.Tensor: 28 | cross_entropy = torch.sum(-target * F.log_softmax(pred, dim=-1), dim=-1) 29 | if self.reduction == 'mean': 30 | return cross_entropy.mean() 31 | elif self.reduction == 'sum': 32 | return cross_entropy.sum() 33 | elif self.reduction == 'none': 34 | return cross_entropy 35 | else: 36 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction)) 37 | -------------------------------------------------------------------------------- /HiVT/losses/laplace_nll_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import torch.nn as nn 16 | 17 | 18 | class LaplaceNLLLoss(nn.Module): 19 | 20 | def __init__(self, 21 | eps: float = 1e-6, 22 | reduction: str = 'mean') -> None: 23 | super(LaplaceNLLLoss, self).__init__() 24 | self.eps = eps 25 | self.reduction = reduction 26 | 27 | def forward(self, 28 | pred: torch.Tensor, 29 | target: torch.Tensor) -> torch.Tensor: 30 | loc, scale = pred.chunk(2, dim=-1) 31 | scale = scale.clone() 32 | with torch.no_grad(): 33 | scale.clamp_(min=self.eps) 34 | nll = torch.log(2 * scale) + torch.abs(target - loc) / scale 35 | if self.reduction == 'mean': 36 | return nll.mean() 37 | elif self.reduction == 'sum': 38 | return nll.sum() 39 | elif self.reduction == 'none': 40 | return nll 41 | else: 42 | raise ValueError('{} is not a valid value for reduction'.format(self.reduction)) 43 | -------------------------------------------------------------------------------- /parser_model/caption_getter_vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ["CUDA_VISIBLE_DEVICES"] = '4' 3 | os.environ['HF_HOME'] = '/mnt/jaewoo4tb/cache/' 4 | from io import BytesIO 5 | 6 | import requests 7 | from PIL import Image 8 | import numpy as np 9 | import time 10 | from vllm import LLM, SamplingParams 11 | 12 | # import copy 13 | import torch 14 | # import time 15 | def logging_time(original_fn): 16 | def wrapper_fn(*args, **kwargs): 17 | start_time = time.time() 18 | result = original_fn(*args, **kwargs) 19 | end_time = time.time() 20 | print("WorkingTime[{}]: {} sec".format(original_fn.__name__, end_time-start_time)) 21 | return result 22 | return wrapper_fn 23 | 24 | 25 | class CaptionModel: 26 | def __init__(self): 27 | self.llm = LLM(model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=4096) 28 | 29 | @logging_time 30 | def get_caption(self, image, question): 31 | if isinstance(image, np.ndarray): 32 | image = Image.fromarray(image) 33 | with torch.no_grad(): 34 | sampling_params = SamplingParams(temperature=0.8, 35 | top_p=0.97, 36 | max_tokens=40) 37 | outputs = self.llm.generate( 38 | { 39 | "prompt": question, 40 | "multi_modal_data": { 41 | "image": image 42 | } 43 | }, 44 | sampling_params=sampling_params) 45 | generated_text = "" 46 | for o in outputs: 47 | generated_text += o.outputs[0].text 48 | return generated_text -------------------------------------------------------------------------------- /preprocessing_script/preprocess_0_ethucy.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | def video2frames(file_path): 5 | file_list = ["students003.avi","crowds_zara01.avi", "crowds_zara02.avi", "biwi_eth.avi","biwi_hotel.avi"] 6 | # file_list = ["hotel.avi"] 7 | img_list = ["students", "zara1", "zara2","eth", "hotel"] 8 | img_path = "/ssd4tb/ethucy/images/" 9 | for i in range(len(file_list)): 10 | file_name = file_path+file_list[i] 11 | cap = cv2.VideoCapture(file_name) 12 | video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) -1 13 | img_file_path = img_path+img_list[i] 14 | os.makedirs(img_file_path, exist_ok=True) 15 | count = 0 16 | print(f"start converting video {file_list[i]}") 17 | # os.makedirs(img_file_path) 18 | while cap.isOpened(): 19 | # extract the frame 20 | ret,frame = cap.read() 21 | if not ret: 22 | continue 23 | if count % 10 != 0: 24 | count = count+1 25 | if (count > (video_length)): 26 | cap.release() 27 | print(f"Done extracting frames. {count} frames extracted.") 28 | break 29 | continue 30 | cv2.imwrite(img_file_path+f"/{str(count).zfill(4)}.jpg",frame) 31 | count = count+1 32 | if (count > (video_length)): 33 | cap.release() 34 | print(f"Done extracting frames. {count} frames extracted.") 35 | break 36 | 37 | if __name__ == '__main__': 38 | file_path = "/ssd4tb/ethucy/" 39 | # crowds_zara01.avi students003.avi crowds_zara01.avi crowds_zara02.avi hotel.avi eth.avi 40 | video2frames(file_path) -------------------------------------------------------------------------------- /HiVT/metrics/mr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Optional 15 | 16 | import torch 17 | from torchmetrics import Metric 18 | 19 | 20 | class MR(Metric): 21 | 22 | def __init__(self, 23 | miss_threshold: float = 2.0, 24 | dist_sync_on_step: bool = False, 25 | process_group: Optional[Any] = None, 26 | dist_sync_fn: Callable = None) -> None: 27 | super(MR, self).__init__(dist_sync_on_step=dist_sync_on_step, 28 | process_group=process_group, dist_sync_fn=dist_sync_fn) 29 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 30 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 31 | self.miss_threshold = miss_threshold 32 | 33 | def update(self, 34 | pred: torch.Tensor, 35 | target: torch.Tensor) -> None: 36 | self.sum += (torch.norm(pred[:, -1] - target[:, -1], p=2, dim=-1) > self.miss_threshold).sum() 37 | self.count += pred.size(0) 38 | 39 | def compute(self) -> torch.Tensor: 40 | return self.sum / self.count 41 | -------------------------------------------------------------------------------- /HiVT/metrics/ade.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Optional 15 | 16 | import torch 17 | from torchmetrics import Metric 18 | 19 | 20 | class ADE(Metric): 21 | 22 | def __init__(self, 23 | dist_sync_on_step: bool = False, 24 | process_group: Optional[Any] = None, 25 | dist_sync_fn: Callable = None) -> None: 26 | super(ADE, self).__init__(dist_sync_on_step=dist_sync_on_step, 27 | process_group=process_group, dist_sync_fn=dist_sync_fn) 28 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 29 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 30 | 31 | def update(self, 32 | pred: torch.Tensor, 33 | target: torch.Tensor, 34 | reg_mask: torch.Tensor) -> None: 35 | reg_mask_count = reg_mask.sum(-1) 36 | reg_mask_count[reg_mask_count==0] = 1 37 | ade = torch.norm(pred - target, p=2, dim=-1) 38 | ade = ade * reg_mask # N X T 39 | ade = ade.sum(-1) / reg_mask_count 40 | self.sum += ade.sum() 41 | self.count += (reg_mask.sum(-1) > 0).sum() 42 | 43 | def compute(self) -> torch.Tensor: 44 | return self.sum / self.count 45 | -------------------------------------------------------------------------------- /parser_model/text_encoder_bert.py: -------------------------------------------------------------------------------- 1 | from transformers import BertTokenizer, BertModel 2 | import torch 3 | 4 | class TextEncoder_BERT: 5 | def __init__(self): 6 | self.torch_device = "cuda" if torch.cuda.is_available() else "cpu" 7 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 8 | self.model = BertModel.from_pretrained('bert-base-uncased') 9 | self.model.to(self.torch_device) 10 | 11 | def __call__(self, text: str) -> torch.Tensor: 12 | inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True) 13 | for k in inputs.keys(): 14 | if torch.is_tensor(inputs[k]): inputs[k] = inputs[k].to('cuda') 15 | with torch.no_grad(): 16 | outputs = self.model(**inputs) 17 | assert outputs.last_hidden_state.shape[0] == 1 18 | return outputs.last_hidden_state[0,0,:] # return class token 19 | 20 | class TextEncoder_TinyBERT: 21 | def __init__(self): 22 | self.torch_device = "cuda" if torch.cuda.is_available() else "cpu" 23 | self.tokenizer = BertTokenizer.from_pretrained('huawei-noah/TinyBERT_General_4L_312D', model_max_length=256) 24 | self.model = BertModel.from_pretrained('huawei-noah/TinyBERT_General_4L_312D') 25 | self.model.to(self.torch_device) 26 | 27 | def __call__(self, text: str) -> torch.Tensor: 28 | inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True) 29 | for k in inputs.keys(): 30 | if torch.is_tensor(inputs[k]): inputs[k] = inputs[k].to(self.torch_device) 31 | with torch.no_grad(): 32 | outputs = self.model(**inputs) 33 | assert outputs.last_hidden_state.shape[0] == 1 34 | return outputs.last_hidden_state[0,0,:] # [CLS] token representation 35 | 36 | if __name__ == "__main__": 37 | test = TextEncoder_BERT() 38 | res = test("test string") 39 | print(res.shape) -------------------------------------------------------------------------------- /HiVT/metrics/fde.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Callable, Optional 15 | 16 | import torch 17 | from torchmetrics import Metric 18 | 19 | 20 | class FDE(Metric): 21 | 22 | def __init__(self, 23 | dist_sync_on_step: bool = False, 24 | process_group: Optional[Any] = None, 25 | dist_sync_fn: Callable = None) -> None: 26 | super(FDE, self).__init__(dist_sync_on_step=dist_sync_on_step, 27 | process_group=process_group, dist_sync_fn=dist_sync_fn) 28 | self.add_state('sum', default=torch.tensor(0.0), dist_reduce_fx='sum') 29 | self.add_state('count', default=torch.tensor(0), dist_reduce_fx='sum') 30 | 31 | def update(self, 32 | pred: torch.Tensor, 33 | target: torch.Tensor, 34 | reg_mask: torch.Tensor) -> None: 35 | reg_mask[:, :-1] = False 36 | reg_mask_count = reg_mask.sum(-1) 37 | reg_mask_count[reg_mask_count==0] = 1 38 | ade = torch.norm(pred - target, p=2, dim=-1) 39 | ade = ade * reg_mask # N X T 40 | ade = ade.sum(-1) / reg_mask_count 41 | self.sum += ade.sum() 42 | self.count += (reg_mask.sum(-1) > 0).sum() 43 | 44 | def compute(self) -> torch.Tensor: 45 | return self.sum / self.count 46 | -------------------------------------------------------------------------------- /HiVT/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from models.decoder import GRUDecoder 15 | from models.decoder import MLPDecoder 16 | from models.embedding import MultipleInputEmbedding 17 | from models.embedding import SingleInputEmbedding 18 | from models.embedding import TrajPoseEmbedding, MultipleInputEmbedding_mask, MultipleInputEmbedding_mask_v2, MultipleInputEmbedding_mask_v3, MultipleInputEmbedding_mask_v4, MultipleInputEmbedding_mask_v5 19 | from models.global_interactor import GlobalInteractor 20 | from models.global_interactor import GlobalInteractorLayer 21 | from models.global_interactor_text import GlobalInteractor_text 22 | 23 | from models.local_encoder_onlyTraj_v3 import LocalEncoder_onlyTraj_v3 24 | from models.local_encoder_wPose import LocalEncoder_wPose 25 | from models.local_encoder_wText import LocalEncoder_wText 26 | from models.local_encoder_wPoseText_v3 import LocalEncoder_wPoseText_v3 27 | 28 | # ETH/UCY 29 | from models.local_encoder_wPose_ethucy import LocalEncoder_wPose_ethucy 30 | from models.local_encoder_wPoseText_ethucy import LocalEncoder_wPoseText_ethucy 31 | from models.local_encoder_wPoseText_ethucy_v1 import LocalEncoder_wPoseText_ethucy_v1 32 | from models.local_encoder_wText_ethucy import LocalEncoder_wText_ethucy 33 | # from models.local_encoder_original import LocalEncoder_original 34 | 35 | # SIT 36 | from models.local_encoder_wPoseText_sit import LocalEncoder_wPoseText_sit -------------------------------------------------------------------------------- /utils/pllava/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available 17 | 18 | 19 | _import_structure = {"configuration_pllava": ["PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP", "PllavaConfig"]} 20 | 21 | try: 22 | if not is_torch_available(): 23 | raise OptionalDependencyNotAvailable() 24 | except OptionalDependencyNotAvailable: 25 | pass 26 | else: 27 | _import_structure["modeling_pllava"] = [ 28 | "PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST", 29 | "PllavaForConditionalGeneration", 30 | "PllavaPreTrainedModel", 31 | ] 32 | _import_structure["processing_pllava"] = ["PllavaProcessor"] 33 | 34 | 35 | if TYPE_CHECKING: 36 | from .configuration_pllava import PLLAVA_PRETRAINED_CONFIG_ARCHIVE_MAP, PllavaConfig 37 | 38 | try: 39 | if not is_torch_available(): 40 | raise OptionalDependencyNotAvailable() 41 | except OptionalDependencyNotAvailable: 42 | pass 43 | else: 44 | from .modeling_pllava import ( 45 | PLLAVA_PRETRAINED_MODEL_ARCHIVE_LIST, 46 | PllavaForConditionalGeneration, 47 | PllavaPreTrainedModel, 48 | ) 49 | from .processing_pllava import PllavaProcessor 50 | 51 | 52 | else: 53 | import sys 54 | 55 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure) 56 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 16 | from argparse import ArgumentParser 17 | 18 | import pytorch_lightning as pl 19 | from torch_geometric.data import DataLoader 20 | 21 | from datasets import ArgoverseV1Dataset 22 | from models.hivt import HiVT 23 | 24 | if __name__ == '__main__': 25 | pl.seed_everything(2022) 26 | 27 | parser = ArgumentParser() 28 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/data/argoverse') 29 | parser.add_argument('--batch_size', type=int, default=4) 30 | parser.add_argument('--num_workers', type=int, default=8) 31 | parser.add_argument('--pin_memory', type=bool, default=True) 32 | parser.add_argument('--persistent_workers', type=bool, default=True) 33 | parser.add_argument('--gpus', type=int, default=1) 34 | parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/HiVT/checkpoints/HiVT-64/checkpoints/epoch=63-step=411903.ckpt') 35 | args = parser.parse_args() 36 | 37 | trainer = pl.Trainer.from_argparse_args(args) 38 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=True) 39 | val_dataset = ArgoverseV1Dataset(root=args.root, split='val', local_radius=model.hparams.local_radius) 40 | dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, 41 | pin_memory=args.pin_memory, persistent_workers=args.persistent_workers) 42 | trainer.validate(model, dataloader) 43 | -------------------------------------------------------------------------------- /utils/caption_analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nltk 3 | from nltk import word_tokenize, pos_tag 4 | import torch 5 | from tqdm import tqdm 6 | import glob 7 | import matplotlib.pyplot as plt 8 | 9 | data_root = '/mnt/jaewoo4tb/textraj/preprocessed_1st/jrdb_v1/' 10 | 11 | nltk.download('punkt') 12 | nltk.download('averaged_perceptron_tagger') 13 | 14 | 15 | # analysis perprocessed datas 16 | files = glob.glob(os.path.join(data_root, '*.pt')) 17 | files.sort() 18 | 19 | words = [] 20 | filters = ['VBG'] 21 | 22 | idx = 0 23 | for file in tqdm(files): 24 | words.append({}) 25 | data = torch.load(file) 26 | for frame, frame_data in data.items(): 27 | for person, person_data in frame_data.items(): 28 | if ('description' not in person_data) or (person_data['description'] is None): 29 | continue 30 | 31 | caption = person_data['description'].lower() 32 | sentence = word_tokenize(caption) 33 | pos_tags = pos_tag(sentence) 34 | for word, tag in pos_tags: 35 | if tag in filters: 36 | if word not in words[idx]: 37 | words[idx][word] = 1 38 | else: 39 | words[idx][word] += 1 40 | 41 | idx += 1 42 | 43 | # concat train/val 44 | val_idxs = [0, 13, 14, 18] 45 | 46 | train_words = {} 47 | val_words = {} 48 | 49 | for idx, word in enumerate(words): 50 | for w, c in word.items(): 51 | if idx in val_idxs: 52 | if w not in val_words: 53 | val_words[w] = c 54 | else: 55 | val_words[w] += c 56 | else: 57 | if w not in train_words: 58 | train_words[w] = c 59 | else: 60 | train_words[w] += c 61 | 62 | 63 | # plot histogram 64 | def plot_word_histogram(word_data, ax, split, max_bars=10): 65 | sorted_word_data = dict(sorted(word_data.items(), key=lambda item: item[1], reverse=True)) 66 | 67 | filtered_word_data = dict(list(sorted_word_data.items())[:max_bars]) 68 | 69 | words = list(filtered_word_data.keys()) 70 | frequencies = list(filtered_word_data.values()) 71 | 72 | ax.bar(words, frequencies, color='skyblue') 73 | ax.set_xlabel('Word') 74 | ax.set_ylabel('Frequency') 75 | ax.set_title(f'{split} Word-Frequency Histogram') 76 | ax.tick_params(axis='x', rotation=45) 77 | 78 | fig, axs = plt.subplots(1, 2, figsize=(10, 6)) 79 | 80 | plot_word_histogram(train_words, axs[0], "train", 15) 81 | plot_word_histogram(val_words, axs[1], "val", 15) 82 | 83 | plt.tight_layout() 84 | plt.show() 85 | plt.savefig("jrdb.jpg", dpi=1000) 86 | 87 | breakpoint() -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | # function for suppress print 2 | import os 3 | import sys 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation as R 6 | from typing import List, Optional, Tuple 7 | 8 | import torch 9 | import torch.nn as nn 10 | # from torch_geometric.data import Data 11 | def stopPrint(func, *args, **kwargs): 12 | with open(os.devnull,"w") as devNull: 13 | original = sys.stdout 14 | sys.stdout = devNull 15 | func(*args, **kwargs) 16 | sys.stdout = original 17 | 18 | # Rotation stuff 19 | def axis_angle_to_matrix(axis_angle): 20 | axis = axis_angle[:3] 21 | angle = np.linalg.norm(axis) 22 | if angle > 0: 23 | axis = axis / angle # Normalize the axis 24 | return R.from_rotvec(axis * angle).as_matrix() 25 | 26 | def get_heading_direction(theta): 27 | """ 28 | Get the heading direction from the SMPL theta parameters. 29 | 30 | Args: 31 | theta: (N, 3) array of axis-angle representations for N joints. 32 | 33 | Returns: 34 | heading_direction: (3,) array representing the heading direction. 35 | """ 36 | theta = theta.reshape(-1, 3) 37 | # Extract the root joint's rotation (usually the first joint in SMPL) 38 | root_rotation_axis_angle = theta[0] 39 | 40 | # Convert the root joint's axis-angle representation to a rotation matrix 41 | root_rotation_matrix = axis_angle_to_matrix(root_rotation_axis_angle) 42 | 43 | # Define the forward direction vector (assume [1, 0, 0] for X-forward) 44 | forward_vector = np.array([1, 0, 0]) 45 | 46 | # Apply the root joint's rotation to the forward vector 47 | heading_direction = root_rotation_matrix @ forward_vector 48 | 49 | return heading_direction.reshape(-1) 50 | 51 | def matrix_to_axis_angle(matrix): 52 | rot = R.from_matrix(matrix) 53 | return rot.as_rotvec() 54 | 55 | def apply_z_rotation_on_theta(theta, z_rotation_angle): 56 | """ 57 | Apply a rotation around the Z-axis to the SMPL theta parameter. 58 | 59 | Args: 60 | theta: (N, 3) array of axis-angle representations for N joints. 61 | z_rotation_angle: rotation angle around the Z-axis in radians. 62 | 63 | Returns: 64 | (N, 3) array of updated axis-angle representations. 65 | """ 66 | theta = theta.reshape(-1, 3) 67 | z_rotation_matrix = R.from_euler('z', z_rotation_angle).as_matrix() 68 | 69 | updated_theta = np.zeros_like(theta) 70 | for i in range(theta.shape[0]): 71 | rotation_matrix = axis_angle_to_matrix(theta[i]) 72 | new_rotation_matrix = z_rotation_matrix @ rotation_matrix 73 | new_axis_angle = matrix_to_axis_angle(new_rotation_matrix) 74 | updated_theta[i] = new_axis_angle 75 | 76 | return updated_theta.reshape(-1) 77 | 78 | -------------------------------------------------------------------------------- /HiVT/datamodules/jrdb_datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Optional 15 | 16 | from pytorch_lightning import LightningDataModule 17 | from torch_geometric.loader import DataLoader 18 | 19 | from datasets import jrdbDataset 20 | 21 | 22 | class jrdbDatamodule(LightningDataModule): 23 | 24 | def __init__(self, 25 | root: str, 26 | train_batch_size: int, 27 | val_batch_size: int, 28 | shuffle: bool = True, 29 | num_workers: int = 8, 30 | pin_memory: bool = True, 31 | persistent_workers: bool = True, 32 | train_transform: Optional[Callable] = None, 33 | val_transform: Optional[Callable] = None) -> None: 34 | super(jrdbDatamodule, self).__init__() 35 | self.root = root 36 | self.train_batch_size = train_batch_size 37 | self.val_batch_size = val_batch_size 38 | self.shuffle = shuffle 39 | self.pin_memory = pin_memory 40 | self.persistent_workers = persistent_workers 41 | self.num_workers = num_workers 42 | self.train_transform = train_transform 43 | self.val_transform = val_transform 44 | 45 | def prepare_data(self) -> None: 46 | jrdbDataset(self.root, 'train', self.train_transform) 47 | jrdbDataset(self.root, 'val', self.val_transform) 48 | 49 | def setup(self, stage: Optional[str] = None) -> None: 50 | self.train_dataset = jrdbDataset(self.root, 'train', self.train_transform) 51 | self.val_dataset = jrdbDataset(self.root, 'val', self.val_transform) 52 | 53 | def train_dataloader(self): 54 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, 55 | num_workers=self.num_workers, pin_memory=self.pin_memory, 56 | persistent_workers=self.persistent_workers) 57 | 58 | def val_dataloader(self): 59 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, 60 | pin_memory=self.pin_memory, persistent_workers=self.persistent_workers) 61 | -------------------------------------------------------------------------------- /HiVT/datamodules/jrdb_datamodule_LED.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Optional 15 | 16 | from pytorch_lightning import LightningDataModule 17 | from torch_geometric.loader import DataLoader 18 | 19 | from datasets import jrdbDataset 20 | 21 | 22 | class jrdbDatamodule_LED(LightningDataModule): 23 | 24 | def __init__(self, 25 | root: str, 26 | train_batch_size: int, 27 | val_batch_size: int, 28 | shuffle: bool = True, 29 | num_workers: int = 8, 30 | pin_memory: bool = True, 31 | persistent_workers: bool = True, 32 | train_transform: Optional[Callable] = None, 33 | val_transform: Optional[Callable] = None) -> None: 34 | super(jrdbDatamodule_LED, self).__init__() 35 | self.root = root 36 | self.train_batch_size = train_batch_size 37 | self.val_batch_size = val_batch_size 38 | self.shuffle = shuffle 39 | self.pin_memory = pin_memory 40 | self.persistent_workers = persistent_workers 41 | self.num_workers = num_workers 42 | self.train_transform = train_transform 43 | self.val_transform = val_transform 44 | 45 | def prepare_data(self) -> None: 46 | jrdbDataset(self.root, 'train', self.train_transform) 47 | jrdbDataset(self.root, 'val', self.val_transform) 48 | 49 | def setup(self, stage: Optional[str] = None) -> None: 50 | self.train_dataset = jrdbDataset(self.root, 'train', self.train_transform) 51 | self.val_dataset = jrdbDataset(self.root, 'val', self.val_transform) 52 | 53 | def train_dataloader(self): 54 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, 55 | num_workers=self.num_workers, pin_memory=self.pin_memory, 56 | persistent_workers=self.persistent_workers, drop_last=True) 57 | 58 | def val_dataloader(self): 59 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, 60 | pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, drop_last=True) 61 | -------------------------------------------------------------------------------- /HiVT/datamodules/jrdb_datamodule_poseViz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Optional 15 | 16 | from pytorch_lightning import LightningDataModule 17 | from torch_geometric.loader import DataLoader 18 | 19 | from datasets import jrdbDataset_poseViz 20 | 21 | 22 | class jrdbDatamodule_poseViz(LightningDataModule): 23 | 24 | def __init__(self, 25 | root: str, 26 | train_batch_size: int, 27 | val_batch_size: int, 28 | shuffle: bool = True, 29 | num_workers: int = 8, 30 | pin_memory: bool = True, 31 | persistent_workers: bool = True, 32 | train_transform: Optional[Callable] = None, 33 | val_transform: Optional[Callable] = None) -> None: 34 | super(jrdbDatamodule_poseViz, self).__init__() 35 | self.root = root 36 | self.train_batch_size = train_batch_size 37 | self.val_batch_size = val_batch_size 38 | self.shuffle = shuffle 39 | self.pin_memory = pin_memory 40 | self.persistent_workers = persistent_workers 41 | self.num_workers = num_workers 42 | self.train_transform = train_transform 43 | self.val_transform = val_transform 44 | 45 | def prepare_data(self) -> None: 46 | jrdbDataset_poseViz(self.root, 'train', self.train_transform) 47 | jrdbDataset_poseViz(self.root, 'val', self.val_transform) 48 | 49 | def setup(self, stage: Optional[str] = None) -> None: 50 | self.train_dataset = jrdbDataset_poseViz(self.root, 'train', self.train_transform) 51 | self.val_dataset = jrdbDataset_poseViz(self.root, 'val', self.val_transform) 52 | 53 | def train_dataloader(self): 54 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, 55 | num_workers=self.num_workers, pin_memory=self.pin_memory, 56 | persistent_workers=self.persistent_workers) 57 | 58 | def val_dataloader(self): 59 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, 60 | pin_memory=self.pin_memory, persistent_workers=self.persistent_workers) 61 | -------------------------------------------------------------------------------- /parser_model/caption_getter.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_VISIBLE_DEVICES"] = '4' 3 | os.environ['HF_HOME'] = '/mnt/jaewoo4tb/cache/' 4 | 5 | from PIL import Image 6 | import requests 7 | import copy 8 | import torch 9 | import time 10 | from llava.model.builder import load_pretrained_model 11 | from llava.mm_utils import get_model_name_from_path, process_images, tokenizer_image_token 12 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, IGNORE_INDEX 13 | from llava.conversation import conv_templates, SeparatorStyle 14 | def logging_time(original_fn): 15 | def wrapper_fn(*args, **kwargs): 16 | start_time = time.time() 17 | result = original_fn(*args, **kwargs) 18 | end_time = time.time() 19 | print("WorkingTime[{}]: {} sec".format(original_fn.__name__, end_time-start_time)) 20 | return result 21 | return wrapper_fn 22 | 23 | 24 | class CaptionModel: 25 | def __init__(self): 26 | pretrained = "lmms-lab/llama3-llava-next-8b" 27 | model_name = "llava_llama3" 28 | print("start") 29 | # pretrained = "lmms-lab/llava-next-72b" 30 | # model_name = "llava_llama3" 31 | self.device = "cuda" 32 | device_map = "auto" 33 | self.tokenizer, self.model, self.image_processor, max_length = load_pretrained_model(pretrained, None, model_name, device_map=device_map) # Add any other thing you want to pass in llava_model_args 34 | self.model.eval() 35 | self.model.tie_weights() 36 | 37 | @logging_time 38 | def get_caption(self, image, question): 39 | image_tensor = process_images([image], self.image_processor, self.model.config) 40 | image_tensor = [_image.to(dtype=torch.float16, device=self.device) for _image in image_tensor] 41 | conv_template = "llava_llama_3" # Make sure you use correct chat template for different models 42 | # conv_template = "qwen_1_5" # Make sure you use correct chat template for different models 43 | question = DEFAULT_IMAGE_TOKEN + "\n" + question 44 | conv = copy.deepcopy(conv_templates[conv_template]) 45 | conv.append_message(conv.roles[0], question) 46 | conv.append_message(conv.roles[1], None) 47 | prompt_question = conv.get_prompt() 48 | input_ids = tokenizer_image_token(prompt_question, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device) 49 | image_sizes = [image.size] 50 | cont = self.model.generate( 51 | input_ids, 52 | images=image_tensor, 53 | image_sizes=image_sizes, 54 | do_sample=False, 55 | temperature=0, 56 | max_new_tokens=50, 57 | pad_token_id=self.tokenizer.eos_token_id 58 | ) 59 | text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True) 60 | return text_outputs 61 | -------------------------------------------------------------------------------- /bev/split2process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | from bev.post_parser import remove_subjects 5 | 6 | def padding_image_overlap(image, overlap_ratio=0.46): 7 | h, w = image.shape[:2] 8 | pad_length = int(h* overlap_ratio) 9 | pad_w = w+2*pad_length 10 | pad_image = np.zeros((h, pad_w, 3), dtype=np.uint8) 11 | top, left = 0, pad_length 12 | bottom, right = h, w+pad_length 13 | pad_image[top:bottom, left:right] = image 14 | 15 | # due to BEV takes square input, so we convert top, bottom to the state that assuming square padding 16 | pad_height = (w - h)//2 17 | top = pad_height 18 | bottom = w - top 19 | left = 0 20 | right = w 21 | image_pad_info = torch.Tensor([top, bottom, left, right, h, w]) 22 | return pad_image, image_pad_info, pad_length 23 | 24 | def get_image_split_plan(image, overlap_ratio=0.46): 25 | h, w = image.shape[:2] 26 | aspect_ratio = w / h 27 | slide_time = int(np.ceil((aspect_ratio - 1) / (1 - overlap_ratio))) + 1 28 | 29 | crop_box = [] # left, right, top, bottom 30 | move_step = (1 - overlap_ratio) * h 31 | for ind in range(slide_time): 32 | if ind == (slide_time-1): 33 | left = w-h 34 | else: 35 | left = move_step * ind 36 | right = left+h 37 | crop_box.append([left, right, 0, h]) 38 | 39 | return np.array(crop_box).astype(np.int32) 40 | 41 | def exclude_boudary_subjects(outputs, drop_boundary_ratio, ptype='left', torlerance=0.05): 42 | if ptype=='left': 43 | drop_mask = outputs['cam'][:, 2] > (1 - drop_boundary_ratio + torlerance) 44 | elif ptype=='right': 45 | drop_mask = outputs['cam'][:, 2] < (drop_boundary_ratio - 1 - torlerance) 46 | remove_subjects(outputs, torch.where(drop_mask)[0]) 47 | 48 | def convert_crop_cam_params2full_image(cam_params, crop_bbox, image_shape): 49 | h, w = image_shape 50 | # adjust scale, cam 3: depth, y, x 51 | scale_adjust = (crop_bbox[[1,3]]-crop_bbox[[0,2]]).max() / max(h, w) 52 | cam_params *= scale_adjust 53 | 54 | # adjust x 55 | # crop_bbox[:2] -= pad_length 56 | bbox_mean_x = crop_bbox[:2].mean() 57 | cam_params[:,2] += bbox_mean_x / (w /2) - 1 58 | return cam_params 59 | 60 | def collect_outputs(outputs, all_outputs): 61 | keys = list(outputs.keys()) 62 | for key in keys: 63 | if key not in all_outputs: 64 | all_outputs[key] = outputs[key] 65 | else: 66 | if key in ['smpl_face']: 67 | continue 68 | if key in ['center_map']: 69 | all_outputs[key] = torch.cat([all_outputs[key], outputs[key]],3) 70 | continue 71 | if key in ['center_map_3d']: 72 | all_outputs[key] = torch.cat([all_outputs[key], outputs[key]],2) 73 | continue 74 | all_outputs[key] = torch.cat([all_outputs[key], outputs[key]],0) 75 | -------------------------------------------------------------------------------- /preprocessing_script_upload/bev/split2process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | from bev.post_parser import remove_subjects 5 | 6 | def padding_image_overlap(image, overlap_ratio=0.46): 7 | h, w = image.shape[:2] 8 | pad_length = int(h* overlap_ratio) 9 | pad_w = w+2*pad_length 10 | pad_image = np.zeros((h, pad_w, 3), dtype=np.uint8) 11 | top, left = 0, pad_length 12 | bottom, right = h, w+pad_length 13 | pad_image[top:bottom, left:right] = image 14 | 15 | # due to BEV takes square input, so we convert top, bottom to the state that assuming square padding 16 | pad_height = (w - h)//2 17 | top = pad_height 18 | bottom = w - top 19 | left = 0 20 | right = w 21 | image_pad_info = torch.Tensor([top, bottom, left, right, h, w]) 22 | return pad_image, image_pad_info, pad_length 23 | 24 | def get_image_split_plan(image, overlap_ratio=0.46): 25 | h, w = image.shape[:2] 26 | aspect_ratio = w / h 27 | slide_time = int(np.ceil((aspect_ratio - 1) / (1 - overlap_ratio))) + 1 28 | 29 | crop_box = [] # left, right, top, bottom 30 | move_step = (1 - overlap_ratio) * h 31 | for ind in range(slide_time): 32 | if ind == (slide_time-1): 33 | left = w-h 34 | else: 35 | left = move_step * ind 36 | right = left+h 37 | crop_box.append([left, right, 0, h]) 38 | 39 | return np.array(crop_box).astype(np.int32) 40 | 41 | def exclude_boudary_subjects(outputs, drop_boundary_ratio, ptype='left', torlerance=0.05): 42 | if ptype=='left': 43 | drop_mask = outputs['cam'][:, 2] > (1 - drop_boundary_ratio + torlerance) 44 | elif ptype=='right': 45 | drop_mask = outputs['cam'][:, 2] < (drop_boundary_ratio - 1 - torlerance) 46 | remove_subjects(outputs, torch.where(drop_mask)[0]) 47 | 48 | def convert_crop_cam_params2full_image(cam_params, crop_bbox, image_shape): 49 | h, w = image_shape 50 | # adjust scale, cam 3: depth, y, x 51 | scale_adjust = (crop_bbox[[1,3]]-crop_bbox[[0,2]]).max() / max(h, w) 52 | cam_params *= scale_adjust 53 | 54 | # adjust x 55 | # crop_bbox[:2] -= pad_length 56 | bbox_mean_x = crop_bbox[:2].mean() 57 | cam_params[:,2] += bbox_mean_x / (w /2) - 1 58 | return cam_params 59 | 60 | def collect_outputs(outputs, all_outputs): 61 | keys = list(outputs.keys()) 62 | for key in keys: 63 | if key not in all_outputs: 64 | all_outputs[key] = outputs[key] 65 | else: 66 | if key in ['smpl_face']: 67 | continue 68 | if key in ['center_map']: 69 | all_outputs[key] = torch.cat([all_outputs[key], outputs[key]],3) 70 | continue 71 | if key in ['center_map_3d']: 72 | all_outputs[key] = torch.cat([all_outputs[key], outputs[key]],2) 73 | continue 74 | all_outputs[key] = torch.cat([all_outputs[key], outputs[key]],0) 75 | -------------------------------------------------------------------------------- /HiVT/datamodules/argoverse_v1_datamodule.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Optional 15 | 16 | from pytorch_lightning import LightningDataModule 17 | from torch_geometric.data import DataLoader 18 | 19 | from datasets import ArgoverseV1Dataset 20 | 21 | 22 | class ArgoverseV1DataModule(LightningDataModule): 23 | 24 | def __init__(self, 25 | root: str, 26 | train_batch_size: int, 27 | val_batch_size: int, 28 | shuffle: bool = True, 29 | num_workers: int = 8, 30 | pin_memory: bool = True, 31 | persistent_workers: bool = True, 32 | train_transform: Optional[Callable] = None, 33 | val_transform: Optional[Callable] = None, 34 | local_radius: float = 50) -> None: 35 | super(ArgoverseV1DataModule, self).__init__() 36 | self.root = root 37 | self.train_batch_size = train_batch_size 38 | self.val_batch_size = val_batch_size 39 | self.shuffle = shuffle 40 | self.pin_memory = pin_memory 41 | self.persistent_workers = persistent_workers 42 | self.num_workers = num_workers 43 | self.train_transform = train_transform 44 | self.val_transform = val_transform 45 | self.local_radius = local_radius 46 | 47 | def prepare_data(self) -> None: 48 | ArgoverseV1Dataset(self.root, 'train', self.train_transform, self.local_radius) 49 | ArgoverseV1Dataset(self.root, 'val', self.val_transform, self.local_radius) 50 | 51 | def setup(self, stage: Optional[str] = None) -> None: 52 | self.train_dataset = ArgoverseV1Dataset(self.root, 'train', self.train_transform, self.local_radius) 53 | self.val_dataset = ArgoverseV1Dataset(self.root, 'val', self.val_transform, self.local_radius) 54 | 55 | def train_dataloader(self): 56 | return DataLoader(self.train_dataset, batch_size=self.train_batch_size, shuffle=self.shuffle, 57 | num_workers=self.num_workers, pin_memory=self.pin_memory, 58 | persistent_workers=self.persistent_workers) 59 | 60 | def val_dataloader(self): 61 | return DataLoader(self.val_dataset, batch_size=self.val_batch_size, shuffle=False, num_workers=self.num_workers, 62 | pin_memory=self.pin_memory, persistent_workers=self.persistent_workers) 63 | -------------------------------------------------------------------------------- /HiVT/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def viz_trajectory(pred, gt, data, output_dir, batch_idx): 8 | # print(f'Visualizing trajectory of batch {batch_idx}') 9 | if not os.path.isdir(os.path.join(output_dir, 'viz_results', 'trajectory')): os.makedirs(os.path.join(output_dir, 'viz_results', 'trajectory')) 10 | # pred: B*N, T, 15, 3 11 | B = data.num_graphs 12 | T1 = data.x.shape[1] 13 | T2 = gt.shape[1] 14 | x_gt = data.x.cpu().numpy() 15 | y_gt = gt.cpu().numpy() 16 | y_pred = pred[:, :, :2].cpu().detach().numpy() 17 | pos = data.positions.cpu().numpy() 18 | thetas = data.theta.cpu().numpy() 19 | 20 | 21 | for scene_idx in range(B): 22 | fig = plt.figure(figsize=(4, 4)) 23 | ax = fig.add_subplot(111) 24 | xy_mean = np.zeros((0, 2)) 25 | scene_filter = (data.batch==scene_idx).cpu().numpy() 26 | x_scene_gt = x_gt[scene_filter] 27 | y_scene_gt = y_gt[scene_filter] 28 | y_scene_pred = y_pred[scene_filter] 29 | pos0 = pos[scene_filter, 19, :] 30 | theta = thetas[scene_idx] 31 | inv_rot = np.array([[np.cos(theta), np.sin(theta)], 32 | [-np.sin(theta), np.cos(theta)]]) 33 | 34 | N = x_scene_gt.shape[0] 35 | if N < 1: continue 36 | 37 | for agent_idx in range(N): 38 | 39 | agent_pos0 = pos0[agent_idx, :] 40 | 41 | x_scene_agent_gt = np.cumsum(x_scene_gt[agent_idx, :, :], axis=0) 42 | x_scene_agent_gt = np.matmul(x_scene_agent_gt, inv_rot) 43 | x_scene_agent_gt = agent_pos0 - x_scene_agent_gt[19] + x_scene_agent_gt 44 | 45 | y_scene_agent_gt = y_scene_gt[agent_idx, :, :] 46 | y_scene_agent_gt = np.matmul(y_scene_agent_gt, inv_rot) 47 | y_scene_agent_gt = agent_pos0 - y_scene_agent_gt[0] + y_scene_agent_gt 48 | 49 | y_scene_agent_pred = y_scene_pred[agent_idx, :, :] 50 | y_scene_agent_pred = np.matmul(y_scene_agent_pred, inv_rot) 51 | y_scene_agent_pred = agent_pos0 - y_scene_agent_pred[0] + y_scene_agent_pred 52 | 53 | 54 | ax.plot(x_scene_agent_gt[:, 0], x_scene_agent_gt[:, 1], 'ko-', linewidth=0.25, markersize=0.5) 55 | ax.plot(y_scene_agent_gt[:, 0], y_scene_agent_gt[:, 1], 'bo-', linewidth=0.25, markersize=0.5) 56 | ax.plot(y_scene_agent_pred[:, 0], y_scene_agent_pred[:, 1], 'ro-', linewidth=0.25, markersize=0.5) 57 | 58 | #xy_mean = np.concatenate((xy_mean, np.expand_dims(x_scene_gt.mean(0)[:2], axis=0)), axis=0) 59 | #ax.set_xlim([xy_mean.mean(0)[0]-5, xy_mean.mean(0)[0]+5]) 60 | #ax.set_ylim([xy_mean.mean(0)[1]-5, xy_mean.mean(0)[1]+5]) 61 | ax.set_xlabel("x") 62 | ax.set_ylabel("y") 63 | fig.savefig(f'{output_dir}/viz_results/trajectory/batch_{batch_idx}_scene_{scene_idx}.png') 64 | plt.close() 65 | plt.cla() 66 | # import pdb;pdb.set_trace() 67 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/visualization.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def viz_trajectory(pred, gt, data, output_dir, batch_idx): 8 | # print(f'Visualizing trajectory of batch {batch_idx}') 9 | if not os.path.isdir(os.path.join(output_dir, 'viz_results', 'trajectory')): os.makedirs(os.path.join(output_dir, 'viz_results', 'trajectory')) 10 | # pred: B*N, T, 15, 3 11 | B = data.num_graphs 12 | T1 = data.x.shape[1] 13 | T2 = gt.shape[1] 14 | x_gt = data.x.cpu().numpy() 15 | y_gt = gt.cpu().numpy() 16 | y_pred = pred[:, :, :2].cpu().detach().numpy() 17 | pos = data.positions.cpu().numpy() 18 | thetas = data.theta.cpu().numpy() 19 | 20 | 21 | for scene_idx in range(B): 22 | fig = plt.figure(figsize=(4, 4)) 23 | ax = fig.add_subplot(111) 24 | xy_mean = np.zeros((0, 2)) 25 | scene_filter = (data.batch==scene_idx).cpu().numpy() 26 | x_scene_gt = x_gt[scene_filter] 27 | y_scene_gt = y_gt[scene_filter] 28 | y_scene_pred = y_pred[scene_filter] 29 | pos0 = pos[scene_filter, 19, :] 30 | theta = thetas[scene_idx] 31 | inv_rot = np.array([[np.cos(theta), np.sin(theta)], 32 | [-np.sin(theta), np.cos(theta)]]) 33 | 34 | N = x_scene_gt.shape[0] 35 | if N < 1: continue 36 | 37 | for agent_idx in range(N): 38 | 39 | agent_pos0 = pos0[agent_idx, :] 40 | 41 | x_scene_agent_gt = np.cumsum(x_scene_gt[agent_idx, :, :], axis=0) 42 | x_scene_agent_gt = np.matmul(x_scene_agent_gt, inv_rot) 43 | x_scene_agent_gt = agent_pos0 - x_scene_agent_gt[19] + x_scene_agent_gt 44 | 45 | y_scene_agent_gt = y_scene_gt[agent_idx, :, :] 46 | y_scene_agent_gt = np.matmul(y_scene_agent_gt, inv_rot) 47 | y_scene_agent_gt = agent_pos0 - y_scene_agent_gt[0] + y_scene_agent_gt 48 | 49 | y_scene_agent_pred = y_scene_pred[agent_idx, :, :] 50 | y_scene_agent_pred = np.matmul(y_scene_agent_pred, inv_rot) 51 | y_scene_agent_pred = agent_pos0 - y_scene_agent_pred[0] + y_scene_agent_pred 52 | 53 | 54 | ax.plot(x_scene_agent_gt[:, 0], x_scene_agent_gt[:, 1], 'ko-', linewidth=0.25, markersize=0.5) 55 | ax.plot(y_scene_agent_gt[:, 0], y_scene_agent_gt[:, 1], 'bo-', linewidth=0.25, markersize=0.5) 56 | ax.plot(y_scene_agent_pred[:, 0], y_scene_agent_pred[:, 1], 'ro-', linewidth=0.25, markersize=0.5) 57 | 58 | #xy_mean = np.concatenate((xy_mean, np.expand_dims(x_scene_gt.mean(0)[:2], axis=0)), axis=0) 59 | #ax.set_xlim([xy_mean.mean(0)[0]-5, xy_mean.mean(0)[0]+5]) 60 | #ax.set_ylim([xy_mean.mean(0)[1]-5, xy_mean.mean(0)[1]+5]) 61 | ax.set_xlabel("x") 62 | ax.set_ylabel("y") 63 | fig.savefig(f'{output_dir}/viz_results/trajectory/batch_{batch_idx}_scene_{scene_idx}.png') 64 | plt.close() 65 | plt.cla() 66 | # import pdb;pdb.set_trace() 67 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/train_opp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models.hivt import HiVT 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(2022) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/v2_fps_2_5_frame_20_oppositeRotate') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=64) 34 | parser.add_argument('--val_batch_size', type=int, default=64) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=8) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minFDE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs/', help='Directory where logs and checkpoints are saved') 46 | parser.add_argument('--encoder_type', type=str, default='traj', choices=['traj', 'traj_pose', 'traj_pose_text']) 47 | parser.add_argument('--log_name', type=str, default='traj_opp_v2', help='Directory where logs and checkpoints are saved') 48 | parser.add_argument('--embed_dim', type=int, default=128) 49 | 50 | parser = HiVT.add_model_specific_args(parser) 51 | args = parser.parse_args() 52 | 53 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 54 | trainer = pl.Trainer.from_argparse_args(args, 55 | callbacks=[model_checkpoint], 56 | default_root_dir=os.path.join(args.log_dir, args.log_name)) 57 | model = HiVT(**vars(args)) 58 | datamodule = jrdbDatamodule.from_argparse_args(args) 59 | trainer.fit(model, datamodule) 60 | 61 | 62 | -------------------------------------------------------------------------------- /HiVT/datasets/jrdb_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from itertools import permutations 16 | from itertools import product 17 | from typing import Callable, Dict, List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import torch 22 | torch.multiprocessing.set_start_method('spawn', force=True) 23 | from torch_geometric.data import Data 24 | from torch_geometric.data import Dataset 25 | from tqdm import tqdm 26 | import sys 27 | sys.path.append('/mnt/jaewoo4tb/textraj/') 28 | from HiVT.utils_hivt import TemporalData 29 | class jrdbDataset(Dataset): 30 | 31 | def __init__(self, 32 | root: str, 33 | split: str, 34 | transform: Optional[Callable] = None) -> None: 35 | print(f'Using jrdb dataset, split {split}') 36 | self._split = split 37 | if split == 'sample': 38 | self._directory = 'forecasting_sample' 39 | elif split == 'train': 40 | self._directory = 'train' 41 | elif split == 'val': 42 | self._directory = 'val' 43 | elif split == 'test': 44 | self._directory = 'test_obs' 45 | else: 46 | raise ValueError(split + ' is not valid') 47 | self.root = root 48 | # Instead of assigning to the property, assign to a private variable 49 | self._processed_file_names = sorted([os.path.splitext(f)[0] + '.pt' for f in os.listdir(self.processed_dir)]) 50 | self._processed_paths = [os.path.join(self.processed_dir, f) for f in self._processed_file_names] 51 | super(jrdbDataset, self).__init__(root, transform=transform) 52 | 53 | @property 54 | def raw_dir(self) -> str: 55 | return os.path.join(self.root, self._directory, 'data') 56 | 57 | @property 58 | def processed_dir(self) -> str: 59 | return os.path.join(self.root, self._directory) 60 | 61 | @property 62 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 63 | return self._raw_file_names 64 | 65 | @property 66 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 67 | return self._processed_file_names # Use the internal variable to return the value 68 | 69 | @property 70 | def processed_paths(self) -> List[str]: 71 | return self._processed_paths 72 | 73 | def len(self) -> int: 74 | return len(self.processed_paths) 75 | 76 | def get(self, idx) -> Data: 77 | return torch.load(self.processed_paths[idx]) 78 | # return TemporalData(**torch.load(self.processed_paths[idx])) 79 | -------------------------------------------------------------------------------- /HiVT/datasets/jrdb_dataset_poseViz.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import os 15 | from itertools import permutations 16 | from itertools import product 17 | from typing import Callable, Dict, List, Optional, Tuple, Union 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import torch 22 | torch.multiprocessing.set_start_method('spawn', force=True) 23 | from torch_geometric.data import Data 24 | from torch_geometric.data import Dataset 25 | from tqdm import tqdm 26 | import sys 27 | sys.path.append('/mnt/jaewoo4tb/textraj/') 28 | from HiVT.utils_hivt import TemporalData 29 | class jrdbDataset_poseViz(Dataset): 30 | 31 | def __init__(self, 32 | root: str, 33 | split: str, 34 | transform: Optional[Callable] = None) -> None: 35 | print(f'Using jrdb dataset, split {split}') 36 | self._split = split 37 | if split == 'sample': 38 | self._directory = 'forecasting_sample' 39 | elif split == 'train': 40 | self._directory = 'train' 41 | elif split == 'val': 42 | self._directory = 'val_testPose' 43 | elif split == 'test': 44 | self._directory = 'test_obs' 45 | else: 46 | raise ValueError(split + ' is not valid') 47 | self.root = root 48 | # Instead of assigning to the property, assign to a private variable 49 | self._processed_file_names = [os.path.splitext(f)[0] + '.pt' for f in os.listdir(self.processed_dir)] 50 | self._processed_paths = [os.path.join(self.processed_dir, f) for f in self._processed_file_names] 51 | super(jrdbDataset_poseViz, self).__init__(root, transform=transform) 52 | 53 | @property 54 | def raw_dir(self) -> str: 55 | return os.path.join(self.root, self._directory, 'data') 56 | 57 | @property 58 | def processed_dir(self) -> str: 59 | return os.path.join(self.root, self._directory) 60 | 61 | @property 62 | def raw_file_names(self) -> Union[str, List[str], Tuple]: 63 | return self._raw_file_names 64 | 65 | @property 66 | def processed_file_names(self) -> Union[str, List[str], Tuple]: 67 | return self._processed_file_names # Use the internal variable to return the value 68 | 69 | @property 70 | def processed_paths(self) -> List[str]: 71 | return self._processed_paths 72 | 73 | def len(self) -> int: 74 | return len(self.processed_paths) 75 | 76 | def get(self, idx) -> Data: 77 | return torch.load(self.processed_paths[idx]) 78 | # return TemporalData(**torch.load(self.processed_paths[idx])) 79 | -------------------------------------------------------------------------------- /HiVT/train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import os 3 | import torch 4 | torch.multiprocessing.set_sharing_strategy('file_system') 5 | 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.callbacks import ModelCheckpoint 8 | 9 | from datamodules import jrdbDatamodule 10 | from models.hivt_v2 import HiVT 11 | 12 | if __name__ == '__main__': 13 | pl.seed_everything(1110) 14 | 15 | parser = ArgumentParser() 16 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 17 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 18 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 19 | # parser.add_argument('--embed_dim', type=int, default=64) 20 | parser.add_argument('--train_batch_size', type=int, default=64) 21 | parser.add_argument('--val_batch_size', type=int, default=64) 22 | parser.add_argument('--shuffle', type=bool, default=True) 23 | parser.add_argument('--num_workers', type=int, default=1) 24 | parser.add_argument('--pin_memory', type=bool, default=True) 25 | parser.add_argument('--persistent_workers', type=bool, default=True) 26 | parser.add_argument('--gpus', type=int, default=1) 27 | parser.add_argument('--max_epochs', type=int, default=64) 28 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 29 | parser.add_argument('--save_top_k', type=int, default=5) 30 | 31 | # # # # # # Editables 32 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb_curFixed/', help='Directory where logs and checkpoints are saved') 33 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 34 | parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj_v3', 'traj_pose', 'traj_pose_text_v3', 'traj_text', ]) 35 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 36 | parser.add_argument('--embed_dim', type=int, default=128) 37 | parser.add_argument('--emb_numHead', type=int, default=4) 38 | parser.add_argument('--emb_dropout', type=int, default=0.5) 39 | parser.add_argument('--emb_numLayers', type=int, default=2) 40 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 41 | parser.add_argument('--log_name', type=str, default='debug', help='Directory where logs and checkpoints are saved') 42 | 43 | parser = HiVT.add_model_specific_args(parser) 44 | args = parser.parse_args() 45 | 46 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 47 | trainer = pl.Trainer.from_argparse_args(args, 48 | callbacks=[model_checkpoint], 49 | default_root_dir=os.path.join(args.log_dir, args.log_name), 50 | # num_sanity_val_steps=0, 51 | ) 52 | model = HiVT(**vars(args)) 53 | datamodule = jrdbDatamodule.from_argparse_args(args) 54 | trainer.fit(model, datamodule) 55 | -------------------------------------------------------------------------------- /romp/vis_human/sim3drender/lib/rasterize.h: -------------------------------------------------------------------------------- 1 | #ifndef MESH_CORE_HPP_ 2 | #define MESH_CORE_HPP_ 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | using namespace std; 12 | 13 | class Point3D { 14 | public: 15 | float x; 16 | float y; 17 | float z; 18 | 19 | public: 20 | Point3D() : x(0.f), y(0.f), z(0.f) {} 21 | Point3D(float x_, float y_, float z_) : x(x_), y(y_), z(z_) {} 22 | 23 | void initialize(float x_, float y_, float z_){ 24 | this->x = x_; this->y = y_; this->z = z_; 25 | } 26 | 27 | Point3D cross(Point3D &p){ 28 | Point3D c; 29 | c.x = this->y * p.z - this->z * p.y; 30 | c.y = this->z * p.x - this->x * p.z; 31 | c.z = this->x * p.y - this->y * p.x; 32 | return c; 33 | } 34 | 35 | float dot(Point3D &p) { 36 | return this->x * p.x + this->y * p.y + this->z * p.z; 37 | } 38 | 39 | Point3D operator-(const Point3D &p) { 40 | Point3D np; 41 | np.x = this->x - p.x; 42 | np.y = this->y - p.y; 43 | np.z = this->z - p.z; 44 | return np; 45 | } 46 | 47 | }; 48 | 49 | class Point { 50 | public: 51 | float x; 52 | float y; 53 | 54 | public: 55 | Point() : x(0.f), y(0.f) {} 56 | Point(float x_, float y_) : x(x_), y(y_) {} 57 | float dot(Point p) { 58 | return this->x * p.x + this->y * p.y; 59 | } 60 | 61 | Point operator-(const Point &p) { 62 | Point np; 63 | np.x = this->x - p.x; 64 | np.y = this->y - p.y; 65 | return np; 66 | } 67 | 68 | Point operator+(const Point &p) { 69 | Point np; 70 | np.x = this->x + p.x; 71 | np.y = this->y + p.y; 72 | return np; 73 | } 74 | 75 | Point operator*(float s) { 76 | Point np; 77 | np.x = s * this->x; 78 | np.y = s * this->y; 79 | return np; 80 | } 81 | }; 82 | 83 | 84 | bool is_point_in_tri(Point p, Point p0, Point p1, Point p2); 85 | 86 | void get_point_weight(float *weight, Point p, Point p0, Point p1, Point p2); 87 | 88 | void _get_tri_normal(float *tri_normal, float *vertices, int *triangles, int ntri, bool norm_flg); 89 | 90 | void _get_ver_normal(float *ver_normal, float *tri_normal, int *triangles, int nver, int ntri); 91 | 92 | void _get_normal(float *ver_normal, float *vertices, int *triangles, int nver, int ntri); 93 | 94 | void _rasterize_triangles( 95 | float *vertices, int *triangles, float *depth_buffer, int *triangle_buffer, float *barycentric_weight, 96 | int ntri, int h, int w); 97 | 98 | void _rasterize( 99 | unsigned char *image, float *vertices, int *triangles, float *colors, 100 | float *depth_buffer, int ntri, int h, int w, int c, float alpha, bool reverse); 101 | 102 | void _render_texture_core( 103 | float *image, float *vertices, int *triangles, 104 | float *texture, float *tex_coords, int *tex_triangles, 105 | float *depth_buffer, 106 | int nver, int tex_nver, int ntri, 107 | int h, int w, int c, 108 | int tex_h, int tex_w, int tex_c, 109 | int mapping_type); 110 | 111 | void _write_obj_with_colors_texture(string filename, string mtl_name, 112 | float *vertices, int *triangles, float *colors, float *uv_coords, 113 | int nver, int ntri, int ntexver); 114 | 115 | #endif -------------------------------------------------------------------------------- /HiVT/___scripts_copy/README.md: -------------------------------------------------------------------------------- 1 | # HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction 2 | This repository contains the official implementation of [HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhou_HiVT_Hierarchical_Vector_Transformer_for_Multi-Agent_Motion_Prediction_CVPR_2022_paper.pdf) published in CVPR 2022. 3 | 4 | ![](assets/overview.png) 5 | 6 | ## Gettting Started 7 | 8 | 1\. Clone this repository: 9 | ``` 10 | git clone https://github.com/ZikangZhou/HiVT.git 11 | cd HiVT 12 | ``` 13 | 14 | 2\. Create a conda environment and install the dependencies: 15 | ``` 16 | conda create -n HiVT python=3.8 17 | conda activate HiVT 18 | conda install pytorch==1.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge 19 | conda install pytorch-geometric==1.7.2 -c rusty1s -c conda-forge 20 | conda install pytorch-lightning==1.5.2 -c conda-forge 21 | ``` 22 | 23 | 3\. Download [Argoverse Motion Forecasting Dataset v1.1](https://www.argoverse.org/av1.html). After downloading and extracting the tar.gz files, the dataset directory should be organized as follows: 24 | ``` 25 | /path/to/dataset_root/ 26 | ├── train/ 27 | | └── data/ 28 | | ├── 1.csv 29 | | ├── 2.csv 30 | | ├── ... 31 | └── val/ 32 | └── data/ 33 | ├── 1.csv 34 | ├── 2.csv 35 | ├── ... 36 | ``` 37 | 38 | 4\. Install [Argoverse 1 API](https://github.com/argoai/argoverse-api). 39 | 40 | ## Training 41 | 42 | To train HiVT-64: 43 | ``` 44 | python train.py --root /path/to/dataset_root/ --embed_dim 64 45 | ``` 46 | 47 | To train HiVT-128: 48 | ``` 49 | python train.py --root /path/to/dataset_root/ --embed_dim 128 50 | ``` 51 | 52 | **Note**: When running the training script for the first time, it will take several hours to preprocess the data (~3.5 hours on my machine). Training on an RTX 2080 Ti GPU takes 35-40 minutes per epoch. 53 | 54 | During training, the checkpoints will be saved in `lightning_logs/` automatically. To monitor the training process: 55 | ``` 56 | tensorboard --logdir lightning_logs/ 57 | ``` 58 | 59 | ## Evaluation 60 | 61 | To evaluate the prediction performance: 62 | ``` 63 | python eval.py --root /path/to/dataset_root/ --batch_size 32 --ckpt_path /path/to/your_checkpoint.ckpt 64 | ``` 65 | 66 | ## Pretrained Models 67 | 68 | We provide the pretrained HiVT-64 and HiVT-128 in [checkpoints/](checkpoints). You can evaluate the pretrained models using the aforementioned evaluation command, or have a look at the training process via TensorBoard: 69 | ``` 70 | tensorboard --logdir checkpoints/ 71 | ``` 72 | 73 | ## Results 74 | 75 | ### Quantitative Results 76 | 77 | For this repository, the expected performance on Argoverse 1.1 validation set is: 78 | 79 | | Models | minADE | minFDE | MR | 80 | | :--- | :---: | :---: | :---: | 81 | | HiVT-64 | 0.69 | 1.03 | 0.10 | 82 | | HiVT-128 | 0.66 | 0.97 | 0.09 | 83 | 84 | ### Qualitative Results 85 | 86 | ![](assets/visualization.png) 87 | 88 | ## Citation 89 | 90 | If you found this repository useful, please consider citing our work: 91 | 92 | ``` 93 | @inproceedings{zhou2022hivt, 94 | title={HiVT: Hierarchical Vector Transformer for Multi-Agent Motion Prediction}, 95 | author={Zhou, Zikang and Ye, Luyao and Wang, Jianping and Wu, Kui and Lu, Kejie}, 96 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 97 | year={2022} 98 | } 99 | ``` 100 | 101 | ## License 102 | 103 | This repository is licensed under [Apache 2.0](LICENSE). 104 | 105 | -------------------------------------------------------------------------------- /HiVT/validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from torch_geometric.data import DataLoader 24 | 25 | from datasets import jrdbDataset 26 | 27 | from datamodules import jrdbDatamodule 28 | # from models.hivt_kd_ import HiVT 29 | # from models.hivt_v3 import HiVT 30 | from HiVT.models.hivt_kd_onlyAgent_v4_dummy_teachV3 import HiVT # teacher v3 31 | 32 | if __name__ == '__main__': 33 | pl.seed_everything(2024) 34 | 35 | parser = ArgumentParser() 36 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 37 | # parser.add_argument('--embed_dim', type=int, default=64) 38 | parser.add_argument('--train_batch_size', type=int, default=64) 39 | parser.add_argument('--val_batch_size', type=int, default=32) 40 | parser.add_argument('--shuffle', type=bool, default=True) 41 | parser.add_argument('--num_workers', type=int, default=1) 42 | parser.add_argument('--pin_memory', type=bool, default=True) 43 | parser.add_argument('--persistent_workers', type=bool, default=True) 44 | parser.add_argument('--gpus', type=int, default=1) 45 | parser.add_argument('--max_epochs', type=int, default=64) 46 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 47 | parser.add_argument('--save_top_k', type=int, default=5) 48 | 49 | # # # # # # Editables 50 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_withViz/', help='Directory where logs and checkpoints are saved') 51 | parser.add_argument('--encoder_type', type=str, default='traj_pose_text_intra', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 52 | parser.add_argument('--log_name', type=str, default='debug', help='Directory where logs and checkpoints are saved') 53 | parser.add_argument('--embed_dim', type=int, default=128) 54 | parser.add_argument('--emb_numHead', type=int, default=8) 55 | parser.add_argument('--emb_dropout', type=int, default=0.5) 56 | parser.add_argument('--emb_numLayers', type=int, default=1) 57 | 58 | parser.add_argument('--ckpt_path', type=str, default='your_ckpt.ckpt') 59 | 60 | parser = HiVT.add_model_specific_args(parser) 61 | args = parser.parse_args() 62 | 63 | trainer = pl.Trainer.from_argparse_args(args, 64 | default_root_dir=os.path.join(args.log_dir, args.log_name), 65 | # num_sanity_val_steps=0, 66 | ) 67 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=False) 68 | val_dataset = jrdbDataset(root=args.root, split='val') 69 | dataloader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, 70 | pin_memory=args.pin_memory, persistent_workers=args.persistent_workers) 71 | trainer.validate(model, dataloader) 72 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/validate.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from torch_geometric.data import DataLoader 24 | 25 | from datasets import jrdbDataset 26 | 27 | from datamodules import jrdbDatamodule 28 | # from models.hivt_kd_ import HiVT 29 | # from models.hivt_v3 import HiVT 30 | from HiVT.models.hivt_kd_onlyAgent_v4_dummy_teachV3 import HiVT # teacher v3 31 | 32 | if __name__ == '__main__': 33 | pl.seed_everything(2024) 34 | 35 | parser = ArgumentParser() 36 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 37 | # parser.add_argument('--embed_dim', type=int, default=64) 38 | parser.add_argument('--train_batch_size', type=int, default=64) 39 | parser.add_argument('--val_batch_size', type=int, default=32) 40 | parser.add_argument('--shuffle', type=bool, default=True) 41 | parser.add_argument('--num_workers', type=int, default=1) 42 | parser.add_argument('--pin_memory', type=bool, default=True) 43 | parser.add_argument('--persistent_workers', type=bool, default=True) 44 | parser.add_argument('--gpus', type=int, default=1) 45 | parser.add_argument('--max_epochs', type=int, default=64) 46 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 47 | parser.add_argument('--save_top_k', type=int, default=5) 48 | 49 | # # # # # # Editables 50 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_withViz/', help='Directory where logs and checkpoints are saved') 51 | parser.add_argument('--encoder_type', type=str, default='traj_pose_text_intra', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 52 | parser.add_argument('--log_name', type=str, default='debug', help='Directory where logs and checkpoints are saved') 53 | parser.add_argument('--embed_dim', type=int, default=128) 54 | parser.add_argument('--emb_numHead', type=int, default=8) 55 | parser.add_argument('--emb_dropout', type=int, default=0.5) 56 | parser.add_argument('--emb_numLayers', type=int, default=1) 57 | 58 | parser.add_argument('--ckpt_path', type=str, default='your_ckpt.ckpt') 59 | 60 | parser = HiVT.add_model_specific_args(parser) 61 | args = parser.parse_args() 62 | 63 | trainer = pl.Trainer.from_argparse_args(args, 64 | default_root_dir=os.path.join(args.log_dir, args.log_name), 65 | # num_sanity_val_steps=0, 66 | ) 67 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=False) 68 | val_dataset = jrdbDataset(root=args.root, split='val') 69 | dataloader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, 70 | pin_memory=args.pin_memory, persistent_workers=args.persistent_workers) 71 | trainer.validate(model, dataloader) 72 | -------------------------------------------------------------------------------- /HiVT/train_ethucy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models.hivt_v2_ethucy import HiVT 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(9999) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_eth') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=64) 34 | parser.add_argument('--val_batch_size', type=int, default=64) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=1) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_ethucy_intra/eth/', help='Directory where logs and checkpoints are saved') 46 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 47 | parser.add_argument('--encoder_type', type=str, default='traj_text', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_v1', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text']) 48 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 49 | parser.add_argument('--embed_dim', type=int, default=128) 50 | parser.add_argument('--emb_numHead', type=int, default=4) 51 | parser.add_argument('--emb_dropout', type=int, default=0.5) 52 | parser.add_argument('--emb_numLayers', type=int, default=2) 53 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 54 | parser.add_argument('--log_name', type=str, default='_V2InstCorrected_hivtV2_TraText_allTimestep_1.5', help='Directory where logs and checkpoints are saved') 55 | 56 | parser = HiVT.add_model_specific_args(parser) 57 | args = parser.parse_args() 58 | 59 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 60 | trainer = pl.Trainer.from_argparse_args(args, 61 | callbacks=[model_checkpoint], 62 | default_root_dir=os.path.join(args.log_dir, args.log_name), 63 | # num_sanity_val_steps=0, 64 | ) 65 | model = HiVT(**vars(args)) 66 | datamodule = jrdbDatamodule.from_argparse_args(args) 67 | trainer.fit(model, datamodule) 68 | -------------------------------------------------------------------------------- /HiVT/___models_copy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from models.decoder import GRUDecoder 15 | from models.decoder import MLPDecoder 16 | from models.embedding import MultipleInputEmbedding 17 | from models.embedding import SingleInputEmbedding 18 | from models.embedding import TrajPoseEmbedding, MultipleInputEmbedding_mask, MultipleInputEmbedding_mask_v2, MultipleInputEmbedding_mask_v3, MultipleInputEmbedding_mask_v4, MultipleInputEmbedding_mask_v5 19 | from models.global_interactor import GlobalInteractor 20 | from models.global_interactor import GlobalInteractorLayer 21 | from models.global_interactor_text import GlobalInteractor_text 22 | 23 | from models.local_encoder import AAEncoder 24 | from models.local_encoder import ALEncoder 25 | from models.local_encoder import LocalEncoder 26 | from models.local_encoder import TemporalEncoder 27 | from models.local_encoder import TemporalEncoderLayer 28 | from models.local_encoder_onlyTraj import LocalEncoder_onlyTraj 29 | from models.local_encoder_onlyTraj_v2 import LocalEncoder_onlyTraj_v2 30 | from models.local_encoder_onlyTraj_v3 import LocalEncoder_onlyTraj_v3 31 | from models.local_encoder_wPose import LocalEncoder_wPose 32 | from models.local_encoder_wPose_joint import LocalEncoder_wPose_joint 33 | from models.local_encoder_wPose_withAngle import LocalEncoder_wPose_withAngle 34 | from models.local_encoder_wText import LocalEncoder_wText 35 | from models.local_encoder_wText_intra import LocalEncoder_wText_intra 36 | from models.local_encoder_wPoseText import LocalEncoder_wPoseText 37 | from models.local_encoder_wPoseText_intra import LocalEncoder_wPoseText_intra 38 | from models.local_encoder_wPoseText_intra_joint import LocalEncoder_wPoseText_intra_joint 39 | from models.local_encoder_wPoseText_embedMask import LocalEncoder_wPoseText_embedMask 40 | from models.local_encoder_wPoseText_embedMask_teacher import LocalEncoder_wPoseText_embedMask_teacher 41 | from models.local_encoder_wPoseText_v1 import LocalEncoder_wPoseText_v1 42 | from models.local_encoder_wPoseText_v2 import LocalEncoder_wPoseText_v2 43 | from models.local_encoder_wPoseText_v3 import LocalEncoder_wPoseText_v3 44 | from models.local_encoder_wPose_8steps import LocalEncoder_wPose_8steps 45 | from models.local_encoder_wPose_withoutRotation import LocalEncoder_wPose_withoutRotation 46 | from models.local_encoder_wPoseText_woRotation import LocalEncoder_wPoseText_woRotation 47 | 48 | # Text ablation 49 | from models.local_encoder_wText_egoOther import LocalEncoder_wText_egoOther 50 | from models.local_encoder_wText_egoOtherInteraction import LocalEncoder_wText_egoOtherInteraction 51 | from models.local_encoder_wText_onlyEgo import LocalEncoder_wText_onlyEgo 52 | from models.local_encoder_wPoseText_onlyEgo import LocalEncoder_wPoseText_onlyEgo 53 | from models.local_encoder_wPoseText_egoOther import LocalEncoder_wPoseText_egoOther 54 | from models.local_encoder_wPoseText_egoOtherInteraction import LocalEncoder_wPoseText_egoOtherInteraction 55 | 56 | # ETH/UCY 57 | from models.local_encoder_wPose_ethucy import LocalEncoder_wPose_ethucy 58 | from models.local_encoder_wPoseText_ethucy import LocalEncoder_wPoseText_ethucy 59 | from models.local_encoder_wPoseText_ethucy_v1 import LocalEncoder_wPoseText_ethucy_v1 60 | from models.local_encoder_wText_ethucy import LocalEncoder_wText_ethucy 61 | # from models.local_encoder_original import LocalEncoder_original 62 | 63 | # SIT 64 | from models.local_encoder_wPoseText_sit import LocalEncoder_wPoseText_sit -------------------------------------------------------------------------------- /HiVT/___scripts_copy/train_ethucy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models.hivt_v2_ethucy import HiVT 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(9999) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_eth') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=64) 34 | parser.add_argument('--val_batch_size', type=int, default=64) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=1) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_ethucy_intra/eth/', help='Directory where logs and checkpoints are saved') 46 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 47 | parser.add_argument('--encoder_type', type=str, default='traj_text', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_v1', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text']) 48 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 49 | parser.add_argument('--embed_dim', type=int, default=128) 50 | parser.add_argument('--emb_numHead', type=int, default=4) 51 | parser.add_argument('--emb_dropout', type=int, default=0.5) 52 | parser.add_argument('--emb_numLayers', type=int, default=2) 53 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 54 | parser.add_argument('--log_name', type=str, default='_V2InstCorrected_hivtV2_TraText_allTimestep_1.5', help='Directory where logs and checkpoints are saved') 55 | 56 | parser = HiVT.add_model_specific_args(parser) 57 | args = parser.parse_args() 58 | 59 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 60 | trainer = pl.Trainer.from_argparse_args(args, 61 | callbacks=[model_checkpoint], 62 | default_root_dir=os.path.join(args.log_dir, args.log_name), 63 | # num_sanity_val_steps=0, 64 | ) 65 | model = HiVT(**vars(args)) 66 | datamodule = jrdbDatamodule.from_argparse_args(args) 67 | trainer.fit(model, datamodule) 68 | -------------------------------------------------------------------------------- /HiVT/_train_mart_ethucy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models._mart_v3_ethucy import mart 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(7575) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_eth') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=32) 34 | parser.add_argument('--val_batch_size', type=int, default=32) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=1) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_mart/ethucy/eth', help='Directory where logs and checkpoints are saved') 46 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 47 | parser.add_argument('--hidden_dim', type=int, default=128) 48 | parser.add_argument('--decoder_hidden_dim', type=int, default=128) 49 | parser.add_argument('--model_dim', type=int, default=128) 50 | parser.add_argument('--dropout_mart', type=float, default=0.0) 51 | parser.add_argument('--num_head_mart', type=int, default=8) 52 | parser.add_argument('--num_layers', type=int, default=2) 53 | parser.add_argument('--aggregation', type=str, default='avg') 54 | parser.add_argument('--function_type', type=int, default=2) 55 | parser.add_argument('--hyper_scales', type=list, default=[1, 'adaptive']) 56 | parser.add_argument('--modality', type=str, default='traj') # traj, traj_pose 57 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 58 | parser.add_argument('--log_name', type=str, default='martV1_Tuned_allTimestep_sceneAveFixed', help='Directory where logs and checkpoints are saved') 59 | 60 | parser = mart.add_model_specific_args(parser) 61 | args = parser.parse_args() 62 | 63 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 64 | trainer = pl.Trainer.from_argparse_args(args, 65 | callbacks=[model_checkpoint], 66 | default_root_dir=os.path.join(args.log_dir, args.log_name), 67 | # num_sanity_val_steps=0, 68 | ) 69 | model = mart(**vars(args)) 70 | datamodule = jrdbDatamodule.from_argparse_args(args) 71 | trainer.fit(model, datamodule) 72 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/_train_mart_ethucy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models._mart_v3_ethucy import mart 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(7575) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_eth') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=32) 34 | parser.add_argument('--val_batch_size', type=int, default=32) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=1) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_mart/ethucy/eth', help='Directory where logs and checkpoints are saved') 46 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 47 | parser.add_argument('--hidden_dim', type=int, default=128) 48 | parser.add_argument('--decoder_hidden_dim', type=int, default=128) 49 | parser.add_argument('--model_dim', type=int, default=128) 50 | parser.add_argument('--dropout_mart', type=float, default=0.0) 51 | parser.add_argument('--num_head_mart', type=int, default=8) 52 | parser.add_argument('--num_layers', type=int, default=2) 53 | parser.add_argument('--aggregation', type=str, default='avg') 54 | parser.add_argument('--function_type', type=int, default=2) 55 | parser.add_argument('--hyper_scales', type=list, default=[1, 'adaptive']) 56 | parser.add_argument('--modality', type=str, default='traj') # traj, traj_pose 57 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 58 | parser.add_argument('--log_name', type=str, default='martV1_Tuned_allTimestep_sceneAveFixed', help='Directory where logs and checkpoints are saved') 59 | 60 | parser = mart.add_model_specific_args(parser) 61 | args = parser.parse_args() 62 | 63 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 64 | trainer = pl.Trainer.from_argparse_args(args, 65 | callbacks=[model_checkpoint], 66 | default_root_dir=os.path.join(args.log_dir, args.log_name), 67 | # num_sanity_val_steps=0, 68 | ) 69 | model = mart(**vars(args)) 70 | datamodule = jrdbDatamodule.from_argparse_args(args) 71 | trainer.fit(model, datamodule) 72 | -------------------------------------------------------------------------------- /HiVT/_train_socialTransmotion_ethucy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models._social_transmotion_v1_ethucy import socialTransmotion 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(7979) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_students') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=32) 34 | parser.add_argument('--val_batch_size', type=int, default=32) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=1) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_socialTransmotion/ethucy/students', help='Directory where logs and checkpoints are saved') 46 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 47 | parser.add_argument('--embed_dim', type=int, default=128) 48 | parser.add_argument('--seq_len', type=int, default=212) 49 | parser.add_argument('--token_num', type=int, default=25) 50 | parser.add_argument('--num_layers_local', type=int, default=6) 51 | parser.add_argument('--num_layers_global', type=int, default=3) 52 | parser.add_argument('--dim_feedforward', type=int, default=1024) 53 | parser.add_argument('--output_scale', type=int, default=1) 54 | parser.add_argument('--num_head_st', type=int, default=4) 55 | parser.add_argument('--modality', type=str, default='traj') # traj, traj_pose 56 | parser.add_argument('--onlyFullTimestep', type=bool, default=True) 57 | parser.add_argument('--log_name', type=str, default='socialTransmotionV1_tuned_onlyFullTimestep_fixed_noRotate', help='Directory where logs and checkpoints are saved') 58 | 59 | parser = socialTransmotion.add_model_specific_args(parser) 60 | args = parser.parse_args() 61 | 62 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 63 | trainer = pl.Trainer.from_argparse_args(args, 64 | callbacks=[model_checkpoint], 65 | default_root_dir=os.path.join(args.log_dir, args.log_name), 66 | # num_sanity_val_steps=0, 67 | ) 68 | model = socialTransmotion(**vars(args)) 69 | datamodule = jrdbDatamodule.from_argparse_args(args) 70 | trainer.fit(model, datamodule) 71 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/_train_socialTransmotion_ethucy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models._social_transmotion_v1_ethucy import socialTransmotion 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(7979) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_students') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=32) 34 | parser.add_argument('--val_batch_size', type=int, default=32) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=1) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_socialTransmotion/ethucy/students', help='Directory where logs and checkpoints are saved') 46 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 47 | parser.add_argument('--embed_dim', type=int, default=128) 48 | parser.add_argument('--seq_len', type=int, default=212) 49 | parser.add_argument('--token_num', type=int, default=25) 50 | parser.add_argument('--num_layers_local', type=int, default=6) 51 | parser.add_argument('--num_layers_global', type=int, default=3) 52 | parser.add_argument('--dim_feedforward', type=int, default=1024) 53 | parser.add_argument('--output_scale', type=int, default=1) 54 | parser.add_argument('--num_head_st', type=int, default=4) 55 | parser.add_argument('--modality', type=str, default='traj') # traj, traj_pose 56 | parser.add_argument('--onlyFullTimestep', type=bool, default=True) 57 | parser.add_argument('--log_name', type=str, default='socialTransmotionV1_tuned_onlyFullTimestep_fixed_noRotate', help='Directory where logs and checkpoints are saved') 58 | 59 | parser = socialTransmotion.add_model_specific_args(parser) 60 | args = parser.parse_args() 61 | 62 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 63 | trainer = pl.Trainer.from_argparse_args(args, 64 | callbacks=[model_checkpoint], 65 | default_root_dir=os.path.join(args.log_dir, args.log_name), 66 | # num_sanity_val_steps=0, 67 | ) 68 | model = socialTransmotion(**vars(args)) 69 | datamodule = jrdbDatamodule.from_argparse_args(args) 70 | trainer.fit(model, datamodule) 71 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/_train_socialTransmotion_ethucy_v1_1.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models._social_transmotion_v1_1_ethucy import socialTransmotion 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(7979) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_students') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=32) 34 | parser.add_argument('--val_batch_size', type=int, default=32) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=1) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_socialTransmotion/ethucy/students', help='Directory where logs and checkpoints are saved') 46 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 47 | parser.add_argument('--embed_dim', type=int, default=128) 48 | parser.add_argument('--seq_len', type=int, default=212) 49 | parser.add_argument('--token_num', type=int, default=25) 50 | parser.add_argument('--num_layers_local', type=int, default=6) 51 | parser.add_argument('--num_layers_global', type=int, default=3) 52 | parser.add_argument('--dim_feedforward', type=int, default=1024) 53 | parser.add_argument('--output_scale', type=int, default=1) 54 | parser.add_argument('--num_head_st', type=int, default=4) 55 | parser.add_argument('--modality', type=str, default='traj') # traj, traj_pose 56 | parser.add_argument('--onlyFullTimestep', type=bool, default=True) 57 | parser.add_argument('--log_name', type=str, default='socialTransmotionV1_1_tuned_onlyFullTimestep_fixed_noRotate', help='Directory where logs and checkpoints are saved') 58 | 59 | parser = socialTransmotion.add_model_specific_args(parser) 60 | args = parser.parse_args() 61 | 62 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 63 | trainer = pl.Trainer.from_argparse_args(args, 64 | callbacks=[model_checkpoint], 65 | default_root_dir=os.path.join(args.log_dir, args.log_name), 66 | # num_sanity_val_steps=0, 67 | ) 68 | model = socialTransmotion(**vars(args)) 69 | datamodule = jrdbDatamodule.from_argparse_args(args) 70 | trainer.fit(model, datamodule) 71 | -------------------------------------------------------------------------------- /parser_model/calibration/cameras.yaml: -------------------------------------------------------------------------------- 1 | stitching: 2 | radius: 3360000 3 | rotation: 0 4 | scalewidth: 1831 5 | crop: 1 6 | cameras: 7 | # camera order matters! 8 | sensor_0: 9 | width: 752 10 | height: 480 11 | D: -0.336591 0.159742 0.00012697 -7.22557e-05 -0.0461953 12 | # K = fx 0 cx 13 | # 0 fy cy 14 | # 0 0 1 15 | K: > 16 | 476.71 0 350.738 17 | 0 479.505 209.532 18 | 0 0 1 19 | R: > 20 | 0.999994 0.000654539 0.00340293 21 | -0.000654519 1 -6.81963e-06 22 | -0.00340293 4.59231e-06 0.999994 23 | T: -0.0104242 -3.70974 -56.9177 24 | sensor_1: 25 | width: 752 26 | height: 480 27 | D: -0.335073 0.151959 -0.000232061 0.00032014 -0.0396825 28 | K: > 29 | 483.254 0 365.33 30 | 0 485.78 210.953 31 | 0 0 1 32 | R: > 33 | 0.305706 -0.00895443 -0.952084 34 | 0.0110396 0.999922 -0.00585963 35 | 0.952062 -0.0087193 0.305781 36 | T: 0.93957 -4.05131 -52.03 37 | sensor_2: 38 | width: 752 39 | height: 480 40 | D: -0.338469 0.156256 -0.000385467 0.000295485 -0.0401965 41 | K: > 42 | 483.911 0 355.144 43 | 0 486.466 223.026 44 | 0 0 1 45 | R: > 46 | -0.806828 0.0136361 -0.590629 47 | 0.00870468 0.999899 0.011194 48 | 0.590723 0.00389039 -0.806865 49 | T: -0.25753 -6.54978 -47.7311 50 | sensor_3: 51 | width: 752 52 | height: 480 53 | D: -0.330848 0.14747 8.59247e-05 0.000262599 -0.0385311 54 | K: > 55 | 475.807 0 339.53 56 | 0 478.371 188.481 57 | 0 0 1 58 | R: > 59 | -0.811334 0.0033829 0.584574 60 | 0.00046071 0.999987 -0.00514746 61 | -0.584583 -0.00390699 -0.811324 62 | T: 2.72207 -6.82928 -45.9778 63 | sensor_4: 64 | width: 752 65 | height: 480 66 | D: -0.34064 0.168338 0.000147292 0.000229372 -0.0516133 67 | K: > 68 | 485.046 0 368.864 69 | 0 488.185 208.215 70 | 0 0 1 71 | R: > 72 | 0.310275 0.00160497 0.950645 73 | -0.00648686 0.999979 0.000428942 74 | -0.950625 -0.00629979 0.310279 75 | T: -0.333857 -5.12974 -56.0573 76 | sensor_5: 77 | width: 752 78 | height: 480 79 | D: -0.338422 0.163703 -0.000376267 7.73351e-06 -0.0479871 80 | K: > 81 | 478.406 0 353.499 82 | 0 481.322 190.225 83 | 0 0 1 84 | R: > 85 | 0.999995 0.00282205 0.00163291 86 | -0.00282345 0.999996 0.000852931 87 | -0.00163049 -0.000857537 0.999998 88 | T: -0.903588 -126.851 -56.6256 89 | sensor_6: 90 | width: 752 91 | height: 480 92 | D: -0.340676 0.165511 -0.00035978 0.000181532 -0.0493721 93 | K: > 94 | 480.459 0 362.503 95 | 0 482.924 197.949 96 | 0 0 1 97 | R: > 98 | 0.308288 -0.0110391 -0.951229 99 | -0.000933102 0.999929 -0.0119067 100 | 0.951293 0.00455829 0.308256 101 | T: 1.74525 -127.214 -51.7722 102 | sensor_7: 103 | width: 752 104 | height: 480 105 | D: -0.344379 0.170343 -0.000137847 0.000141047 -0.0510536 106 | K: > 107 | 486.491 0 361.559 108 | 0 489.22 210.547 109 | 0 0 1 110 | R: > 111 | -0.808201 0.0313998 -0.588068 112 | 0.026057 0.999506 0.0175574 113 | 0.588329 -0.00113337 -0.808621 114 | T: -2.56535 -129.191 -47.5803 115 | sensor_8: 116 | width: 752 117 | height: 480 118 | D: -0.331228 0.144696 0.000117553 0.000566449 -0.0343506 119 | K: > 120 | 476.708 0 354.16 121 | 0 479.424 209.383 122 | 0 0 1 123 | R: > 124 | -0.807384 -0.00296577 0.590019 125 | -0.0122001 0.999857 -0.0116688 126 | -0.589901 -0.0166195 -0.807305 127 | T: 3.39727 -129.381 -45.2409 128 | sensor_9: 129 | width: 752 130 | height: 480 131 | D: -0.345189 0.180808 0.000276465 0.000131868 -0.062103 132 | K: > 133 | 484.219 0 345.303 134 | 0 487.312 192.371 135 | 0 0 1 136 | R: > 137 | 0.308505 0.00370159 0.951215 138 | -0.00403535 0.999988 -0.00258261 139 | -0.951214 -0.00304174 0.308517 140 | T: 0.354966 -128.218 -54.0617 141 | 142 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/_train_mart_ethucy_lightning2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models._mart_ethucy import mart 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(8989) 29 | 30 | parser = ArgumentParser() 31 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_zara2') 32 | # parser.add_argument('--embed_dim', type=int, default=64) 33 | parser.add_argument('--train_batch_size', type=int, default=64) 34 | parser.add_argument('--val_batch_size', type=int, default=64) 35 | parser.add_argument('--shuffle', type=bool, default=True) 36 | parser.add_argument('--num_workers', type=int, default=1) 37 | parser.add_argument('--pin_memory', type=bool, default=True) 38 | parser.add_argument('--persistent_workers', type=bool, default=True) 39 | parser.add_argument('--gpus', type=int, default=1) 40 | parser.add_argument('--max_epochs', type=int, default=64) 41 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 42 | parser.add_argument('--save_top_k', type=int, default=5) 43 | 44 | # # # # # # Editables 45 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_mart/ethucy/zara2', help='Directory where logs and checkpoints are saved') 46 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 47 | parser.add_argument('--hidden_dim', type=int, default=128) 48 | parser.add_argument('--decoder_hidden_dim', type=int, default=128) 49 | parser.add_argument('--model_dim', type=int, default=128) 50 | parser.add_argument('--dropout_mart', type=float, default=0.1) 51 | parser.add_argument('--num_head_mart', type=int, default=8) 52 | parser.add_argument('--num_layers', type=int, default=2) 53 | parser.add_argument('--aggregation', type=str, default='avg') 54 | parser.add_argument('--function_type', type=int, default=2) 55 | parser.add_argument('--hyper_scales', type=list, default=[1, 'adaptive']) 56 | parser.add_argument('--modality', type=str, default='traj') # traj, traj_pose 57 | parser.add_argument('--log_name', type=str, default='martV1_3e-4_FullLoss_Traj_FullLoss_dropout0.1_unifiedDecoding_128', help='Directory where logs and checkpoints are saved') 58 | 59 | parser = mart.add_model_specific_args(parser) 60 | args = parser.parse_args() 61 | 62 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 63 | trainer = pl.Trainer( 64 | max_epochs=args.max_epochs, 65 | callbacks=[model_checkpoint], 66 | default_root_dir=os.path.join(args.log_dir, args.log_name) 67 | ) 68 | model = mart(**vars(args)) 69 | datamodule = jrdbDatamodule( 70 | root=args.root, 71 | train_batch_size=args.train_batch_size, 72 | val_batch_size=args.val_batch_size, 73 | shuffle=args.shuffle, 74 | num_workers=args.num_workers, 75 | pin_memory=args.pin_memory, 76 | persistent_workers=args.persistent_workers 77 | ) 78 | trainer.fit(model, datamodule) 79 | -------------------------------------------------------------------------------- /HiVT/_train_LED.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule_LED 25 | # from models._LED_onlyDiff import LED 26 | from models._LED import LED 27 | 28 | if __name__ == '__main__': 29 | pl.seed_everything(4433) 30 | 31 | parser = ArgumentParser() 32 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_eth') 33 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 34 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 35 | 36 | parser.add_argument('--train_batch_size', type=int, default=32) 37 | parser.add_argument('--val_batch_size', type=int, default=32) 38 | parser.add_argument('--shuffle', type=bool, default=True) 39 | parser.add_argument('--num_workers', type=int, default=4) 40 | parser.add_argument('--pin_memory', type=bool, default=True) 41 | parser.add_argument('--persistent_workers', type=bool, default=True) 42 | parser.add_argument('--gpus', type=int, default=1) 43 | parser.add_argument('--max_epochs', type=int, default=64) 44 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 45 | parser.add_argument('--save_top_k', type=int, default=5) 46 | 47 | # # # # # # Editables 48 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 49 | # parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_LED/jrdb', help='Directory where logs and checkpoints are saved') 50 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_LED_curFixed/jrdb', help='Directory where logs and checkpoints are saved') 51 | 52 | parser.add_argument('--embed_dim', type=int, default=128) 53 | parser.add_argument('--modality', type=str, default='traj') # traj, traj_pose, traj_text, traj_pose_text 54 | # parser.add_argument('--log_name', type=str, default='debug', help='Directory where logs and checkpoints are saved') 55 | parser.add_argument('--log_name', type=str, default='V2embed128_batch32_allTimestep_traj_wGlobal_batchAve', help='Directory where logs and checkpoints are saved') 56 | parser.add_argument('--diff_steps', type=int, default=5) 57 | parser.add_argument('--diff_beta_start', type=float, default=1.e-4) 58 | parser.add_argument('--diff_beta_end', type=float, default=5.e-2) 59 | parser.add_argument('--diff_beta_schedule', type=str, default='linear') 60 | 61 | parser = LED.add_model_specific_args(parser) 62 | args = parser.parse_args() 63 | 64 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 65 | trainer = pl.Trainer.from_argparse_args(args, 66 | callbacks=[model_checkpoint], 67 | default_root_dir=os.path.join(args.log_dir, args.log_name), 68 | num_sanity_val_steps=2, 69 | ) 70 | model = LED(**vars(args)) 71 | datamodule = jrdbDatamodule_LED.from_argparse_args(args) 72 | trainer.fit(model, datamodule) 73 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/_train_LED.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule_LED 25 | # from models._LED_onlyDiff import LED 26 | from models._LED import LED 27 | 28 | if __name__ == '__main__': 29 | pl.seed_everything(4433) 30 | 31 | parser = ArgumentParser() 32 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/ethucy_v2_fps_2_5_frame_20_eth') 33 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 34 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 35 | 36 | parser.add_argument('--train_batch_size', type=int, default=32) 37 | parser.add_argument('--val_batch_size', type=int, default=32) 38 | parser.add_argument('--shuffle', type=bool, default=True) 39 | parser.add_argument('--num_workers', type=int, default=4) 40 | parser.add_argument('--pin_memory', type=bool, default=True) 41 | parser.add_argument('--persistent_workers', type=bool, default=True) 42 | parser.add_argument('--gpus', type=int, default=1) 43 | parser.add_argument('--max_epochs', type=int, default=64) 44 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 45 | parser.add_argument('--save_top_k', type=int, default=5) 46 | 47 | # # # # # # Editables 48 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 49 | # parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_LED/jrdb', help='Directory where logs and checkpoints are saved') 50 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_LED_curFixed/jrdb', help='Directory where logs and checkpoints are saved') 51 | 52 | parser.add_argument('--embed_dim', type=int, default=128) 53 | parser.add_argument('--modality', type=str, default='traj') # traj, traj_pose, traj_text, traj_pose_text 54 | # parser.add_argument('--log_name', type=str, default='debug', help='Directory where logs and checkpoints are saved') 55 | parser.add_argument('--log_name', type=str, default='V2embed128_batch32_allTimestep_traj_wGlobal_batchAve', help='Directory where logs and checkpoints are saved') 56 | parser.add_argument('--diff_steps', type=int, default=5) 57 | parser.add_argument('--diff_beta_start', type=float, default=1.e-4) 58 | parser.add_argument('--diff_beta_end', type=float, default=5.e-2) 59 | parser.add_argument('--diff_beta_schedule', type=str, default='linear') 60 | 61 | parser = LED.add_model_specific_args(parser) 62 | args = parser.parse_args() 63 | 64 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 65 | trainer = pl.Trainer.from_argparse_args(args, 66 | callbacks=[model_checkpoint], 67 | default_root_dir=os.path.join(args.log_dir, args.log_name), 68 | num_sanity_val_steps=2, 69 | ) 70 | model = LED(**vars(args)) 71 | datamodule = jrdbDatamodule_LED.from_argparse_args(args) 72 | trainer.fit(model, datamodule) 73 | -------------------------------------------------------------------------------- /HiVT/train_sit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models.hivt_v3_sit import HiVT # Fixed to intra 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(1555) 29 | 30 | parser = ArgumentParser() 31 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 32 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_prompt2') 33 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_prompt3') 34 | # parser.add_argument('--embed_dim', type=int, default=64) 35 | parser.add_argument('--train_batch_size', type=int, default=64) 36 | parser.add_argument('--val_batch_size', type=int, default=64) 37 | parser.add_argument('--shuffle', type=bool, default=True) 38 | parser.add_argument('--num_workers', type=int, default=1) 39 | parser.add_argument('--pin_memory', type=bool, default=True) 40 | parser.add_argument('--persistent_workers', type=bool, default=True) 41 | parser.add_argument('--gpus', type=int, default=1) 42 | parser.add_argument('--max_epochs', type=int, default=64) 43 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 44 | parser.add_argument('--save_top_k', type=int, default=5) 45 | 46 | # # # # # # Editables 47 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_sit_curFixed/', help='Directory where logs and checkpoints are saved') 48 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 49 | parser.add_argument('--encoder_type', type=str, default='traj_v3', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_v1', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text']) 50 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 51 | parser.add_argument('--embed_dim', type=int, default=128) 52 | parser.add_argument('--emb_numHead', type=int, default=4) 53 | parser.add_argument('--emb_dropout', type=int, default=0.5) 54 | parser.add_argument('--emb_numLayers', type=int, default=2) 55 | parser.add_argument('--log_name', type=str, default='_V2InstCorrected_hivtV2_TrajText_scaled3_intra1.5_prompt1', help='Directory where logs and checkpoints are saved') 56 | 57 | parser = HiVT.add_model_specific_args(parser) 58 | args = parser.parse_args() 59 | 60 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 61 | trainer = pl.Trainer.from_argparse_args(args, 62 | callbacks=[model_checkpoint], 63 | default_root_dir=os.path.join(args.log_dir, args.log_name), 64 | # num_sanity_val_steps=0, 65 | ) 66 | model = HiVT(**vars(args)) 67 | datamodule = jrdbDatamodule.from_argparse_args(args) 68 | trainer.fit(model, datamodule) 69 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/train_sit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | from models.hivt_v3_sit import HiVT # Fixed to intra 26 | 27 | if __name__ == '__main__': 28 | pl.seed_everything(1555) 29 | 30 | parser = ArgumentParser() 31 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 32 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_prompt2') 33 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_prompt3') 34 | # parser.add_argument('--embed_dim', type=int, default=64) 35 | parser.add_argument('--train_batch_size', type=int, default=64) 36 | parser.add_argument('--val_batch_size', type=int, default=64) 37 | parser.add_argument('--shuffle', type=bool, default=True) 38 | parser.add_argument('--num_workers', type=int, default=1) 39 | parser.add_argument('--pin_memory', type=bool, default=True) 40 | parser.add_argument('--persistent_workers', type=bool, default=True) 41 | parser.add_argument('--gpus', type=int, default=1) 42 | parser.add_argument('--max_epochs', type=int, default=64) 43 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 44 | parser.add_argument('--save_top_k', type=int, default=5) 45 | 46 | # # # # # # Editables 47 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_sit_curFixed/', help='Directory where logs and checkpoints are saved') 48 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 49 | parser.add_argument('--encoder_type', type=str, default='traj_v3', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_v1', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text']) 50 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 51 | parser.add_argument('--embed_dim', type=int, default=128) 52 | parser.add_argument('--emb_numHead', type=int, default=4) 53 | parser.add_argument('--emb_dropout', type=int, default=0.5) 54 | parser.add_argument('--emb_numLayers', type=int, default=2) 55 | parser.add_argument('--log_name', type=str, default='_V2InstCorrected_hivtV2_TrajText_scaled3_intra1.5_prompt1', help='Directory where logs and checkpoints are saved') 56 | 57 | parser = HiVT.add_model_specific_args(parser) 58 | args = parser.parse_args() 59 | 60 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 61 | trainer = pl.Trainer.from_argparse_args(args, 62 | callbacks=[model_checkpoint], 63 | default_root_dir=os.path.join(args.log_dir, args.log_name), 64 | # num_sanity_val_steps=0, 65 | ) 66 | model = HiVT(**vars(args)) 67 | datamodule = jrdbDatamodule.from_argparse_args(args) 68 | trainer.fit(model, datamodule) 69 | -------------------------------------------------------------------------------- /romp/vis_human/pyrenderer.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | os.environ['PYOPENGL_PLATFORM'] = 'egl' 3 | try: 4 | import pyrender 5 | except: 6 | print('To use the pyrender, we are trying to install it via pip install pyrender.') 7 | print('If you meet any bug in this process, please refer to https://pyrender.readthedocs.io/en/latest/install/index.html to install it by youself.') 8 | os.system('pip install pyrender') 9 | import pyrender 10 | 11 | import trimesh 12 | import numpy as np 13 | 14 | 15 | def add_light(scene, light): 16 | # Use 3 directional lights 17 | light_pose = np.eye(4) 18 | light_pose[:3, 3] = np.array([0, -1, 1]) 19 | scene.add(light, pose=light_pose) 20 | light_pose[:3, 3] = np.array([0, 1, 1]) 21 | scene.add(light, pose=light_pose) 22 | light_pose[:3, 3] = np.array([1, 1, 2]) 23 | scene.add(light, pose=light_pose) 24 | 25 | 26 | class Py3DR(object): 27 | def __init__(self, FOV=60, height=512, width=512, focal_length=None): 28 | self.renderer = pyrender.OffscreenRenderer(height, width) 29 | if focal_length is None: 30 | self.focal_length = 1/(np.tan(np.radians(FOV/2))) 31 | else: 32 | self.focal_length = focal_length / max(height, width)*2 33 | self.rot = trimesh.transformations.rotation_matrix( 34 | np.radians(180), [1, 0, 0]) 35 | self.colors = [ 36 | (.7, .7, .6, 1.), 37 | (.7, .5, .5, 1.), # Pink 38 | (.5, .5, .7, 1.), # Blue 39 | (.5, .55, .3, 1.), # capsule 40 | (.3, .5, .55, 1.), # Yellow 41 | ] 42 | 43 | def __call__(self, vertices, triangles, image_shape=[724,512], mesh_colors=None, f=None, persp=True, camera_pose=None): 44 | img_height, img_width = image_shape[0], image_shape[1] 45 | self.renderer.viewport_height = img_height 46 | self.renderer.viewport_width = img_width 47 | # Create a scene for each image and render all meshes 48 | scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0], 49 | ambient_light=(0.3, 0.3, 0.3)) 50 | 51 | if camera_pose is None: 52 | camera_pose = np.eye(4) 53 | if persp: 54 | if f is None: 55 | f = self.focal_length * max(img_height, img_width) / 2 56 | camera = pyrender.camera.IntrinsicsCamera(fx=f, fy=f, cx=img_width / 2., cy=img_height / 2.) 57 | else: 58 | xmag = ymag = np.abs(vertices[:,:,:2]).max() * 1.05 59 | camera = pyrender.camera.OrthographicCamera(xmag, ymag, znear=0.05, zfar=100.0, name=None) 60 | scene.add(camera, pose=camera_pose) 61 | 62 | if len(triangles.shape) == 2: 63 | triangles = [triangles for _ in range(len(vertices))] 64 | 65 | # Create light source 66 | light = pyrender.DirectionalLight(color=[1.0, 1.0, 1.0], intensity=0.5) 67 | # for every person in the scene 68 | for n in range(vertices.shape[0]): 69 | mesh = trimesh.Trimesh(vertices[n], triangles[n]) 70 | mesh.apply_transform(self.rot) 71 | if mesh_colors is None: 72 | mesh_color = self.colors[n % len(self.colors)] 73 | else: 74 | mesh_color = mesh_colors[n % len(mesh_colors)] 75 | material = pyrender.MetallicRoughnessMaterial( 76 | metallicFactor=0.2, 77 | alphaMode='OPAQUE', 78 | baseColorFactor=mesh_color) 79 | mesh = pyrender.Mesh.from_trimesh(mesh, material=material) 80 | scene.add(mesh, 'mesh') 81 | 82 | add_light(scene, light) 83 | # Alpha channel was not working previously need to check again 84 | # Until this is fixed use hack with depth image to get the opacity 85 | color, rend_depth = self.renderer.render(scene, flags=pyrender.RenderFlags.RGBA) 86 | 87 | # color = color.astype(np.float32) 88 | # valid_mask = (rend_depth > 0)[:, :, None] 89 | # output_image = (color[:, :, :3] * valid_mask + 90 | # (1 - valid_mask) * image).astype(np.uint8) 91 | 92 | return color, rend_depth 93 | 94 | def delete(self): 95 | self.renderer.delete() -------------------------------------------------------------------------------- /preprocessing_script_upload/preprocess_1st_jrdb.py: -------------------------------------------------------------------------------- 1 | # ### JRDB parser with image captioning via llava-next ### 2 | from PIL import Image 3 | import cv2 4 | import numpy as np 5 | import os, sys 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "3" 7 | sys.path.append('/mnt/jaewoo4tb/t2p') 8 | import os.path as osp 9 | import torch 10 | from torch import nn 11 | import glob 12 | from tqdm import tqdm 13 | import argparse 14 | import copy 15 | import json 16 | from preprocess_utils import * 17 | from captions_jrdb import Pose3DEngine 18 | sys.path.append('/home/user/anaconda3/envs/t2p/lib/python3.8/site-packages/bev/') 19 | 20 | def main_parse(): 21 | print("Start parsing.") 22 | '''Parameters''' 23 | VERSION = 'v1_debug' 24 | VLM_TYPE = 'vllm' # ['local', 'vllm'] for local, use bev_jw. for vllm, use t2p (conda env) 25 | VISUALIZE = False 26 | 27 | default_save_dir = '/mnt/jaewoo4tb/t2p/preprocessed_1st/' + VERSION 28 | base_jrdb = '/mnt/jaewoo4tb/t2p/jrdb/train_dataset/' 29 | 30 | label_2d_base = f"{base_jrdb}labels/labels_2d_stitched/" 31 | label_3d_base = f"{base_jrdb}labels/labels_3d/" 32 | label_social_base = f"{base_jrdb}labels/labels_2d_activity_social_stitched/" 33 | img_base = f"{base_jrdb}images/image_stitched/" 34 | odometry_base = f'{base_jrdb}odometry_processed' 35 | 36 | scenes = sorted(glob.glob(base_jrdb+ 'images/image_0/*')) 37 | 38 | 39 | for sceneIdx, scene in enumerate(scenes): 40 | scene_name = os.path.basename(os.path.normpath(scene)) 41 | print(f'Processing scene: {scene_name}, {sceneIdx} out of {len(scenes)} scenes.') 42 | 43 | os.makedirs(default_save_dir, exist_ok=True) 44 | output_savedir = os.path.join(default_save_dir, scene_name) 45 | annot_odometry_pos = np.load(f'{odometry_base}/{scene_name}_pos.npy') 46 | annot_odometry_pos -= annot_odometry_pos[0] # Normalize by reference to the first frame 47 | annot_odometry_ori = np.load(f'{odometry_base}/{scene_name}_orientation.npy') 48 | annot_odometry_ori -= annot_odometry_ori[0] # Normalize by reference to the first frame 49 | 50 | frame_save_dir, gif_save_dir, data_save_dir, pose2d_save_dir = os.path.join(output_savedir, 'frames'), os.path.join(output_savedir, 'gif'), os.path.join(output_savedir, 'data'), os.path.join(output_savedir, 'pose2d') 51 | save_data, save_interaction = {}, {} 52 | 53 | 54 | frame_num = len(glob.glob(f'{base_jrdb}images/image_stitched/{scene_name}/*.jpg')) 55 | parse_engine = Pose3DEngine() 56 | parse_engine.load_files(label_2d_base + scene_name + ".json", label_3d_base + scene_name + ".json", img_base + scene_name + "/", label_social_base + scene_name + ".json") 57 | 58 | start_frame_idx = 0 59 | last_frame_idx = frame_num 60 | for frame_idx in tqdm(range(start_frame_idx, frame_num)): 61 | parse_engine.preprocess_frame(frame_idx, annot_odometry_pos[frame_idx], annot_odometry_ori[frame_idx]) 62 | parse_engine.regress_3dpose() 63 | 64 | save_data[frame_idx] = {} 65 | for agent_id in parse_engine.agents.keys(): 66 | if agent_id not in save_data[frame_idx].keys(): 67 | save_data[frame_idx][agent_id] = {} 68 | 69 | save_data[frame_idx][agent_id]["global_position"] = parse_engine.agents[agent_id].global_pos 70 | save_data[frame_idx][agent_id]["local_position"] = parse_engine.agents[agent_id].local_pos 71 | save_data[frame_idx][agent_id]["pose"] = parse_engine.agents[agent_id].pose 72 | save_data[frame_idx][agent_id]["robot_pos"] = parse_engine.agents[agent_id].robot_pos 73 | save_data[frame_idx][agent_id]["robot_ori"] = parse_engine.agents[agent_id].robot_ori 74 | save_data[frame_idx][agent_id]["rot_z"] = parse_engine.agents[agent_id].rot_z 75 | 76 | # visualize 77 | if VISUALIZE: 78 | os.makedirs(frame_save_dir, exist_ok=True) 79 | plot_3d_human(frame_idx, save_data, None, save_dir=frame_save_dir) 80 | 81 | torch.save(save_data, os.path.join(default_save_dir, scene_name+f'_agents_{start_frame_idx}_to_{last_frame_idx}.pt')) 82 | 83 | if __name__ == "__main__": 84 | main_parse() -------------------------------------------------------------------------------- /HiVT/___scripts_copy/validate_viz_sit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from torch_geometric.data import DataLoader 24 | 25 | from datasets import jrdbDataset 26 | 27 | from datamodules import jrdbDatamodule 28 | from models.hivt_v2_withViz import HiVT 29 | 30 | if __name__ == '__main__': 31 | pl.seed_everything(2024) 32 | 33 | parser = ArgumentParser() 34 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 35 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 36 | # parser.add_argument('--embed_dim', type=int, default=64) 37 | parser.add_argument('--train_batch_size', type=int, default=64) 38 | parser.add_argument('--val_batch_size', type=int, default=1) 39 | parser.add_argument('--shuffle', type=bool, default=True) 40 | parser.add_argument('--num_workers', type=int, default=1) 41 | parser.add_argument('--pin_memory', type=bool, default=True) 42 | parser.add_argument('--persistent_workers', type=bool, default=True) 43 | parser.add_argument('--gpus', type=int, default=1) 44 | parser.add_argument('--max_epochs', type=int, default=64) 45 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 46 | parser.add_argument('--save_top_k', type=int, default=5) 47 | 48 | # # # # # # Editables 49 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_withViz_sit/', help='Directory where logs and checkpoints are saved') 50 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 51 | parser.add_argument('--encoder_type', type=str, default='traj_v3', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 52 | parser.add_argument('--log_name', type=str, default='_InstCorrected_val', help='Directory where logs and checkpoints are saved') 53 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 54 | parser.add_argument('--embed_dim', type=int, default=128) 55 | parser.add_argument('--emb_numHead', type=int, default=8) 56 | parser.add_argument('--emb_dropout', type=int, default=0.5) 57 | parser.add_argument('--emb_numLayers', type=int, default=1) 58 | parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_sit_curFixed/_V2InstCorrected_hivtV2_Traj_scaled3_intra1.5/lightning_logs/version_1/checkpoints/epoch=3-step=268.ckpt') 59 | parser = HiVT.add_model_specific_args(parser) 60 | args = parser.parse_args() 61 | 62 | trainer = pl.Trainer.from_argparse_args(args, 63 | default_root_dir=os.path.join(args.log_dir, args.log_name), 64 | # num_sanity_val_steps=0, 65 | ) 66 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=False) 67 | val_dataset = jrdbDataset(root=args.root, split='val') 68 | dataloader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, 69 | pin_memory=args.pin_memory, persistent_workers=args.persistent_workers) 70 | trainer.validate(model, dataloader) 71 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/validate_viz_saveSample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from torch_geometric.data import DataLoader 24 | 25 | from datasets import jrdbDataset 26 | 27 | from datamodules import jrdbDatamodule 28 | from models.hivt_v2_withViz_saveSamples import HiVT 29 | 30 | if __name__ == '__main__': 31 | pl.seed_everything(2024) 32 | 33 | parser = ArgumentParser() 34 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 35 | # parser.add_argument('--embed_dim', type=int, default=64) 36 | parser.add_argument('--train_batch_size', type=int, default=64) 37 | parser.add_argument('--val_batch_size', type=int, default=1) 38 | parser.add_argument('--shuffle', type=bool, default=True) 39 | parser.add_argument('--num_workers', type=int, default=1) 40 | parser.add_argument('--pin_memory', type=bool, default=True) 41 | parser.add_argument('--persistent_workers', type=bool, default=True) 42 | parser.add_argument('--gpus', type=int, default=1) 43 | parser.add_argument('--max_epochs', type=int, default=64) 44 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 45 | parser.add_argument('--save_top_k', type=int, default=5) 46 | 47 | # # # # # # Editables 48 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_withViz/', help='Directory where logs and checkpoints are saved') 49 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 50 | parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 51 | parser.add_argument('--log_name', type=str, default='_InstCorrected_val', help='Directory where logs and checkpoints are saved') 52 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 53 | parser.add_argument('--embed_dim', type=int, default=128) 54 | parser.add_argument('--emb_numHead', type=int, default=8) 55 | parser.add_argument('--emb_dropout', type=int, default=0.5) 56 | parser.add_argument('--emb_numLayers', type=int, default=1) 57 | 58 | # Traj pose 59 | parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb_curFixed/_V2InstCorrected_hivtV2_trajPose_onlyIntra_1.5m/lightning_logs/version_0/checkpoints/epoch=4-step=1610.ckpt') 60 | # Traj 61 | # parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb_curFixed/_V2InstCorrected_hivtV2_traj_allTimesteps_intra1.5/lightning_logs/version_1/checkpoints/epoch=4-step=1610.ckpt') 62 | parser = HiVT.add_model_specific_args(parser) 63 | args = parser.parse_args() 64 | 65 | trainer = pl.Trainer.from_argparse_args(args, 66 | default_root_dir=os.path.join(args.log_dir, args.log_name), 67 | # num_sanity_val_steps=0, 68 | ) 69 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=False) 70 | val_dataset = jrdbDataset(root=args.root, split='val') 71 | dataloader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, 72 | pin_memory=args.pin_memory, persistent_workers=args.persistent_workers) 73 | trainer.validate(model, dataloader) 74 | -------------------------------------------------------------------------------- /HiVT/_train_socialTransmotion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | # from models._social_transmotion_v1 import socialTransmotion 26 | from HiVT.models._social_transmotion_original import socialTransmotion 27 | 28 | if __name__ == '__main__': 29 | pl.seed_everything(9567) 30 | 31 | parser = ArgumentParser() 32 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 33 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 34 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 35 | parser.add_argument('--train_batch_size', type=int, default=44) 36 | parser.add_argument('--val_batch_size', type=int, default=44) 37 | parser.add_argument('--shuffle', type=bool, default=True) 38 | parser.add_argument('--num_workers', type=int, default=1) 39 | parser.add_argument('--pin_memory', type=bool, default=True) 40 | parser.add_argument('--persistent_workers', type=bool, default=True) 41 | parser.add_argument('--gpus', type=int, default=1) 42 | parser.add_argument('--max_epochs', type=int, default=64) 43 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 44 | parser.add_argument('--save_top_k', type=int, default=5) 45 | 46 | # # # # # # Editables 47 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_socialTransmotion_curFixed/jrdb', help='Directory where logs and checkpoints are saved') 48 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 49 | parser.add_argument('--embed_dim', type=int, default=128) 50 | parser.add_argument('--seq_len', type=int, default=212) 51 | parser.add_argument('--token_num', type=int, default=25) 52 | parser.add_argument('--num_layers_local', type=int, default=5) 53 | parser.add_argument('--num_layers_global', type=int, default=3) 54 | parser.add_argument('--dim_feedforward', type=int, default=1024) 55 | parser.add_argument('--output_scale', type=int, default=1) 56 | parser.add_argument('--num_head_st', type=int, default=4) 57 | parser.add_argument('--modality', type=str, default='traj_pose') # traj, traj_pose 58 | parser.add_argument('--log_name', type=str, default='socialTransmotion_original_v1_trajPose_curFixed_onlyDiff', help='Directory where logs and checkpoints are saved') 59 | # parser.add_argument('--log_name', type=str, default='socialTransmotionV1_woRotate_trajPose_trmPose_batch32', help='Directory where logs and checkpoints are saved') 60 | # parser.add_argument('--accumulate_grad_batches', type=int, default=2) 61 | 62 | parser = socialTransmotion.add_model_specific_args(parser) 63 | args = parser.parse_args() 64 | 65 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 66 | trainer = pl.Trainer.from_argparse_args(args, 67 | callbacks=[model_checkpoint], 68 | default_root_dir=os.path.join(args.log_dir, args.log_name), 69 | # num_sanity_val_steps=0, 70 | ) 71 | model = socialTransmotion(**vars(args)) 72 | datamodule = jrdbDatamodule.from_argparse_args(args) 73 | trainer.fit(model, datamodule) 74 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | # from models.hivt_v2_sit import HiVT 26 | from models.hivt_v2 import HiVT 27 | 28 | if __name__ == '__main__': 29 | pl.seed_everything(1110) 30 | 31 | parser = ArgumentParser() 32 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 33 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 34 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 35 | # parser.add_argument('--embed_dim', type=int, default=64) 36 | parser.add_argument('--train_batch_size', type=int, default=64) 37 | parser.add_argument('--val_batch_size', type=int, default=64) 38 | parser.add_argument('--shuffle', type=bool, default=True) 39 | parser.add_argument('--num_workers', type=int, default=1) 40 | parser.add_argument('--pin_memory', type=bool, default=True) 41 | parser.add_argument('--persistent_workers', type=bool, default=True) 42 | parser.add_argument('--gpus', type=int, default=1) 43 | parser.add_argument('--max_epochs', type=int, default=64) 44 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 45 | parser.add_argument('--save_top_k', type=int, default=5) 46 | 47 | # # # # # # Editables 48 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb_curFixed/', help='Directory where logs and checkpoints are saved') 49 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 50 | parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_withAngle', 'traj_pose_woRotation', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_woRotation', 'traj_pose_text_embedMask', 'traj_pose_text_v1', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text', 'traj_text_onlyEgo', 'traj_text_egoOther', 'traj_text_egoOtherInteraction', 'traj_pose_text_onlyEgo', 'traj_pose_text_egoOther', 'traj_pose_text_egoOtherInteraction']) 51 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 52 | parser.add_argument('--embed_dim', type=int, default=128) 53 | parser.add_argument('--emb_numHead', type=int, default=4) 54 | parser.add_argument('--emb_dropout', type=int, default=0.5) 55 | parser.add_argument('--emb_numLayers', type=int, default=2) 56 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 57 | parser.add_argument('--log_name', type=str, default='debug', help='Directory where logs and checkpoints are saved') 58 | 59 | parser = HiVT.add_model_specific_args(parser) 60 | args = parser.parse_args() 61 | 62 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 63 | trainer = pl.Trainer.from_argparse_args(args, 64 | callbacks=[model_checkpoint], 65 | default_root_dir=os.path.join(args.log_dir, args.log_name), 66 | # num_sanity_val_steps=0, 67 | ) 68 | model = HiVT(**vars(args)) 69 | datamodule = jrdbDatamodule.from_argparse_args(args) 70 | trainer.fit(model, datamodule) 71 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/_train_socialTransmotion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | # from models._social_transmotion_v1 import socialTransmotion 26 | from HiVT.models._social_transmotion_original import socialTransmotion 27 | 28 | if __name__ == '__main__': 29 | pl.seed_everything(9567) 30 | 31 | parser = ArgumentParser() 32 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 33 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 34 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 35 | parser.add_argument('--train_batch_size', type=int, default=44) 36 | parser.add_argument('--val_batch_size', type=int, default=44) 37 | parser.add_argument('--shuffle', type=bool, default=True) 38 | parser.add_argument('--num_workers', type=int, default=1) 39 | parser.add_argument('--pin_memory', type=bool, default=True) 40 | parser.add_argument('--persistent_workers', type=bool, default=True) 41 | parser.add_argument('--gpus', type=int, default=1) 42 | parser.add_argument('--max_epochs', type=int, default=64) 43 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 44 | parser.add_argument('--save_top_k', type=int, default=5) 45 | 46 | # # # # # # Editables 47 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/_logs_socialTransmotion_curFixed/jrdb', help='Directory where logs and checkpoints are saved') 48 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 49 | parser.add_argument('--embed_dim', type=int, default=128) 50 | parser.add_argument('--seq_len', type=int, default=212) 51 | parser.add_argument('--token_num', type=int, default=25) 52 | parser.add_argument('--num_layers_local', type=int, default=5) 53 | parser.add_argument('--num_layers_global', type=int, default=3) 54 | parser.add_argument('--dim_feedforward', type=int, default=1024) 55 | parser.add_argument('--output_scale', type=int, default=1) 56 | parser.add_argument('--num_head_st', type=int, default=4) 57 | parser.add_argument('--modality', type=str, default='traj_pose') # traj, traj_pose 58 | parser.add_argument('--log_name', type=str, default='socialTransmotion_original_v1_trajPose_curFixed_onlyDiff', help='Directory where logs and checkpoints are saved') 59 | # parser.add_argument('--log_name', type=str, default='socialTransmotionV1_woRotate_trajPose_trmPose_batch32', help='Directory where logs and checkpoints are saved') 60 | # parser.add_argument('--accumulate_grad_batches', type=int, default=2) 61 | 62 | parser = socialTransmotion.add_model_specific_args(parser) 63 | args = parser.parse_args() 64 | 65 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 66 | trainer = pl.Trainer.from_argparse_args(args, 67 | callbacks=[model_checkpoint], 68 | default_root_dir=os.path.join(args.log_dir, args.log_name), 69 | # num_sanity_val_steps=0, 70 | ) 71 | model = socialTransmotion(**vars(args)) 72 | datamodule = jrdbDatamodule.from_argparse_args(args) 73 | trainer.fit(model, datamodule) 74 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/train_joint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | # from models.hivt_v2_sit import HiVT 26 | from models.hivt_v2 import HiVT 27 | 28 | if __name__ == '__main__': 29 | pl.seed_everything(5817) 30 | 31 | parser = ArgumentParser() 32 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 33 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot') 34 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_joint') 35 | # parser.add_argument('--embed_dim', type=int, default=64) 36 | parser.add_argument('--train_batch_size', type=int, default=64) 37 | parser.add_argument('--val_batch_size', type=int, default=64) 38 | parser.add_argument('--shuffle', type=bool, default=True) 39 | parser.add_argument('--num_workers', type=int, default=1) 40 | parser.add_argument('--pin_memory', type=bool, default=True) 41 | parser.add_argument('--persistent_workers', type=bool, default=True) 42 | parser.add_argument('--gpus', type=int, default=1) 43 | parser.add_argument('--max_epochs', type=int, default=64) 44 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 45 | parser.add_argument('--save_top_k', type=int, default=5) 46 | 47 | # # # # # # Editables 48 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb_curFixed/', help='Directory where logs and checkpoints are saved') 49 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 50 | parser.add_argument('--encoder_type', type=str, default='traj_pose_joint', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_withAngle', 'traj_pose_woRotation', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_woRotation', 'traj_pose_text_embedMask', 'traj_pose_text_v1', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text', 'traj_text_onlyEgo', 'traj_text_egoOther', 'traj_text_egoOtherInteraction', 'traj_pose_text_onlyEgo', 'traj_pose_text_egoOther', 'traj_pose_text_egoOtherInteraction']) 51 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 52 | parser.add_argument('--embed_dim', type=int, default=128) 53 | parser.add_argument('--emb_numHead', type=int, default=4) 54 | parser.add_argument('--emb_dropout', type=int, default=0.5) 55 | parser.add_argument('--emb_numLayers', type=int, default=2) 56 | parser.add_argument('--onlyFullTimestep', type=bool, default=False) 57 | parser.add_argument('--log_name', type=str, default='_V2InstCorrected_hivtV2_trajPose_allTimesteps_intra1.5_JOINT', help='Directory where logs and checkpoints are saved') 58 | 59 | parser = HiVT.add_model_specific_args(parser) 60 | args = parser.parse_args() 61 | 62 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 63 | trainer = pl.Trainer.from_argparse_args(args, 64 | callbacks=[model_checkpoint], 65 | default_root_dir=os.path.join(args.log_dir, args.log_name), 66 | # num_sanity_val_steps=0, 67 | ) 68 | model = HiVT(**vars(args)) 69 | datamodule = jrdbDatamodule.from_argparse_args(args) 70 | trainer.fit(model, datamodule) 71 | -------------------------------------------------------------------------------- /HiVT/train_kd_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | # from HiVT.models.hivt_kd_onlyAgent import HiVT # only full timestep x loss 26 | from HiVT.models.hivt_kd_onlyAgent_v4 import HiVT # with inst loss 27 | 28 | if __name__ == '__main__': 29 | pl.seed_everything(4545) 30 | 31 | parser = ArgumentParser() 32 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20') 33 | # parser.add_argument('--embed_dim', type=int, default=64) 34 | parser.add_argument('--train_batch_size', type=int, default=64) 35 | parser.add_argument('--val_batch_size', type=int, default=64) 36 | parser.add_argument('--shuffle', type=bool, default=True) 37 | parser.add_argument('--num_workers', type=int, default=1) 38 | parser.add_argument('--pin_memory', type=bool, default=True) 39 | parser.add_argument('--persistent_workers', type=bool, default=True) 40 | parser.add_argument('--gpus', type=int, default=1) 41 | parser.add_argument('--max_epochs', type=int, default=64) 42 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 43 | parser.add_argument('--save_top_k', type=int, default=5) 44 | parser.add_argument('--encoder_type', type=str, default='traj_pose_text', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_text_embedMask', 'traj_pose_text_embedMask_teacher', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text']) 45 | parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb/__KD_TrajPoseText2trajPose_fullKD_modalMask_const0.25FIXED/lightning_logs/version_0/checkpoints/epoch=29-step=9660.ckpt') 46 | 47 | # # # # # # Editables 48 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb/', help='Directory where logs and checkpoints are saved') 49 | 50 | ################# ALSO EDIT TRAINER ####################### 51 | parser.add_argument('--encoder_type_student', type=str, default='traj_pose', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text']) 52 | ################# ALSO EDIT TRAINER ####################### 53 | 54 | parser.add_argument('--embed_dim', type=int, default=128) 55 | parser.add_argument('--emb_numHead', type=int, default=4) 56 | parser.add_argument('--emb_dropout', type=int, default=0.5) 57 | parser.add_argument('--emb_numLayers', type=int, default=2) 58 | # parser.add_argument('--log_name', type=str, default='__KD_TrajPoseText2trajPose_onlyLocalKL_fixed', help='Directory where logs and checkpoints are saved') 59 | parser.add_argument('--log_name', type=str, default='__KD_TrajPoseText2trajPose_fullKD_modalMask_finetune5e-4', help='Directory where logs and checkpoints are saved') 60 | 61 | parser = HiVT.add_model_specific_args(parser) 62 | args = parser.parse_args() 63 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=False, strict=False) 64 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 65 | trainer = pl.Trainer.from_argparse_args(args, 66 | callbacks=[model_checkpoint], 67 | default_root_dir=os.path.join(args.log_dir, args.log_name), 68 | # num_sanity_val_steps=0, 69 | ) 70 | model = HiVT(**vars(args)) 71 | datamodule = jrdbDatamodule.from_argparse_args(args) 72 | trainer.fit(model, datamodule) 73 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/validate_viz_saveSample_sit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from torch_geometric.data import DataLoader 24 | 25 | from datasets import jrdbDataset 26 | 27 | from datamodules import jrdbDatamodule 28 | from models.hivt_v2_withViz_saveSamples_sit import HiVT 29 | 30 | if __name__ == '__main__': 31 | pl.seed_everything(2024) 32 | 33 | parser = ArgumentParser() 34 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/sit_v2_fps_2_5_frame_20_withRobot_withPoseInfo_v2') 35 | # parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 36 | # parser.add_argument('--embed_dim', type=int, default=64) 37 | parser.add_argument('--train_batch_size', type=int, default=64) 38 | parser.add_argument('--val_batch_size', type=int, default=1) 39 | parser.add_argument('--shuffle', type=bool, default=True) 40 | parser.add_argument('--num_workers', type=int, default=1) 41 | parser.add_argument('--pin_memory', type=bool, default=True) 42 | parser.add_argument('--persistent_workers', type=bool, default=True) 43 | parser.add_argument('--gpus', type=int, default=1) 44 | parser.add_argument('--max_epochs', type=int, default=64) 45 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 46 | parser.add_argument('--save_top_k', type=int, default=5) 47 | 48 | # # # # # # Editables 49 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/sit/logs_withViz/', help='Directory where logs and checkpoints are saved') 50 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 51 | parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 52 | parser.add_argument('--log_name', type=str, default='_InstCorrected_val', help='Directory where logs and checkpoints are saved') 53 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 54 | parser.add_argument('--embed_dim', type=int, default=128) 55 | parser.add_argument('--emb_numHead', type=int, default=8) 56 | parser.add_argument('--emb_dropout', type=int, default=0.5) 57 | parser.add_argument('--emb_numLayers', type=int, default=1) 58 | 59 | # Traj pose 60 | parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_sit_curFixed/_V2InstCorrected_hivtV2_TrajPose_scaled3_intra1.5/lightning_logs/version_2/checkpoints/epoch=4-step=335.ckpt') 61 | # Traj 62 | # parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb_curFixed/_V2InstCorrected_hivtV2_traj_allTimesteps_intra1.5/lightning_logs/version_1/checkpoints/epoch=4-step=1610.ckpt') 63 | parser = HiVT.add_model_specific_args(parser) 64 | args = parser.parse_args() 65 | 66 | trainer = pl.Trainer.from_argparse_args(args, 67 | default_root_dir=os.path.join(args.log_dir, args.log_name), 68 | # num_sanity_val_steps=0, 69 | ) 70 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=False) 71 | val_dataset = jrdbDataset(root=args.root, split='val') 72 | dataloader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, 73 | pin_memory=args.pin_memory, persistent_workers=args.persistent_workers) 74 | trainer.validate(model, dataloader) 75 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/train_kd_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '2' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | 24 | from datamodules import jrdbDatamodule 25 | # from HiVT.models.hivt_kd_onlyAgent import HiVT # only full timestep x loss 26 | from HiVT.models.hivt_kd_onlyAgent_v4 import HiVT # with inst loss 27 | 28 | if __name__ == '__main__': 29 | pl.seed_everything(4545) 30 | 31 | parser = ArgumentParser() 32 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20') 33 | # parser.add_argument('--embed_dim', type=int, default=64) 34 | parser.add_argument('--train_batch_size', type=int, default=64) 35 | parser.add_argument('--val_batch_size', type=int, default=64) 36 | parser.add_argument('--shuffle', type=bool, default=True) 37 | parser.add_argument('--num_workers', type=int, default=1) 38 | parser.add_argument('--pin_memory', type=bool, default=True) 39 | parser.add_argument('--persistent_workers', type=bool, default=True) 40 | parser.add_argument('--gpus', type=int, default=1) 41 | parser.add_argument('--max_epochs', type=int, default=64) 42 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 43 | parser.add_argument('--save_top_k', type=int, default=5) 44 | parser.add_argument('--encoder_type', type=str, default='traj_pose_text', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_text_embedMask', 'traj_pose_text_embedMask_teacher', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text']) 45 | parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb/__KD_TrajPoseText2trajPose_fullKD_modalMask_const0.25FIXED/lightning_logs/version_0/checkpoints/epoch=29-step=9660.ckpt') 46 | 47 | # # # # # # Editables 48 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb/', help='Directory where logs and checkpoints are saved') 49 | 50 | ################# ALSO EDIT TRAINER ####################### 51 | parser.add_argument('--encoder_type_student', type=str, default='traj_pose', choices=['traj', 'traj_v2', 'traj_v3', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text', 'traj_pose_text_v2', 'traj_pose_text_v3', 'traj_text']) 52 | ################# ALSO EDIT TRAINER ####################### 53 | 54 | parser.add_argument('--embed_dim', type=int, default=128) 55 | parser.add_argument('--emb_numHead', type=int, default=4) 56 | parser.add_argument('--emb_dropout', type=int, default=0.5) 57 | parser.add_argument('--emb_numLayers', type=int, default=2) 58 | # parser.add_argument('--log_name', type=str, default='__KD_TrajPoseText2trajPose_onlyLocalKL_fixed', help='Directory where logs and checkpoints are saved') 59 | parser.add_argument('--log_name', type=str, default='__KD_TrajPoseText2trajPose_fullKD_modalMask_finetune5e-4', help='Directory where logs and checkpoints are saved') 60 | 61 | parser = HiVT.add_model_specific_args(parser) 62 | args = parser.parse_args() 63 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=False, strict=False) 64 | model_checkpoint = ModelCheckpoint(monitor=args.monitor, save_top_k=args.save_top_k, mode='min') 65 | trainer = pl.Trainer.from_argparse_args(args, 66 | callbacks=[model_checkpoint], 67 | default_root_dir=os.path.join(args.log_dir, args.log_name), 68 | # num_sanity_val_steps=0, 69 | ) 70 | model = HiVT(**vars(args)) 71 | datamodule = jrdbDatamodule.from_argparse_args(args) 72 | trainer.fit(model, datamodule) 73 | -------------------------------------------------------------------------------- /HiVT/___scripts_copy/validate_viz_qual.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, Zikang Zhou. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from argparse import ArgumentParser 15 | import os 16 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 17 | import torch 18 | torch.multiprocessing.set_sharing_strategy('file_system') 19 | # torch.multiprocessing.set_start_method('spawn') 20 | 21 | import pytorch_lightning as pl 22 | from pytorch_lightning.callbacks import ModelCheckpoint 23 | from torch_geometric.data import DataLoader 24 | 25 | from datasets import jrdbDataset 26 | 27 | from datamodules import jrdbDatamodule 28 | from models.hivt_v2_withViz_qual import HiVT 29 | 30 | if __name__ == '__main__': 31 | pl.seed_everything(2024) 32 | 33 | parser = ArgumentParser() 34 | parser.add_argument('--root', type=str, default='/mnt/jaewoo4tb/textraj/preprocessed_2nd/jrdb_v2_fps_2_5_frame_20_withRobot_withPoseInfo') 35 | # parser.add_argument('--embed_dim', type=int, default=64) 36 | parser.add_argument('--train_batch_size', type=int, default=64) 37 | parser.add_argument('--val_batch_size', type=int, default=1) 38 | parser.add_argument('--shuffle', type=bool, default=True) 39 | parser.add_argument('--num_workers', type=int, default=1) 40 | parser.add_argument('--pin_memory', type=bool, default=True) 41 | parser.add_argument('--persistent_workers', type=bool, default=True) 42 | parser.add_argument('--gpus', type=int, default=1) 43 | parser.add_argument('--max_epochs', type=int, default=64) 44 | parser.add_argument('--monitor', type=str, default='val_minADE', choices=['val_minADE', 'val_minFDE']) 45 | parser.add_argument('--save_top_k', type=int, default=5) 46 | 47 | # # # # # # Editables 48 | parser.add_argument('--log_dir', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_withViz/', help='Directory where logs and checkpoints are saved') 49 | # parser.add_argument('--encoder_type', type=str, default='traj_pose', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 50 | parser.add_argument('--encoder_type', type=str, default='traj_pose_text_intra', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 51 | # parser.add_argument('--encoder_type', type=str, default='traj_pose_text_intra', choices=['traj', 'traj_pose', 'traj_pose_8steps', 'traj_pose_text']) 52 | parser.add_argument('--log_name', type=str, default='_InstCorrected_val', help='Directory where logs and checkpoints are saved') 53 | # parser.add_argument('--log_name', type=str, default='_InstCorrected_traj_pose_text', help='Directory where logs and checkpoints are saved') 54 | parser.add_argument('--embed_dim', type=int, default=128) 55 | parser.add_argument('--emb_numHead', type=int, default=8) 56 | parser.add_argument('--emb_dropout', type=int, default=0.5) 57 | parser.add_argument('--emb_numLayers', type=int, default=1) 58 | # Traj Pose Text intra 59 | parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb_curFixed/_V2InstCorrected_hivtV3_trajPoseTextIntra_onlyFullTimesteps_intra1.5/lightning_logs/version_0/checkpoints/epoch=4-step=1610.ckpt') 60 | 61 | # Traj Pose 62 | # parser.add_argument('--ckpt_path', type=str, default='/mnt/jaewoo4tb/textraj/HiVT/logs_jrdb_curFixed/_V2InstCorrected_hivtV2_trajPose_onlyIntra_1.5m/lightning_logs/version_0/checkpoints/epoch=4-step=1610.ckpt') 63 | 64 | parser = HiVT.add_model_specific_args(parser) 65 | args = parser.parse_args() 66 | 67 | trainer = pl.Trainer.from_argparse_args(args, 68 | default_root_dir=os.path.join(args.log_dir, args.log_name), 69 | # num_sanity_val_steps=0, 70 | ) 71 | model = HiVT.load_from_checkpoint(checkpoint_path=args.ckpt_path, parallel=False) 72 | val_dataset = jrdbDataset(root=args.root, split='val') 73 | dataloader = DataLoader(val_dataset, batch_size=args.val_batch_size, shuffle=False, num_workers=args.num_workers, 74 | pin_memory=args.pin_memory, persistent_workers=args.persistent_workers) 75 | trainer.validate(model, dataloader) 76 | --------------------------------------------------------------------------------