├── .gitattributes ├── .gitignore ├── .vscode └── launch.json ├── CITATION.cff ├── LICENSE ├── README.md ├── assets └── sample_rollout.gif ├── evaluation ├── BITS.yaml └── BITS_example.yaml ├── scripts ├── evaluate.py ├── generate_config_templates.py ├── parse_results.py ├── train.py └── visualize.py ├── setup.py ├── tbsim ├── algos │ ├── __init__.py │ ├── algo_utils.py │ ├── algos.py │ ├── factory.py │ ├── metric_algos.py │ └── multiagent_algos.py ├── configs │ ├── __init__.py │ ├── algo_config.py │ ├── base.py │ ├── config.py │ ├── eval_config.py │ ├── l5kit_config.py │ ├── nusc_config.py │ └── registry.py ├── datasets │ ├── __init__.py │ ├── factory.py │ ├── l5kit_datamodules.py │ └── trajdata_datamodules.py ├── dynamics │ ├── __init__.py │ ├── base.py │ ├── bicycle.py │ ├── double_integrator.py │ ├── single_integrator.py │ └── unicycle.py ├── envs │ ├── base.py │ ├── env_l5kit.py │ ├── env_metrics.py │ ├── env_metrics_old.py │ └── env_trajdata.py ├── evaluation │ ├── __init__.py │ ├── env_builders.py │ ├── metric_composers.py │ └── policy_composers.py ├── l5kit │ ├── agent_sampling_mixed.py │ ├── l5_agent_dataset.py │ ├── l5_ego_dataset.py │ ├── simulation_dataset.py │ ├── vectorizer.py │ ├── vis_rasterizer.py │ └── visualizer.py ├── models │ ├── GAN_regularizer.py │ ├── ScePT.py │ ├── Transformer.py │ ├── __init__.py │ ├── agentformer.py │ ├── agentformer_lib.py │ ├── base_models.py │ ├── cnn_roi_encoder.py │ ├── layers.py │ ├── learned_metrics.py │ ├── multiagent_models.py │ ├── policy_net.py │ ├── rasterized_models.py │ ├── roi_align.py │ ├── unet.py │ └── vaes.py ├── policies │ ├── __init__.py │ ├── base.py │ ├── common.py │ ├── hardcoded.py │ └── wrappers.py └── utils │ ├── __init__.py │ ├── batch_utils.py │ ├── bokeh_script.py │ ├── config_utils.py │ ├── env_utils.py │ ├── experiment_utils.py │ ├── geometry_utils.py │ ├── l5_utils.py │ ├── lane_utils.py │ ├── log_utils.py │ ├── loss_utils.py │ ├── math_utils.py │ ├── metrics.py │ ├── model_utils.py │ ├── planning_utils.py │ ├── rollout_logger.py │ ├── tensor_utils.py │ ├── timer.py │ ├── torch_utils.py │ ├── train_utils.py │ ├── trajdata_utils.py │ ├── tree.py │ └── vis_utils.py └── trajdata_requirements.txt /.gitattributes: -------------------------------------------------------------------------------- 1 | * text=auto eol=lf 2 | *.{cmd,[cC][mM][dD]} text eol=crlf 3 | *.{bat,[bB][aA][tT]} text eol=crlf -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | videos/ 3 | figures/ 4 | experiments/ 5 | checkpoints 6 | visualizations/ 7 | results/ 8 | ## config files 9 | #*.yaml 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | bin/ 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Mr Developer 67 | .mr.developer.cfg 68 | .project 69 | .pydevproject 70 | 71 | # Rope 72 | .ropeproject 73 | # Django stuff: 74 | *.log 75 | *.pot 76 | local_settings.py 77 | db.sqlite3 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | 90 | # PyBuilder 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | 96 | # IPython 97 | profile_default/ 98 | ipython_config.py 99 | 100 | # pyenv 101 | .python-version 102 | 103 | # celery beat schedule file 104 | celerybeat-schedule 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | 137 | agentformer_trained_models/ 138 | bc_trained_models/ 139 | tree_trained_models/ 140 | tree_vae_trained_models/ 141 | ma_rasterized_trained_models 142 | scept_trained_models 143 | spatial_planner_trained_models 144 | 145 | 146 | 147 | 148 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "version": "0.2.0", 3 | "configurations": [ 4 | { 5 | "name": "BITS sim nuScenes", 6 | "type": "python", 7 | "request": "launch", 8 | "python": "/home/yuxiaoc/.conda/envs/pub_tbsim/bin/python3", 9 | "program": "scripts/evaluate.py", 10 | "console": "integratedTerminal", 11 | "justMyCode": true, 12 | "args":[ 13 | "--results_root_dir=results/", 14 | "--num_scenes_per_batch=2", 15 | "--dataset_path=", 16 | "--env=nusc", 17 | "--ckpt_yaml=evaluation/BITS_example.yaml", 18 | "--eval_class=HierAgentAware", 19 | ] 20 | } 21 | ] 22 | } 23 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: Xu 5 | given-names: Danfei 6 | orcid: https://orcid.org/0000-0002-8744-3861 7 | - family-names: Chen 8 | given-names: Yuxiao 9 | orcid: https://orcid.org/0000-0001-5276-7156 10 | - family-names: Ivanovic 11 | given-names: Boris 12 | orcid: https://orcid.org/0000-0002-8698-202X 13 | - family-names: Pavone 14 | given-names: Marco 15 | orcid: https://orcid.org/0000-0002-0206-4337 16 | title: "Traffic Behavior Simulation" 17 | version: 0.0.1 18 | date-released: 2023-05-26 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Nvidia Source Code License-NC 2 | 1. Definitions 3 | “Licensor” means any person or entity that distributes its Work. 4 | 5 | “Software” means the original work of authorship made available under this License. 6 | 7 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License. 8 | 9 | “Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates. 10 | 11 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 12 | 13 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 14 | 15 | 2. License Grants 16 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 17 | 18 | 3. Limitations 19 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 20 | 21 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 22 | 23 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or its affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 24 | 25 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately. 26 | 27 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 28 | 29 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately. 30 | 31 | 4. Disclaimer of Warranty. 32 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 33 | 34 | 5. Limitation of Liability. 35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Traffic Behavior Simulation (tbsim) 2 | TBSIM is a simulation environment designed for data-driven closed-loop simulation of autonomous vehicles. It supports training and evaluation of popular traffic models such as behavior cloning, CVAE, and our new [BITS](https://arxiv.org/abs/2208.12403) model specifically designed for AV simulation. The users can flexibly specify the simulation environment and plug in their own model (learned or analytic) for evaluation. 3 | 4 | Thanks to [trajdata](https://github.com/NVlabs/trajdata), TBSIM can access data and scenarios from a wide range of public datasets, including [Lyft Level 5](https://woven.toyota/en/prediction-dataset), [nuScenes](https://www.nuscenes.org/nuscenes), and [nuPlan](https://nuplan.org/). 5 | 6 | TBSIM is well equiped with abundant util functions, and supports batched simulation in parallel, logging, and replay. We also provide a suite of simulation metrics that measures the safety, liveness, and diversity of the simulation. 7 | 8 | 9 | 10 | ## Installation 11 | 12 | Install `tbsim` 13 | ```angular2html 14 | conda create -n tbsim python=3.8 15 | conda activate tbsim 16 | git clone git@github.com:NVlabs/traffic-behavior-simulation.git tbsim 17 | cd tbsim 18 | pip install -e . 19 | ``` 20 | 21 | Install `trajdata` 22 | ``` 23 | cd .. 24 | git clone ssh://git@github.com:NVlabs/trajdata.git trajdata 25 | cd trajdata 26 | # replace requirements.txt with trajdata_requirements.txt included in tbsim 27 | pip install -e . 28 | ``` 29 | 30 | Install `Pplan` 31 | ``` 32 | cd .. 33 | git clone ssh://git@github.com:NVlabs/spline-planner.git Pplan 34 | cd Pplan 35 | pip install -e . 36 | ``` 37 | 38 | Usually the user needs to install torch separately that fits the hardware setup (OS, GPU, CUDA version, etc., check https://pytorch.org/get-started/locally/ for instructions) 39 | ## Quick start 40 | ### 1. Obtain dataset(s) 41 | We currently support the Lyft Level 5 [dataset](https://woven.toyota/en/prediction-dataset) and the nuScenes [dataset](https://www.nuscenes.org/nuscenes). 42 | 43 | #### Lyft Level 5: 44 | * Download the Lyft Prediction dataset (only the metadata and the map) and organize the dataset directory as follows: 45 | ``` 46 | lyft_prediction/ 47 | │ aerial_map/ 48 | │ semantic_map/ 49 | │ meta.json 50 | └───scenes 51 | │ │ sample.zarr 52 | │ │ train_full.zarr 53 | │ │ train.zarr 54 | | | validate.zarr 55 | ``` 56 | 57 | #### nuScenes 58 | * Download the nuScenes dataset (with the v1.3 map extension pack) and organize the dataset directory as follows: 59 | ``` 60 | nuscenes/ 61 | │ maps/ 62 | │ └── expansion/ 63 | │ v1.0-mini/ 64 | │ v1.0-trainval/ 65 | ``` 66 | ### 2. Train a behavior cloning model 67 | Lyft dataset (set `--debug` flag to suppress wandb logging): 68 | ``` 69 | python scripts/train.py --dataset_path --config_name l5_bc --debug 70 | ``` 71 | 72 | nuScenes dataset (set `--debug` flag to suppress wandb logging): 73 | ``` 74 | python scripts/train.py --dataset_path --config_name nusc_bc --debug 75 | ``` 76 | 77 | See the list of registered algorithms in `configs/registry.py` 78 | 79 | ### 3. Train BITS model 80 | 81 | Lyft dataset: 82 | 83 | First train a spatial planner: 84 | ``` 85 | python scripts/train.py --dataset_path --config_name l5_spatial_planner --debug 86 | ``` 87 | Then train a multiagent predictor: 88 | ``` 89 | python scripts/train.py --dataset_path --config_name l5_agent_predictor --debug 90 | ``` 91 | 92 | nuScenes dataset: 93 | First train a spatial planner: 94 | ``` 95 | python scripts/train.py --dataset_path --config_name nusc_spatial_planner --debug 96 | ``` 97 | Then train a multiagent predictor: 98 | ``` 99 | python scripts/train.py --dataset_path --config_name nusc_agent_predictor --debug 100 | ``` 101 | 102 | See the list of registered algorithms in `configs/registry.py` 103 | ### 4. Evaluate a trained model (closed-loop simulation) 104 | ``` 105 | python scripts/evaluate.py \ 106 | --results_root_dir results/ \ 107 | --num_scenes_per_batch 2 \ 108 | --dataset_path \ 109 | --env \ 110 | --policy_ckpt_dir \ 111 | --policy_ckpt_key \ 112 | --eval_class BC \ 113 | --render 114 | ``` 115 | 116 | ### 5. Closed-loop simulation with BITS 117 | With the spatial planner and multiagent predictor trained, one can run BITS simulation with 118 | 119 | ``` 120 | python scripts/evaluate.py \ 121 | --results_root_dir results/ \ 122 | --dataset_path \ 123 | --env \ 124 | --ckpt_yaml \ 125 | --eval_class HierAgentAware \ 126 | --render 127 | ``` 128 | The ckpt_yaml file specifies the checkpoints for the spatial planner and predictor, an example can be found at `evaluation/BITS_example.yaml` with pretrained checkpoints. 129 | 130 | Pretrained checkpoints can be downloaded at [link](https://drive.google.com/drive/folders/1y3_HO1c721pFrFOYeGGjORV58g6zNEds?usp=drive_link). 131 | 132 | You can check the launch.json file if using VS code. 133 | 134 | ### 6. Closed-loop evaluation of policy with BITS 135 | 136 | TBSIM allows the ego to have a separate policy than the rest of the agents. An example command is 137 | 138 | ``` 139 | python scripts/evaluate.py \ 140 | --results_root_dir results/ \ 141 | --dataset_path \ 142 | --env \ 143 | --ckpt_yaml \ 144 | --eval_class \ 145 | --agent_eval_class=HierAgentAware\ 146 | --render 147 | ``` 148 | 149 | Here your policy should be declared in `tbsim/evaluation/policy_composer.py`. 150 | ## BibTeX Citation 151 | 152 | If you use TBSIM in a scientific publication, we would appreciate using the following citations: 153 | 154 | ``` 155 | @inproceedings{xu2023bits, 156 | title={Bits: Bi-level imitation for traffic simulation}, 157 | author={Xu, Danfei and Chen, Yuxiao and Ivanovic, Boris and Pavone, Marco}, 158 | booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)}, 159 | pages={2929--2936}, 160 | year={2023}, 161 | organization={IEEE} 162 | } 163 | ``` 164 | -------------------------------------------------------------------------------- /assets/sample_rollout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/traffic-behavior-simulation/c470538011f15207b0688e4b430f8568c4cbe257/assets/sample_rollout.gif -------------------------------------------------------------------------------- /evaluation/BITS.yaml: -------------------------------------------------------------------------------- 1 | planner: 2 | ckpt_dir: 3 | ckpt_key: 4 | predictor: 5 | ckpt_dir: 6 | ckpt_key: -------------------------------------------------------------------------------- /evaluation/BITS_example.yaml: -------------------------------------------------------------------------------- 1 | planner: 2 | ckpt_dir: checkpoints/nusc_archresnet50_bs50_4130359/run0 3 | ckpt_key: iter41000 4 | predictor: 5 | ckpt_dir: checkpoints/nusc_dynUnicycle_gl0_yrl0_tfTrue_4130553 6 | ckpt_key: 94000 -------------------------------------------------------------------------------- /scripts/generate_config_templates.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpful script to generate example config files for each algorithm. These should be re-generated 3 | when new config options are added, or when default settings in the config classes are modified. 4 | """ 5 | import os 6 | 7 | import tbsim 8 | from tbsim.configs.registry import EXP_CONFIG_REGISTRY 9 | 10 | 11 | def main(): 12 | # store template config jsons in this directory 13 | target_dir = os.path.join(tbsim.__path__[0], "../experiments/templates/") 14 | 15 | for name, cfg in EXP_CONFIG_REGISTRY.items(): 16 | cfg.dump(filename=os.path.join(target_dir, name + ".json")) 17 | 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /scripts/parse_results.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import numpy as np 4 | import os 5 | from pprint import pprint 6 | import torch 7 | import h5py 8 | from trajdata.simulation.sim_stats import calc_stats 9 | import tbsim.utils.tensor_utils as TensorUtils 10 | import pathlib 11 | from pyemd import emd 12 | 13 | 14 | def parse(args): 15 | rjson = json.load(open(os.path.join(args.results_dir, "stats.json"), "r")) 16 | cfg = json.load(open(os.path.join(args.results_dir, "config.json"), "r")) 17 | 18 | results = dict() 19 | for k in rjson: 20 | if k != "scene_index": 21 | if args.num_scenes is None: 22 | rnum = np.mean(rjson[k]) 23 | print("{} = {}".format(k, np.mean(rjson[k]))) 24 | else: 25 | rnum = np.mean(rjson[k][:args.num_scenes]) 26 | print("{} = {}".format(k, rnum)) 27 | results[k] = rnum 28 | 29 | hist_stats_fn = os.path.join(args.results_dir, "hist_stats.json") 30 | if not os.path.exists(hist_stats_fn): 31 | compute_and_save_stats(os.path.join(args.results_dir, "data.hdf5")) 32 | 33 | print("num_scenes: {}".format(len(rjson["scene_index"]))) 34 | ade = results["ade"] if "ade" in results else results["ADE"] 35 | fde = results["fde"] if "fde" in results else results["FDE"] 36 | 37 | pprint(cfg["ckpt"]) 38 | results_str = [ 39 | ade, 40 | fde, 41 | results["all_failure_any"] * 100, 42 | results["all_failure_coll"] * 100, 43 | results["all_failure_offroad"] * 100, 44 | results["all_diversity"], 45 | results["all_coverage_success"], 46 | results["all_coverage_total"], 47 | results["all_collision_rate_coll_any"] * 100, 48 | results["all_collision_rate_CollisionType.REAR"] * 100, 49 | results["all_collision_rate_CollisionType.FRONT"] * 100, 50 | results["all_collision_rate_CollisionType.SIDE"] * 100, 51 | results["all_off_road_rate_rate"] * 100, 52 | # results["velocity_dist"], 53 | # results["lon_accel_dist"], 54 | # results["lat_accel_dist"], 55 | # results["jerk_dist"] 56 | ] 57 | 58 | results_str = ["{:.3f}".format(r) for r in results_str] 59 | 60 | print(",".join(results_str)) 61 | 62 | 63 | def calc_hist_distance(hist1, hist2, bin_edges): 64 | bins = np.array(bin_edges) 65 | bins_dist = np.abs(bins[:, None] - bins[None, :]) 66 | hist_dist = emd(hist1, hist2, bins_dist) 67 | return hist_dist 68 | 69 | 70 | def compute_and_save_stats(h5_path): 71 | """Compute histogram statistics for a run""" 72 | h5f = h5py.File(h5_path, "r") 73 | bins = { 74 | "velocity": torch.linspace(0, 30, 21), 75 | "lon_accel": torch.linspace(0, 10, 21), 76 | "lat_accel": torch.linspace(0, 10, 21), 77 | "jerk": torch.linspace(0, 20, 21), 78 | } 79 | 80 | sim_stats = dict() 81 | # gt_stats = dict() 82 | ticks = None 83 | 84 | for i, scene_index in enumerate(h5f.keys()): 85 | if i % 10 == 0: 86 | print(i) 87 | scene_data = h5f[scene_index] 88 | sim_pos = scene_data["centroid"] 89 | sim_yaw = scene_data["yaw"][:][:, None] 90 | sim = calc_stats(positions=torch.Tensor(sim_pos), heading=torch.Tensor(sim_yaw), dt=0.1, bins=bins) 91 | 92 | for k in sim: 93 | if k not in sim_stats: 94 | sim_stats[k] = sim[k].hist.long() 95 | else: 96 | sim_stats[k] += sim[k].hist.long() 97 | 98 | if ticks is None: 99 | ticks = dict() 100 | for k in sim: 101 | ticks[k] = sim[k].bin_edges 102 | 103 | for k in sim_stats: 104 | sim_stats[k] = TensorUtils.to_numpy(sim_stats[k] / len(h5f.keys())).tolist() 105 | for k in ticks: 106 | ticks[k] = TensorUtils.to_numpy(ticks[k]).tolist() 107 | 108 | results_path = pathlib.Path(h5_path).parent.resolve() 109 | output_file = os.path.join(results_path, "hist_stats.json") 110 | json.dump({"stats": sim_stats, "ticks": ticks}, open(output_file, "w+"), indent=4) 111 | print("results dumped to {}".format(output_file)) 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = argparse.ArgumentParser() 116 | parser.add_argument( 117 | "--results_dir", 118 | type=str, 119 | default=None, 120 | help="A directory of results files (including config.json and stats.json)" 121 | ) 122 | 123 | parser.add_argument( 124 | "--num_scenes", 125 | type=int, 126 | default=None 127 | ) 128 | 129 | args = parser.parse_args() 130 | 131 | parse(args) -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | 5 | import wandb 6 | import pytorch_lightning as pl 7 | from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger 8 | 9 | from tbsim.utils.log_utils import PrintLogger 10 | import tbsim.utils.train_utils as TrainUtils 11 | from tbsim.utils.env_utils import RolloutCallback 12 | from tbsim.configs.registry import get_registered_experiment_config 13 | from tbsim.datasets.factory import datamodule_factory 14 | from tbsim.utils.config_utils import get_experiment_config_from_file 15 | from tbsim.utils.batch_utils import set_global_batch_type 16 | from tbsim.algos.factory import algo_factory 17 | 18 | 19 | def main(cfg, auto_remove_exp_dir=False, debug=False): 20 | pl.seed_everything(cfg.seed) 21 | 22 | if cfg.env.name == "l5kit": 23 | set_global_batch_type("l5kit") 24 | elif "nusc" in cfg.env.name: 25 | set_global_batch_type("trajdata") 26 | else: 27 | raise NotImplementedError("Env {} is not supported".format(cfg.env.name)) 28 | 29 | print("\n============= New Training Run with Config =============") 30 | print(cfg) 31 | print("") 32 | root_dir, log_dir, ckpt_dir, video_dir, version_key = TrainUtils.get_exp_dir( 33 | exp_name=cfg.name, 34 | output_dir=cfg.root_dir, 35 | save_checkpoints=cfg.train.save.enabled, 36 | auto_remove_exp_dir=auto_remove_exp_dir 37 | ) 38 | 39 | # Save experiment config to the training dir 40 | cfg.dump(os.path.join(root_dir, version_key, "config.json")) 41 | 42 | if cfg.train.logging.terminal_output_to_txt and not debug: 43 | # log stdout and stderr to a text file 44 | logger = PrintLogger(os.path.join(log_dir, "log.txt")) 45 | sys.stdout = logger 46 | sys.stderr = logger 47 | 48 | train_callbacks = [] 49 | 50 | # Training Parallelism 51 | assert cfg.train.parallel_strategy in [ 52 | "dp", 53 | "ddp_spawn", 54 | None, 55 | ] # TODO: look into other strategies 56 | if not cfg.devices.num_gpus > 1: 57 | # Override strategy when training on a single GPU 58 | with cfg.train.unlocked(): 59 | cfg.train.parallel_strategy = None 60 | if cfg.train.parallel_strategy in ["ddp_spawn"]: 61 | with cfg.train.training.unlocked(): 62 | cfg.train.training.batch_size = int( 63 | cfg.train.training.batch_size / cfg.devices.num_gpus 64 | ) 65 | with cfg.train.validation.unlocked(): 66 | cfg.train.validation.batch_size = int( 67 | cfg.train.validation.batch_size / cfg.devices.num_gpus 68 | ) 69 | 70 | # Dataset 71 | datamodule = datamodule_factory( 72 | cls_name=cfg.train.datamodule_class, config=cfg 73 | ) 74 | datamodule.setup() 75 | 76 | # Environment for close-loop evaluation 77 | if cfg.train.rollout.enabled: 78 | # Run rollout at regular intervals 79 | rollout_callback = RolloutCallback( 80 | exp_config=cfg, 81 | every_n_steps=cfg.train.rollout.every_n_steps, 82 | warm_start_n_steps=cfg.train.rollout.warm_start_n_steps, 83 | verbose=True, 84 | save_video=cfg.train.rollout.save_video, 85 | video_dir=video_dir 86 | ) 87 | train_callbacks.append(rollout_callback) 88 | 89 | # Model 90 | model = algo_factory( 91 | config=cfg, 92 | modality_shapes=datamodule.modality_shapes 93 | ) 94 | 95 | # Checkpointing 96 | if cfg.train.validation.enabled and cfg.train.save.save_best_validation: 97 | assert ( 98 | cfg.train.save.every_n_steps > cfg.train.validation.every_n_steps 99 | ), "checkpointing frequency needs to be greater than validation frequency" 100 | for metric_name, metric_key in model.checkpoint_monitor_keys.items(): 101 | print( 102 | "Monitoring metrics {} under alias {}".format(metric_key, metric_name) 103 | ) 104 | ckpt_valid_callback = pl.callbacks.ModelCheckpoint( 105 | dirpath=ckpt_dir, 106 | filename="iter{step}_ep{epoch}_%s{%s:.2f}" % (metric_name, metric_key), 107 | # explicitly spell out metric names, otherwise PL parses '/' in metric names to directories 108 | auto_insert_metric_name=False, 109 | save_top_k=cfg.train.save.best_k, # save the best k models 110 | monitor=metric_key, 111 | mode="min", 112 | every_n_train_steps=cfg.train.save.every_n_steps, 113 | verbose=True, 114 | ) 115 | train_callbacks.append(ckpt_valid_callback) 116 | 117 | if cfg.train.rollout.enabled and cfg.train.save.save_best_rollout: 118 | assert ( 119 | cfg.train.save.every_n_steps > cfg.train.rollout.every_n_steps 120 | ), "checkpointing frequency needs to be greater than rollout frequency" 121 | ckpt_rollout_callback = pl.callbacks.ModelCheckpoint( 122 | dirpath=ckpt_dir, 123 | filename="iter{step}_ep{epoch}_simADE{rollout/metrics_ego_ADE:.2f}", 124 | # explicitly spell out metric names, otherwise PL parses '/' in metric names to directories 125 | auto_insert_metric_name=False, 126 | save_top_k=cfg.train.save.best_k, # save the best k models 127 | monitor="rollout/metrics_ego_ADE", 128 | mode="min", 129 | every_n_train_steps=cfg.train.save.every_n_steps, 130 | verbose=True, 131 | ) 132 | train_callbacks.append(ckpt_rollout_callback) 133 | 134 | # a ckpt monitor to save at fixed interval 135 | ckpt_fixed_callback = pl.callbacks.ModelCheckpoint( 136 | dirpath=ckpt_dir, 137 | filename="iter{step}", 138 | auto_insert_metric_name=False, 139 | save_top_k=-1, 140 | monitor=None, 141 | every_n_train_steps=10000, 142 | verbose=True, 143 | ) 144 | train_callbacks.append(ckpt_fixed_callback) 145 | 146 | # Logging 147 | assert not (cfg.train.logging.log_tb and cfg.train.logging.log_wandb) 148 | logger = None 149 | if debug: 150 | print("Debugging mode, suppress logging.") 151 | elif cfg.train.logging.log_tb: 152 | logger = TensorBoardLogger( 153 | save_dir=root_dir, version=version_key, name=None, sub_dir="logs/" 154 | ) 155 | print("Tensorboard event will be saved at {}".format(logger.log_dir)) 156 | elif cfg.train.logging.log_wandb: 157 | assert ( 158 | "WANDB_APIKEY" in os.environ 159 | ), "Set api key by `export WANDB_APIKEY=`" 160 | apikey = os.environ["WANDB_APIKEY"] 161 | wandb.login(key=apikey) 162 | logger = WandbLogger( 163 | name=cfg.name, project=cfg.train.logging.wandb_project_name, 164 | ) 165 | # record the entire config on wandb 166 | logger.experiment.config.update(cfg.to_dict()) 167 | logger.watch(model=model) 168 | else: 169 | print("WARNING: not logging training stats") 170 | 171 | # Train 172 | trainer = pl.Trainer( 173 | default_root_dir=root_dir, 174 | # checkpointing 175 | enable_checkpointing=cfg.train.save.enabled, 176 | # logging 177 | logger=logger, 178 | # flush_logs_every_n_steps=cfg.train.logging.flush_every_n_steps, 179 | log_every_n_steps=cfg.train.logging.log_every_n_steps, 180 | # training 181 | max_steps=cfg.train.training.num_steps, 182 | # validation 183 | val_check_interval=cfg.train.validation.every_n_steps, 184 | limit_val_batches=cfg.train.validation.num_steps_per_epoch, 185 | # all callbacks 186 | callbacks=train_callbacks, 187 | # device & distributed training setup 188 | gpus=cfg.devices.num_gpus, 189 | strategy=cfg.train.parallel_strategy, 190 | # setting for overfit debugging 191 | # limit_val_batches=0, 192 | # overfit_batches=2 193 | ) 194 | 195 | trainer.fit(model=model, datamodule=datamodule) 196 | 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | 201 | # External config file that overwrites default config 202 | parser.add_argument( 203 | "--config_file", 204 | type=str, 205 | default=None, 206 | help="(optional) path to a config json that will be used to override the default settings. \ 207 | If omitted, default settings are used. This is the preferred way to run experiments.", 208 | ) 209 | 210 | parser.add_argument( 211 | "--config_name", 212 | type=str, 213 | default=None, 214 | help="(optional) create experiment config from a preregistered name (see configs/registry.py)", 215 | ) 216 | # Experiment Name (for tensorboard, saving models, etc.) 217 | parser.add_argument( 218 | "--name", 219 | type=str, 220 | default=None, 221 | help="(optional) if provided, override the experiment name defined in the config", 222 | ) 223 | 224 | parser.add_argument( 225 | "--wandb_project_name", 226 | type=str, 227 | default=None, 228 | help="(optional) if provided, override the wandb project name defined in the config", 229 | ) 230 | 231 | parser.add_argument( 232 | "--dataset_path", 233 | type=str, 234 | default=None, 235 | help="(optional) if provided, override the dataset root path", 236 | ) 237 | 238 | parser.add_argument( 239 | "--output_dir", 240 | type=str, 241 | default=None, 242 | help="Root directory of training output (checkpoints, visualization, tensorboard log, etc.)", 243 | ) 244 | 245 | parser.add_argument( 246 | "--remove_exp_dir", 247 | action="store_true", 248 | help="Whether to automatically remove existing experiment directory of the same name (remember to set this to " 249 | "True to avoid unexpected stall when launching cloud experiments).", 250 | ) 251 | 252 | 253 | 254 | parser.add_argument( 255 | "--debug", action="store_true", help="Debug mode, suppress wandb logging, etc." 256 | ) 257 | 258 | args = parser.parse_args() 259 | 260 | if args.config_name is not None: 261 | default_config = get_registered_experiment_config(args.config_name) 262 | elif args.config_file is not None: 263 | # Update default config with external json file 264 | default_config = get_experiment_config_from_file(args.config_file, locked=False) 265 | else: 266 | raise Exception( 267 | "Need either a config name or a json file to create experiment config" 268 | ) 269 | 270 | if args.name is not None: 271 | default_config.name = args.name 272 | 273 | if args.dataset_path is not None: 274 | default_config.train.dataset_path = args.dataset_path 275 | 276 | if args.output_dir is not None: 277 | default_config.root_dir = os.path.abspath(args.output_dir) 278 | 279 | if args.wandb_project_name is not None: 280 | default_config.train.logging.wandb_project_name = args.wandb_project_name 281 | 282 | 283 | if args.debug: 284 | # Test policy rollout 285 | default_config.train.rollout.every_n_steps = 10 286 | default_config.train.rollout.num_episodes = 1 287 | 288 | # make rollout evaluation config consistent with the rest of the config 289 | if default_config.train.rollout.enabled: 290 | default_config.eval.env = default_config.env.name 291 | assert default_config.algo.eval_class is not None, \ 292 | "Please set an eval_class for {}".format(default_config.algo.name) 293 | default_config.eval.eval_class = default_config.algo.eval_class 294 | default_config.eval.dataset_path = default_config.train.dataset_path 295 | for k in default_config.eval[default_config.eval.env]: # copy env-specific config to the global-level 296 | default_config.eval[k] = default_config.eval[default_config.eval.env][k] 297 | default_config.eval.pop("nusc") 298 | default_config.eval.pop("l5kit") 299 | 300 | default_config.lock() # Make config read-only 301 | main(default_config, auto_remove_exp_dir=args.remove_exp_dir, debug=args.debug) 302 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # read the contents of your README file 4 | from os import path 5 | 6 | this_directory = path.abspath(path.dirname(__file__)) 7 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: 8 | lines = f.readlines() 9 | 10 | # remove images from README 11 | lines = [x for x in lines if ".png" not in x] 12 | long_description = "".join(lines) 13 | 14 | setup( 15 | name="tbsim", 16 | packages=[package for package in find_packages() if package.startswith("tbsim")], 17 | install_requires=[ 18 | "l5kit==1.5.0", 19 | "numpy>=1.19.0", # need to manually update numpy version to (1.21.4) due to conflict with l5kit's requirement 20 | "pytorch-lightning", 21 | "wandb", 22 | "torch>=1.13.1", 23 | "torchvision>=0.11.3", 24 | "pyemd", 25 | "h5py", 26 | "imageio-ffmpeg", 27 | "python-louvain", 28 | "networkx", 29 | "torchtext", # weird pytorch-lightning dependency bug 30 | "pytorch-lightning", 31 | "nuscenes-devkit", 32 | ], 33 | eager_resources=["*"], 34 | include_package_data=True, 35 | python_requires=">=3", 36 | description="Traffic Behavior Simulation", 37 | author="NVIDIA AV Research", 38 | author_email="danfeix@nvidia.com", 39 | version="0.0.1", 40 | long_description=long_description, 41 | long_description_content_type="text/markdown", 42 | ) 43 | -------------------------------------------------------------------------------- /tbsim/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/traffic-behavior-simulation/c470538011f15207b0688e4b430f8568c4cbe257/tbsim/algos/__init__.py -------------------------------------------------------------------------------- /tbsim/algos/factory.py: -------------------------------------------------------------------------------- 1 | """Factory methods for creating models""" 2 | from pytorch_lightning import LightningDataModule 3 | from tbsim.configs.base import ExperimentConfig 4 | 5 | from tbsim.algos.algos import ( 6 | BehaviorCloning, 7 | VAETrafficModel, 8 | DiscreteVAETrafficModel, 9 | BehaviorCloningGC, 10 | SpatialPlanner, 11 | GANTrafficModel, 12 | BehaviorCloningEC, 13 | TreeVAETrafficModel, 14 | SceneTreeTrafficModel, 15 | ScePTTrafficModel, 16 | AgentFormerTrafficModel 17 | ) 18 | 19 | from tbsim.algos.multiagent_algos import ( 20 | MATrafficModel, 21 | ) 22 | 23 | from tbsim.algos.metric_algos import ( 24 | OccupancyMetric 25 | ) 26 | 27 | 28 | 29 | def algo_factory(config: ExperimentConfig, modality_shapes: dict): 30 | """ 31 | A factory for creating training algos 32 | 33 | Args: 34 | config (ExperimentConfig): an ExperimentConfig object, 35 | modality_shapes (dict): a dictionary that maps observation modality names to shapes 36 | 37 | Returns: 38 | algo: pl.LightningModule 39 | """ 40 | algo_config = config.algo 41 | algo_name = algo_config.name 42 | 43 | if algo_name == "bc": 44 | algo = BehaviorCloning(algo_config=algo_config, modality_shapes=modality_shapes) 45 | elif algo_name == "bc_gc": 46 | algo = BehaviorCloningGC(algo_config=algo_config, modality_shapes=modality_shapes) 47 | elif algo_name == "vae": 48 | algo = VAETrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 49 | elif algo_name == "discrete_vae": 50 | algo = DiscreteVAETrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 51 | elif algo_name == "tree": 52 | algo = SceneTreeTrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 53 | elif algo_name == "bc_ec": 54 | algo = BehaviorCloningEC(algo_config=algo_config, modality_shapes=modality_shapes) 55 | elif algo_name == "spatial_planner": 56 | algo = SpatialPlanner(algo_config=algo_config, modality_shapes=modality_shapes) 57 | elif algo_name == "occupancy": 58 | algo = OccupancyMetric(algo_config=algo_config, modality_shapes=modality_shapes) 59 | elif algo_name == "agent_predictor": 60 | algo = MATrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 61 | elif algo_name == "gan": 62 | algo = GANTrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 63 | elif algo_name == "scept": 64 | algo = ScePTTrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 65 | elif algo_name == "agentformer": 66 | algo = AgentFormerTrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 67 | else: 68 | raise NotImplementedError("{} is not a valid algorithm" % algo_name) 69 | return algo 70 | -------------------------------------------------------------------------------- /tbsim/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from tbsim.configs.base import ExperimentConfig 2 | 3 | -------------------------------------------------------------------------------- /tbsim/configs/base.py: -------------------------------------------------------------------------------- 1 | from tbsim.configs.config import Dict 2 | from copy import deepcopy 3 | from tbsim.configs.eval_config import TrainTimeEvaluationConfig 4 | 5 | 6 | class TrainConfig(Dict): 7 | def __init__(self): 8 | super(TrainConfig, self).__init__() 9 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 10 | self.logging.log_tb = False # enable tensorboard logging 11 | self.logging.log_wandb = True # enable wandb logging 12 | self.logging.wandb_project_name = "tbsim" 13 | self.logging.log_every_n_steps = 10 14 | self.logging.flush_every_n_steps = 100 15 | 16 | ## save config - if and when to save model checkpoints ## 17 | self.save.enabled = True # whether model saving should be enabled or disabled 18 | self.save.every_n_steps = 100 # save model every n epochs 19 | self.save.best_k = 5 20 | self.save.save_best_rollout = False 21 | self.save.save_best_validation = True 22 | 23 | ## evaluation rollout config ## 24 | self.rollout.save_video = True 25 | self.rollout.enabled = False # enable evaluation rollouts 26 | self.rollout.every_n_steps = 1000 # do rollouts every @rate epochs 27 | self.rollout.warm_start_n_steps = 1 # number of steps to wait before starting rollouts 28 | 29 | 30 | ## training config 31 | self.training.batch_size = 100 32 | self.training.num_steps = 200000 33 | self.training.num_data_workers = 0 34 | 35 | ## validation config 36 | self.validation.enabled = True 37 | self.validation.batch_size = 100 38 | self.validation.num_data_workers = 0 39 | self.validation.every_n_steps = 1000 40 | self.validation.num_steps_per_epoch = 100 41 | 42 | ## Training parallelism (e.g., multi-GPU) 43 | self.parallel_strategy = "ddp_spawn" 44 | 45 | 46 | 47 | class EnvConfig(Dict): 48 | def __init__(self): 49 | super(EnvConfig, self).__init__() 50 | self.name = "my_env" 51 | 52 | 53 | class AlgoConfig(Dict): 54 | def __init__(self): 55 | super(AlgoConfig, self).__init__() 56 | self.name = "my_algo" 57 | 58 | 59 | class ExperimentConfig(Dict): 60 | def __init__( 61 | self, 62 | train_config: TrainConfig, 63 | env_config: EnvConfig, 64 | algo_config: AlgoConfig, 65 | eval_config: TrainTimeEvaluationConfig = None, 66 | registered_name: str = None, 67 | ): 68 | """ 69 | 70 | Args: 71 | train_config (TrainConfig): training config 72 | env_config (EnvConfig): environment config 73 | algo_config (AlgoConfig): algorithm config 74 | registered_name (str): name of the experiment config object in the global config registry 75 | """ 76 | super(ExperimentConfig, self).__init__() 77 | self.registered_name = registered_name 78 | 79 | self.train = train_config 80 | self.env = env_config 81 | self.algo = algo_config 82 | self.eval = TrainTimeEvaluationConfig() if eval_config is None else eval_config 83 | 84 | # Write all results to this directory. A new folder with the timestamp will be created 85 | # in this directory, and it will contain three subfolders - "log", "models", and "videos". 86 | # The "log" directory will contain tensorboard and stdout txt logs. The "models" directory 87 | # will contain saved model checkpoints. The "videos" directory contains evaluation rollout 88 | # videos. 89 | self.name = ( 90 | "test" # name of the experiment (creates a subdirectory under root_dir) 91 | ) 92 | 93 | self.root_dir = "{}_trained_models/".format(self.algo.name) 94 | self.seed = 1 # seed for everything (for reproducibility) 95 | 96 | self.devices.num_gpus = 1 # Set to 0 to use CPU 97 | 98 | def clone(self): 99 | return self.__class__( 100 | train_config=deepcopy(self.train), 101 | env_config=deepcopy(self.env), 102 | algo_config=deepcopy(self.algo), 103 | eval_config=deepcopy(self.eval), 104 | registered_name=self.registered_name, 105 | ) 106 | -------------------------------------------------------------------------------- /tbsim/configs/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic config class - provides a convenient way to work with nested 3 | dictionaries (by exposing keys as attributes) and to save / load from jsons. 4 | 5 | Based on addict: https://github.com/mewwts/addict 6 | """ 7 | 8 | import json 9 | import copy 10 | import contextlib 11 | from copy import deepcopy 12 | 13 | 14 | class Dict(dict): 15 | 16 | def __init__(__self, *args, **kwargs): 17 | object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) 18 | object.__setattr__(__self, '__key', kwargs.pop('__key', None)) 19 | object.__setattr__(__self, '__frozen', False) 20 | for arg in args: 21 | if not arg: 22 | continue 23 | elif isinstance(arg, dict): 24 | for key, val in arg.items(): 25 | __self[key] = __self._hook(val) 26 | elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): 27 | __self[arg[0]] = __self._hook(arg[1]) 28 | else: 29 | for key, val in iter(arg): 30 | __self[key] = __self._hook(val) 31 | 32 | for key, val in kwargs.items(): 33 | __self[key] = __self._hook(val) 34 | 35 | def __setattr__(self, name, value): 36 | if hasattr(self.__class__, name): 37 | raise AttributeError("'Dict' object attribute " 38 | "'{0}' is read-only".format(name)) 39 | else: 40 | self[name] = value 41 | 42 | def __setitem__(self, name, value): 43 | isFrozen = (hasattr(self, '__frozen') and 44 | object.__getattribute__(self, '__frozen')) 45 | if isFrozen and name not in super(Dict, self).keys(): 46 | raise KeyError(name) 47 | super(Dict, self).__setitem__(name, value) 48 | try: 49 | p = object.__getattribute__(self, '__parent') 50 | key = object.__getattribute__(self, '__key') 51 | except AttributeError: 52 | p = None 53 | key = None 54 | if p is not None: 55 | p[key] = self 56 | object.__delattr__(self, '__parent') 57 | object.__delattr__(self, '__key') 58 | 59 | def __add__(self, other): 60 | if not self.keys(): 61 | return other 62 | else: 63 | self_type = type(self).__name__ 64 | other_type = type(other).__name__ 65 | msg = "unsupported operand type(s) for +: '{}' and '{}'" 66 | raise TypeError(msg.format(self_type, other_type)) 67 | 68 | @classmethod 69 | def _hook(cls, item): 70 | if isinstance(item, dict): 71 | return cls(item) 72 | elif isinstance(item, (list, tuple)): 73 | return type(item)(cls._hook(elem) for elem in item) 74 | return item 75 | 76 | def __getattr__(self, item): 77 | return self.__getitem__(item) 78 | 79 | def __missing__(self, name): 80 | if object.__getattribute__(self, '__frozen'): 81 | raise KeyError(name) 82 | return Dict(__parent=self, __key=name) 83 | 84 | def __delattr__(self, name): 85 | del self[name] 86 | 87 | def __repr__(self): 88 | json_string = json.dumps(self.to_dict(), indent=4) 89 | return json_string 90 | 91 | def to_dict(self): 92 | base = {} 93 | for key, value in self.items(): 94 | if isinstance(value, type(self)): 95 | base[key] = value.to_dict() 96 | elif isinstance(value, (list, tuple)): 97 | base[key] = type(value)( 98 | item.to_dict() if isinstance(item, type(self)) else 99 | item for item in value) 100 | else: 101 | base[key] = value 102 | return base 103 | 104 | def copy(self): 105 | return copy.copy(self) 106 | 107 | def deepcopy(self): 108 | return copy.deepcopy(self) 109 | 110 | def __deepcopy__(self, memo): 111 | other = self.__class__() 112 | memo[id(self)] = other 113 | for key, value in self.items(): 114 | other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) 115 | return other 116 | 117 | def update(self, *args, **kwargs): 118 | other = {} 119 | if args: 120 | if len(args) > 1: 121 | raise TypeError() 122 | other.update(args[0]) 123 | other.update(kwargs) 124 | for k, v in other.items(): 125 | if ((k not in self) or 126 | (not isinstance(self[k], dict)) or 127 | (not isinstance(v, dict))): 128 | self[k] = v 129 | else: 130 | self[k].update(v) 131 | 132 | def __getnewargs__(self): 133 | return tuple(self.items()) 134 | 135 | def __getstate__(self): 136 | return self 137 | 138 | def __setstate__(self, state): 139 | self.update(state) 140 | 141 | def __or__(self, other): 142 | if not isinstance(other, (Dict, dict)): 143 | return NotImplemented 144 | new = Dict(self) 145 | new.update(other) 146 | return new 147 | 148 | def __ror__(self, other): 149 | if not isinstance(other, (Dict, dict)): 150 | return NotImplemented 151 | new = Dict(other) 152 | new.update(self) 153 | return new 154 | 155 | def __ior__(self, other): 156 | self.update(other) 157 | return self 158 | 159 | def setdefault(self, key, default=None): 160 | if key in self: 161 | return self[key] 162 | else: 163 | self[key] = default 164 | return default 165 | 166 | def lock(self, should_lock=True): 167 | object.__setattr__(self, '__frozen', should_lock) 168 | for key, val in self.items(): 169 | if isinstance(val, Dict): 170 | val.lock(should_lock) 171 | 172 | def dump(self, filename = None): 173 | json_string = json.dumps(self.to_dict(), indent=4) 174 | if filename is not None: 175 | f = open(filename, "w") 176 | f.write(json_string) 177 | f.close() 178 | return json_string 179 | 180 | def unlock(self): 181 | self.lock(False) 182 | 183 | @contextlib.contextmanager 184 | def unlocked(self): 185 | self.unlock() 186 | yield 187 | self.lock() 188 | 189 | def clone(self): 190 | return deepcopy(self) -------------------------------------------------------------------------------- /tbsim/configs/eval_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | 4 | from tbsim.configs.config import Dict 5 | 6 | 7 | class EvaluationConfig(Dict): 8 | def __init__(self): 9 | super(EvaluationConfig, self).__init__() 10 | self.name = None 11 | self.env = "nusc" # [l5kit, nusc] 12 | self.dataset_path = None 13 | self.eval_class = "" 14 | self.seed = 0 15 | self.num_scenes_per_batch = 2 16 | self.num_scenes_to_evaluate = 2 17 | 18 | self.num_episode_repeats = 1 19 | self.start_frame_index_each_episode = None # if specified, should be the same length as num_episode_repeats 20 | self.seed_each_episode = None # if specified, should be the same length as num_episode_repeats 21 | 22 | self.ego_only = False 23 | self.agent_eval_class = None 24 | 25 | self.ckpt_root_dir = "checkpoints/" 26 | self.experience_hdf5_path = None 27 | self.results_dir = "results/" 28 | 29 | 30 | self.ckpt.policy.ckpt_dir = None 31 | self.ckpt.policy.ckpt_key = None 32 | 33 | self.ckpt.planner.ckpt_dir = None 34 | self.ckpt.planner.ckpt_key = None 35 | 36 | self.ckpt.predictor.ckpt_dir = None 37 | self.ckpt.predictor.ckpt_key = None 38 | 39 | self.ckpt.cvae_metric.ckpt_dir = None 40 | self.ckpt.cvae_metric.ckpt_key = None 41 | 42 | self.ckpt.occupancy_metric.ckpt_dir = None 43 | self.ckpt.occupancy_metric.ckpt_key = None 44 | 45 | self.policy.mask_drivable = True 46 | self.policy.num_plan_samples = 50 47 | self.policy.num_action_samples = 10 48 | self.policy.pos_to_yaw = True 49 | self.policy.yaw_correction_speed = 1.0 50 | self.policy.diversification_clearance = None 51 | self.policy.sample = True 52 | 53 | 54 | self.policy.cost_weights.collision_weight = 15.0 55 | self.policy.cost_weights.lane_weight = 1.0 56 | self.policy.cost_weights.lane_dir_weight = 1.0 57 | self.policy.cost_weights.likelihood_weight = 0.0 # 0.1 58 | self.policy.cost_weights.progress_weight = 0.01 # 0.005 59 | 60 | self.metrics.compute_analytical_metrics = True 61 | self.metrics.compute_learned_metrics = False 62 | 63 | self.perturb.enabled = False 64 | self.perturb.OU.theta = 0.8 65 | self.perturb.OU.sigma = [0.0, 0.1,0.2,0.5,1.0,2.0,4.0] 66 | self.perturb.OU.scale = [1.0,1.0,0.2] 67 | 68 | self.rolling_perturb.enabled = False 69 | self.rolling_perturb.OU.theta = 0.8 70 | self.rolling_perturb.OU.sigma = 0.5 71 | self.rolling_perturb.OU.scale = [1.0,1.0,0.2] 72 | 73 | self.occupancy.rolling = True 74 | self.occupancy.rolling_horizon = [5,10,20] 75 | 76 | self.cvae.rolling = True 77 | self.cvae.rolling_horizon = [5,10,20] 78 | 79 | self.nusc.eval_scenes = np.arange(0,100).tolist() 80 | self.nusc.n_step_action = 2 81 | self.nusc.num_simulation_steps = 200 82 | self.nusc.skip_first_n = 0 83 | 84 | self.drivesim.eval_scenes = np.arange(0,40).tolist() 85 | self.drivesim.n_step_action = 3 86 | self.drivesim.num_simulation_steps = 200 87 | self.drivesim.skip_first_n = 9 88 | 89 | self.l5kit.eval_scenes = [9058, 5232, 14153, 8173, 10314, 7027, 9812, 1090, 9453, 978, 10263, 874, 5563, 9613, 261, 2826, 2175, 9977, 6423, 1069, 1836, 8198, 5034, 6016, 2525, 927, 3634, 11806, 4911, 6192, 11641, 461, 142, 15493, 4919, 8494, 14572, 2402, 308, 1952, 13287, 15614, 6529, 12, 11543, 4558, 489, 6876, 15279, 6095, 5877, 8928, 10599, 16150, 11296, 9382, 13352, 1794, 16122, 12429, 15321, 8614, 12447, 4502, 13235, 2919, 15893, 12960, 7043, 9278, 952, 4699, 768, 13146, 8827, 16212, 10777, 15885, 11319, 9417, 14092, 14873, 6740, 11847, 15331, 15639, 11361, 14784, 13448, 10124, 4872, 3567, 5543, 2214, 7624, 10193, 7297, 1308, 3951, 14001] 90 | self.l5kit.n_step_action = 5 91 | self.l5kit.num_simulation_steps = 200 92 | self.l5kit.skip_first_n = 1 93 | self.l5kit.skimp_rollout = False 94 | 95 | self.adjustment.enabled = False 96 | self.adjustment.random_init_plan=False 97 | self.adjustment.remove_existing_neighbors = False 98 | self.adjustment.initial_num_neighbors = 4 99 | self.adjustment.num_frame_per_new_agent = 2000 100 | 101 | def clone(self): 102 | return deepcopy(self) 103 | 104 | 105 | class TrainTimeEvaluationConfig(EvaluationConfig): 106 | def __init__(self): 107 | super(TrainTimeEvaluationConfig, self).__init__() 108 | 109 | self.num_scenes_per_batch = 4 110 | self.nusc.eval_scenes = np.arange(0, 100, 10).tolist() 111 | self.l5kit.eval_scenes = self.l5kit.eval_scenes[:20] 112 | 113 | self.policy.sample = False 114 | -------------------------------------------------------------------------------- /tbsim/configs/l5kit_config.py: -------------------------------------------------------------------------------- 1 | from tbsim.configs.base import TrainConfig, EnvConfig 2 | 3 | 4 | class L5KitTrainConfig(TrainConfig): 5 | def __init__(self): 6 | super(L5KitTrainConfig, self).__init__() 7 | 8 | self.dataset_path = "/home/yuxiaoc/repos/l5kit/prediction-dataset" 9 | self.dataset_valid_key = "scenes/validate.zarr" 10 | self.dataset_train_key = "scenes/train.zarr" 11 | self.dataset_meta_key = "meta.json" 12 | self.datamodule_class = "L5MixedDataModule" 13 | self.dataset_mode = "agents" 14 | 15 | self.rollout.enabled = False 16 | self.rollout.save_video = True 17 | self.rollout.every_n_steps = 5000 18 | 19 | # training config 20 | self.training.batch_size = 100 21 | self.training.num_steps = 100000 22 | self.training.num_data_workers = 8 23 | 24 | self.save.every_n_steps = 1000 25 | self.save.best_k = 10 26 | 27 | # validation config 28 | self.validation.enabled = True 29 | self.validation.batch_size = 32 30 | self.validation.num_data_workers = 6 31 | self.validation.every_n_steps = 500 32 | self.validation.num_steps_per_epoch = 50 33 | 34 | 35 | class L5KitMixedEnvConfig(EnvConfig): 36 | """Vectorized Scene Component + Rasterized Map""" 37 | 38 | def __init__(self): 39 | super(L5KitMixedEnvConfig, self).__init__() 40 | self.name = "l5kit" 41 | # the keys are relative to the dataset environment variable 42 | self.rasterizer.semantic_map_key = "semantic_map/semantic_map.pb" 43 | self.rasterizer.dataset_meta_key = "meta.json" 44 | 45 | # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being 46 | # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents. 47 | self.rasterizer.filter_agents_threshold = 0.5 48 | 49 | # whether to completely disable traffic light faces in the semantic rasterizer 50 | # this disable option is not supported in avsw_semantic 51 | self.rasterizer.disable_traffic_light_faces = False 52 | 53 | self.generate_agent_obs = False 54 | 55 | self.data_generation_params.other_agents_num = 20 56 | self.data_generation_params.max_agents_distance = 50 57 | self.data_generation_params.lane_params.max_num_lanes = 15 58 | self.data_generation_params.lane_params.max_points_per_lane = 5 59 | self.data_generation_params.lane_params.max_points_per_crosswalk = 5 60 | self.data_generation_params.lane_params.max_retrieval_distance_m = 35 61 | self.data_generation_params.lane_params.max_num_crosswalks = 20 62 | self.data_generation_params.rasterize_agents = False 63 | self.data_generation_params.vectorize_agents = True 64 | 65 | # step size of lane interpolation 66 | self.data_generation_params.lane_params.lane_interp_step_size = 5.0 67 | self.data_generation_params.vectorize_lane = True 68 | 69 | self.rasterizer.raster_size = (224, 224) 70 | 71 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 72 | self.rasterizer.pixel_size = (0.5, 0.5) 73 | 74 | # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image. 75 | self.rasterizer.ego_center = (0.25, 0.5) 76 | 77 | self.rasterizer.map_type = "py_semantic" 78 | # self.rasterizer.map_type = "scene_semantic" 79 | 80 | # the keys are relative to the dataset environment variable 81 | self.rasterizer.satellite_map_key = "aerial_map/aerial_map.png" 82 | self.rasterizer.semantic_map_key = "semantic_map/semantic_map.pb" 83 | 84 | # When set to True, the rasterizer will set the raster origin at bottom left, 85 | # i.e. vehicles are driving on the right side of the road. 86 | # With this change, the vertical flipping on the raster used in the visualization code is no longer needed. 87 | # Set it to False for models trained before v1.1.0-25-g3c517f0 (December 2020). 88 | # In that case visualisation will be flipped (we've removed the flip there) but the model's input will be correct. 89 | self.rasterizer.set_origin_to_bottom = True 90 | 91 | # if a tracked agent is closed than this value to ego, it will be controlled 92 | self.simulation.distance_th_far = 50 93 | 94 | # if a new agent is closer than this value to ego, it will be controlled 95 | self.simulation.distance_th_close = 50 96 | 97 | # whether to disable agents that are not returned at start_frame_index 98 | self.simulation.disable_new_agents = False 99 | 100 | # maximum number of simulation steps to run (0.1sec / step) 101 | self.simulation.num_simulation_steps = 50 102 | 103 | # which frame to start an simulation episode with 104 | self.simulation.start_frame_index = 0 105 | 106 | 107 | class L5KitMixedSemanticMapEnvConfig(L5KitMixedEnvConfig): 108 | def __init__(self): 109 | super(L5KitMixedSemanticMapEnvConfig, self).__init__() 110 | self.rasterizer.map_type = "py_semantic" 111 | self.data_generation_params.vectorize_lane = False 112 | self.generate_agent_obs = True 113 | -------------------------------------------------------------------------------- /tbsim/configs/nusc_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from tbsim.configs.base import TrainConfig, EnvConfig, AlgoConfig 4 | 5 | MAX_POINTS_LANE = 5 6 | 7 | 8 | class NuscTrainConfig(TrainConfig): 9 | def __init__(self): 10 | super(NuscTrainConfig, self).__init__() 11 | 12 | self.trajdata_source_train = "train" 13 | self.trajdata_source_valid = "val" 14 | self.trajdata_source_root = "nusc_trainval" 15 | 16 | self.dataset_path = "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS" 17 | self.datamodule_class = "UnifiedDataModule" 18 | self.ego_only=False 19 | 20 | self.rollout.enabled = False 21 | self.rollout.save_video = True 22 | self.rollout.every_n_steps = 5000 23 | 24 | # training config 25 | self.training.batch_size = 120 26 | self.training.num_steps = 200000 27 | self.training.num_data_workers = 8 28 | 29 | self.save.every_n_steps = 100 30 | self.save.best_k = 10 31 | 32 | 33 | # validation config 34 | self.validation.enabled = True 35 | self.validation.batch_size = 32 36 | self.validation.num_data_workers = 6 37 | self.validation.every_n_steps = 60 38 | self.validation.num_steps_per_epoch = 50 39 | 40 | 41 | class NuscEnvConfig(EnvConfig): 42 | def __init__(self): 43 | super(NuscEnvConfig, self).__init__() 44 | 45 | self.name = "nusc_trainval" 46 | 47 | # raster image size [pixels] 48 | self.rasterizer.raster_size = 224 49 | 50 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 51 | self.rasterizer.pixel_size = 0.5 52 | 53 | # where the agent is on the map, (0.0, 0.0) is the center 54 | # WARNING: this should not be changed before resolving TODO in parse_trajdata_batch() in trajdata_utils.py 55 | self.rasterizer.ego_center = (-0.5, 0.0) 56 | 57 | # maximum number of agents to consider during training 58 | self.data_generation_params.other_agents_num = 20 59 | 60 | self.data_generation_params.max_agents_distance = 30 61 | 62 | # correct for yaw (zero-out delta yaw) when speed is lower than this threshold 63 | self.data_generation_params.yaw_correction_speed = 1.0 64 | 65 | self.simulation.distance_th_close = 30 66 | 67 | # maximum number of simulation steps to run (0.1sec / step) 68 | self.simulation.num_simulation_steps = 50 69 | 70 | # which frame to start an simulation episode with 71 | self.simulation.start_frame_index = 0 72 | 73 | # whether to get lane information 74 | self.simulation.vectorize_lane = "ego" 75 | 76 | # whether include neighbor map patches 77 | self.incl_neighbor_map = False -------------------------------------------------------------------------------- /tbsim/configs/registry.py: -------------------------------------------------------------------------------- 1 | """A global registry for looking up named experiment configs""" 2 | from tbsim.configs.base import ExperimentConfig 3 | 4 | from tbsim.configs.l5kit_config import ( 5 | L5KitTrainConfig, 6 | L5KitMixedEnvConfig, 7 | L5KitMixedSemanticMapEnvConfig, 8 | ) 9 | 10 | from tbsim.configs.nusc_config import ( 11 | NuscTrainConfig, 12 | NuscEnvConfig 13 | ) 14 | 15 | from tbsim.configs.algo_config import ( 16 | AgentFormerConfig, 17 | BehaviorCloningConfig, 18 | BehaviorCloningECConfig, 19 | SpatialPlannerConfig, 20 | BehaviorCloningGCConfig, 21 | AgentPredictorConfig, 22 | VAEConfig, 23 | EBMMetricConfig, 24 | GANConfig, 25 | DiscreteVAEConfig, 26 | TreeAlgoConfig, 27 | OccupancyMetricConfig, 28 | UnetConfig, 29 | ScePTConfig, 30 | SQPMPCConfig, 31 | ) 32 | 33 | 34 | EXP_CONFIG_REGISTRY = dict() 35 | 36 | EXP_CONFIG_REGISTRY["l5_bc"] = ExperimentConfig( 37 | train_config=L5KitTrainConfig(), 38 | env_config=L5KitMixedSemanticMapEnvConfig(), 39 | algo_config=BehaviorCloningConfig(), 40 | registered_name="l5_bc", 41 | ) 42 | 43 | EXP_CONFIG_REGISTRY["l5_gan"] = ExperimentConfig( 44 | train_config=L5KitTrainConfig(), 45 | env_config=L5KitMixedSemanticMapEnvConfig(), 46 | algo_config=GANConfig(), 47 | registered_name="l5_gan", 48 | ) 49 | 50 | EXP_CONFIG_REGISTRY["l5_bc_gc"] = ExperimentConfig( 51 | train_config=L5KitTrainConfig(), 52 | env_config=L5KitMixedSemanticMapEnvConfig(), 53 | algo_config=BehaviorCloningGCConfig(), 54 | registered_name="l5_bc_gc", 55 | ) 56 | 57 | EXP_CONFIG_REGISTRY["l5_spatial_planner"] = ExperimentConfig( 58 | train_config=L5KitTrainConfig(), 59 | env_config=L5KitMixedSemanticMapEnvConfig(), 60 | algo_config=SpatialPlannerConfig(), 61 | registered_name="l5_spatial_planner", 62 | ) 63 | 64 | EXP_CONFIG_REGISTRY["l5_agent_predictor"] = ExperimentConfig( 65 | train_config=L5KitTrainConfig(), 66 | env_config=L5KitMixedSemanticMapEnvConfig(), 67 | algo_config=AgentPredictorConfig(), 68 | registered_name="l5_agent_predictor" 69 | ) 70 | 71 | EXP_CONFIG_REGISTRY["l5_vae"] = ExperimentConfig( 72 | train_config=L5KitTrainConfig(), 73 | env_config=L5KitMixedSemanticMapEnvConfig(), 74 | algo_config=VAEConfig(), 75 | registered_name="l5_vae", 76 | ) 77 | 78 | EXP_CONFIG_REGISTRY["l5_bc_ec"] = ExperimentConfig( 79 | train_config=L5KitTrainConfig(), 80 | env_config=L5KitMixedSemanticMapEnvConfig(), 81 | algo_config=BehaviorCloningECConfig(), 82 | registered_name="l5_bc_ec", 83 | ) 84 | 85 | EXP_CONFIG_REGISTRY["l5_discrete_vae"] = ExperimentConfig( 86 | train_config=L5KitTrainConfig(), 87 | env_config=L5KitMixedSemanticMapEnvConfig(), 88 | algo_config=DiscreteVAEConfig(), 89 | registered_name="l5_discrete_vae", 90 | ) 91 | 92 | EXP_CONFIG_REGISTRY["l5_tree"] = ExperimentConfig( 93 | train_config=L5KitTrainConfig(), 94 | env_config=L5KitMixedSemanticMapEnvConfig(), 95 | algo_config=TreeAlgoConfig(), 96 | registered_name="l5_tree", 97 | ) 98 | 99 | 100 | EXP_CONFIG_REGISTRY["l5_ebm"] = ExperimentConfig( 101 | train_config=L5KitTrainConfig(), 102 | env_config=L5KitMixedSemanticMapEnvConfig(), 103 | algo_config=EBMMetricConfig(), 104 | registered_name="l5_ebm", 105 | ) 106 | 107 | EXP_CONFIG_REGISTRY["l5_occupancy"] = ExperimentConfig( 108 | train_config=L5KitTrainConfig(), 109 | env_config=L5KitMixedSemanticMapEnvConfig(), 110 | algo_config=OccupancyMetricConfig(), 111 | registered_name="l5_occupancy" 112 | ) 113 | 114 | EXP_CONFIG_REGISTRY["nusc_bc"] = ExperimentConfig( 115 | train_config=NuscTrainConfig(), 116 | env_config=NuscEnvConfig(), 117 | algo_config=BehaviorCloningConfig(), 118 | registered_name="nusc_bc" 119 | ) 120 | 121 | EXP_CONFIG_REGISTRY["nusc_bc_gc"] = ExperimentConfig( 122 | train_config=NuscTrainConfig(), 123 | env_config=NuscEnvConfig(), 124 | algo_config=BehaviorCloningGCConfig(), 125 | registered_name="nusc_bc_gc" 126 | ) 127 | 128 | EXP_CONFIG_REGISTRY["nusc_spatial_planner"] = ExperimentConfig( 129 | train_config=NuscTrainConfig(), 130 | env_config=NuscEnvConfig(), 131 | algo_config=SpatialPlannerConfig(), 132 | registered_name="nusc_spatial_planner" 133 | ) 134 | 135 | EXP_CONFIG_REGISTRY["nusc_vae"] = ExperimentConfig( 136 | train_config=NuscTrainConfig(), 137 | env_config=NuscEnvConfig(), 138 | algo_config=VAEConfig(), 139 | registered_name="nusc_vae" 140 | ) 141 | 142 | EXP_CONFIG_REGISTRY["nusc_discrete_vae"] = ExperimentConfig( 143 | train_config=NuscTrainConfig(), 144 | env_config=NuscEnvConfig(), 145 | algo_config=DiscreteVAEConfig(), 146 | registered_name="nusc_discrete_vae" 147 | ) 148 | 149 | EXP_CONFIG_REGISTRY["nusc_tree"] = ExperimentConfig( 150 | train_config=NuscTrainConfig(), 151 | env_config=NuscEnvConfig(), 152 | algo_config=TreeAlgoConfig(), 153 | registered_name="nusc_tree" 154 | ) 155 | 156 | EXP_CONFIG_REGISTRY["nusc_diff_stack"] = ExperimentConfig( 157 | train_config=NuscTrainConfig(), 158 | env_config=NuscEnvConfig(), 159 | algo_config=BehaviorCloningConfig(), 160 | registered_name="nusc_diff_stack" 161 | ) 162 | 163 | 164 | EXP_CONFIG_REGISTRY["nusc_agent_predictor"] = ExperimentConfig( 165 | train_config=NuscTrainConfig(), 166 | env_config=NuscEnvConfig(), 167 | algo_config=AgentPredictorConfig(), 168 | registered_name="nusc_agent_predictor" 169 | ) 170 | 171 | EXP_CONFIG_REGISTRY["nusc_gan"] = ExperimentConfig( 172 | train_config=NuscTrainConfig(), 173 | env_config=NuscEnvConfig(), 174 | algo_config=GANConfig(), 175 | registered_name="nusc_gan" 176 | ) 177 | 178 | EXP_CONFIG_REGISTRY["nusc_occupancy"] = ExperimentConfig( 179 | train_config=NuscTrainConfig(), 180 | env_config=NuscEnvConfig(), 181 | algo_config=OccupancyMetricConfig(), 182 | registered_name="nusc_occupancy" 183 | ) 184 | EXP_CONFIG_REGISTRY["nusc_unet"] = ExperimentConfig( 185 | train_config=NuscTrainConfig(), 186 | env_config=NuscEnvConfig(), 187 | algo_config=UnetConfig(), 188 | registered_name="nusc_unet" 189 | ) 190 | 191 | EXP_CONFIG_REGISTRY["nusc_scept"] = ExperimentConfig( 192 | train_config=NuscTrainConfig(), 193 | env_config=NuscEnvConfig(), 194 | algo_config=ScePTConfig(), 195 | registered_name="nusc_scept" 196 | ) 197 | 198 | EXP_CONFIG_REGISTRY["nusc_agentformer"] = ExperimentConfig( 199 | train_config=NuscTrainConfig(), 200 | env_config=NuscEnvConfig(), 201 | algo_config=AgentFormerConfig(), 202 | registered_name="nusc_agentformer" 203 | ) 204 | 205 | EXP_CONFIG_REGISTRY["nusc_MPC"] = ExperimentConfig( 206 | train_config=NuscTrainConfig(), 207 | env_config=NuscEnvConfig(), 208 | algo_config=SQPMPCConfig(), 209 | registered_name="nusc_MPC" 210 | ) 211 | 212 | def get_registered_experiment_config(registered_name): 213 | registered_name = backward_compatible_translate(registered_name) 214 | 215 | if registered_name not in EXP_CONFIG_REGISTRY.keys(): 216 | raise KeyError( 217 | "'{}' is not a registered experiment config please choose from {}".format( 218 | registered_name, list(EXP_CONFIG_REGISTRY.keys()) 219 | ) 220 | ) 221 | return EXP_CONFIG_REGISTRY[registered_name].clone() 222 | 223 | 224 | def backward_compatible_translate(registered_name): 225 | """Try to translate registered name to maintain backward compatibility.""" 226 | translation = { 227 | "l5_mixed_plan": "l5_bc", 228 | "l5_mixed_gc": "l5_bc_gc", 229 | "l5_ma_rasterized_plan": "l5_agent_predictor", 230 | "l5_gan_plan": "l5_gan", 231 | "l5_mixed_ec_plan": "l5_bc_ec", 232 | "l5_mixed_vae_plan": "l5_vae", 233 | "l5_mixed_discrete_vae_plan": "l5_discrete_vae", 234 | "l5_mixed_tree_vae_plan": "l5_tree_vae", 235 | "nusc_rasterized_plan": "nusc_bc", 236 | "nusc_mixed_gc": "nusc_bc_gc", 237 | "nusc_ma_rasterized_plan": "nusc_agent_predictor", 238 | "nusc_gan_plan": "nusc_gan", 239 | "nusc_vae_plan": "nusc_vae", 240 | "nusc_mixed_tree_vae_plan": "nusc_tree", 241 | } 242 | if registered_name in translation: 243 | registered_name = translation[registered_name] 244 | return registered_name -------------------------------------------------------------------------------- /tbsim/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/traffic-behavior-simulation/c470538011f15207b0688e4b430f8568c4cbe257/tbsim/datasets/__init__.py -------------------------------------------------------------------------------- /tbsim/datasets/factory.py: -------------------------------------------------------------------------------- 1 | """DataModule / Dataset factory""" 2 | from tbsim.utils.config_utils import translate_l5kit_cfg, translate_trajdata_cfg 3 | from tbsim.datasets.l5kit_datamodules import L5MixedDataModule, L5RasterizedDataModule 4 | from tbsim.datasets.trajdata_datamodules import UnifiedDataModule 5 | 6 | def datamodule_factory(cls_name: str, config): 7 | """ 8 | A factory for creating pl.DataModule. 9 | 10 | Valid module class names: "L5MixedDataModule", "L5RasterizedDataModule" 11 | Args: 12 | cls_name (str): name of the datamodule class 13 | config (Config): an Experiment config object 14 | **kwargs: any other kwargs needed by the datamodule 15 | 16 | Returns: 17 | A DataModule 18 | """ 19 | if cls_name.startswith("L5"): # TODO: make this less hacky 20 | l5_config = translate_l5kit_cfg(config) 21 | datamodule = eval(cls_name)(l5_config=l5_config, train_config=config.train) 22 | elif cls_name.startswith("Unified"): 23 | trajdata_config = translate_trajdata_cfg(config) 24 | datamodule = eval(cls_name)(data_config=trajdata_config, train_config=config.train) 25 | else: 26 | raise NotImplementedError("{} is not a supported datamodule type".format(cls_name)) 27 | return datamodule -------------------------------------------------------------------------------- /tbsim/datasets/l5kit_datamodules.py: -------------------------------------------------------------------------------- 1 | """Functions and classes for dataset I/O""" 2 | import abc 3 | from collections import OrderedDict 4 | 5 | from typing import Optional 6 | import os 7 | import pytorch_lightning as pl 8 | from torch.utils.data import DataLoader 9 | 10 | from l5kit.rasterization import build_rasterizer 11 | from l5kit.rasterization.rasterizer import Rasterizer 12 | from l5kit.data import LocalDataManager, ChunkedDataset 13 | from l5kit.dataset import EgoDataset, AgentDataset 14 | 15 | from tbsim.configs.base import TrainConfig 16 | from tbsim.l5kit.vectorizer import build_vectorizer 17 | from tbsim.l5kit.l5_ego_dataset import ( 18 | EgoDatasetMixed 19 | ) 20 | 21 | from tbsim.l5kit.l5_agent_dataset import AgentDatasetMixed, AgentDataset 22 | 23 | 24 | class LazyRasterizer(Rasterizer): 25 | """ 26 | Only creates the actual rasterizer when a member function is called. 27 | 28 | A Rasterizer class is non-pickleable, which means that pickle complains about it when we try to do 29 | multi-process training (e.g., multiGPU training). This class is a way to circumvent the issue by only 30 | creating the rasterizer object when it's being used in the spawned processes. 31 | """ 32 | def __init__(self, l5_config, data_manager): 33 | super(LazyRasterizer, self).__init__() 34 | self._l5_config = l5_config 35 | self._dm = data_manager 36 | self._rasterizer = None 37 | 38 | @property 39 | def rasterizer(self): 40 | if self._rasterizer is None: 41 | self._rasterizer = build_rasterizer(self._l5_config, self._dm) 42 | return self._rasterizer 43 | 44 | def rasterize(self, *args, **kwargs): 45 | return self.rasterizer.rasterize(*args, **kwargs) 46 | 47 | def to_rgb(self,*args, **kwargs): 48 | return self.rasterizer.to_rgb(*args, **kwargs) 49 | 50 | def num_channels(self) -> int: 51 | return self.rasterizer.num_channels() 52 | 53 | 54 | class L5BaseDatasetModule(abc.ABC): 55 | pass 56 | 57 | 58 | class L5RasterizedDataModule(pl.LightningDataModule, L5BaseDatasetModule): 59 | def __init__( 60 | self, 61 | l5_config: dict, 62 | train_config: TrainConfig, 63 | ): 64 | super().__init__() 65 | self.train_dataset = None 66 | self.valid_dataset = None 67 | self.env_dataset = None 68 | self.experience_dataset = None # replay buffer 69 | self.rasterizer = None 70 | self._train_config = train_config 71 | self._l5_config = l5_config 72 | self._mode = train_config.dataset_mode 73 | 74 | assert self._mode in ["ego", "agents"] 75 | 76 | @property 77 | def modality_shapes(self): 78 | dm = LocalDataManager(None) 79 | rasterizer = build_rasterizer(self._l5_config, dm) 80 | h, w = self._l5_config["raster_params"]["raster_size"] 81 | return OrderedDict(image=(rasterizer.num_channels(), h, w)) 82 | 83 | def setup(self, stage: Optional[str] = None): 84 | os.environ["L5KIT_DATA_FOLDER"] = os.path.abspath(self._train_config.dataset_path) 85 | dm = LocalDataManager(None) 86 | self.rasterizer = LazyRasterizer(self._l5_config, dm) 87 | 88 | train_zarr = ChunkedDataset(dm.require(self._train_config.dataset_train_key)).open() 89 | valid_zarr = ChunkedDataset(dm.require(self._train_config.dataset_valid_key)).open() 90 | 91 | self.env_dataset = EgoDataset(self._l5_config, valid_zarr, self.rasterizer) 92 | 93 | if self._mode == "ego": 94 | self.train_dataset = EgoDataset(self._l5_config, train_zarr, self.rasterizer) 95 | self.valid_dataset = EgoDataset(self._l5_config, valid_zarr, self.rasterizer) 96 | else: 97 | read_cached_mask = True 98 | self.train_dataset = AgentDataset(self._l5_config, train_zarr, self.rasterizer, read_cached_mask=read_cached_mask) 99 | self.valid_dataset = AgentDataset(self._l5_config, valid_zarr, self.rasterizer, read_cached_mask=read_cached_mask) 100 | 101 | def train_dataloader(self): 102 | return DataLoader( 103 | dataset=self.train_dataset, 104 | shuffle=True, 105 | batch_size=self._train_config.training.batch_size, 106 | num_workers=self._train_config.training.num_data_workers, 107 | drop_last=True, 108 | persistent_workers=True, 109 | ) 110 | 111 | def val_dataloader(self): 112 | return DataLoader( 113 | dataset=self.valid_dataset, 114 | shuffle=True, 115 | batch_size=self._train_config.validation.batch_size, 116 | num_workers=self._train_config.validation.num_data_workers, 117 | drop_last=True, 118 | persistent_workers=True, 119 | ) 120 | 121 | def test_dataloader(self): 122 | pass 123 | 124 | def predict_dataloader(self): 125 | pass 126 | 127 | 128 | class L5MixedDataModule(L5RasterizedDataModule): 129 | def __init__( 130 | self, 131 | l5_config, 132 | train_config: TrainConfig, 133 | ): 134 | super(L5MixedDataModule, self).__init__( 135 | l5_config=l5_config, train_config=train_config) 136 | self.vectorizer = None 137 | 138 | def setup(self, stage: Optional[str] = None): 139 | os.environ["L5KIT_DATA_FOLDER"] = os.path.abspath(self._train_config.dataset_path) 140 | dm = LocalDataManager(None) 141 | self.rasterizer = build_rasterizer(self._l5_config, dm) 142 | self.vectorizer = build_vectorizer(self._l5_config, dm) 143 | 144 | train_zarr = ChunkedDataset(dm.require(self._train_config.dataset_train_key)).open() 145 | valid_zarr = ChunkedDataset(dm.require(self._train_config.dataset_valid_key)).open() 146 | if self._mode == "ego": 147 | self.train_dataset = EgoDatasetMixed(self._l5_config, train_zarr, self.vectorizer, self.rasterizer) 148 | self.valid_dataset = EgoDatasetMixed(self._l5_config, valid_zarr, self.vectorizer, self.rasterizer) 149 | else: 150 | read_cached_mask = True 151 | self.train_dataset = AgentDatasetMixed(self._l5_config, train_zarr, self.vectorizer, self.rasterizer, read_cached_mask=read_cached_mask) 152 | self.valid_dataset = AgentDatasetMixed(self._l5_config, valid_zarr, self.vectorizer, self.rasterizer, read_cached_mask=read_cached_mask) -------------------------------------------------------------------------------- /tbsim/datasets/trajdata_datamodules.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from collections import defaultdict 4 | from torch.utils.data import Dataset 5 | import pytorch_lightning as pl 6 | from torch.utils.data import DataLoader 7 | from tbsim.configs.base import TrainConfig 8 | 9 | from trajdata import AgentBatch, AgentType, UnifiedDataset 10 | 11 | 12 | class UnifiedDataModule(pl.LightningDataModule): 13 | def __init__(self, data_config, train_config: TrainConfig): 14 | super(UnifiedDataModule, self).__init__() 15 | self._data_config = data_config 16 | self._train_config = train_config 17 | self.train_dataset = None 18 | self.valid_dataset = None 19 | 20 | @property 21 | def modality_shapes(self): 22 | # TODO: better way to figure out channel size? 23 | return dict( 24 | image=(3 + self._data_config.history_num_frames + 1, # semantic map + num_history + current 25 | self._data_config.raster_size, 26 | self._data_config.raster_size), 27 | static=(3,self._data_config.raster_size,self._data_config.raster_size), 28 | dynamic=(self._data_config.history_num_frames + 1,self._data_config.raster_size,self._data_config.raster_size) 29 | 30 | ) 31 | 32 | def setup(self, stage = None): 33 | data_cfg = self._data_config 34 | future_sec = data_cfg.future_num_frames * data_cfg.step_time 35 | history_sec = data_cfg.history_num_frames * data_cfg.step_time 36 | neighbor_distance = data_cfg.max_agents_distance 37 | kwargs = dict( 38 | centric = data_cfg.centric, 39 | desired_data=[data_cfg.trajdata_source_train], 40 | desired_dt=data_cfg.step_time, 41 | future_sec=(future_sec, future_sec), 42 | history_sec=(history_sec, history_sec), 43 | data_dirs={ 44 | data_cfg.trajdata_source_root: data_cfg.dataset_path, 45 | }, 46 | only_types=[AgentType.VEHICLE], 47 | agent_interaction_distances=defaultdict(lambda: neighbor_distance), 48 | incl_raster_map=True, 49 | raster_map_params={ 50 | "px_per_m": int(1 / data_cfg.pixel_size), 51 | "map_size_px": data_cfg.raster_size, 52 | "return_rgb": False, 53 | "offset_frac_xy": data_cfg.raster_center, 54 | "original_format": True, 55 | }, 56 | cache_location= "~/.unified_data_cache", 57 | verbose=False, 58 | max_agent_num = 1+data_cfg.other_agents_num, 59 | # max_neighbor_num = data_cfg.other_agents_num, 60 | num_workers=os.cpu_count(), 61 | # ego_only = self._train_config.ego_only, 62 | ) 63 | print(kwargs) 64 | self.train_dataset = UnifiedDataset(**kwargs) 65 | 66 | kwargs["desired_data"] = [data_cfg.trajdata_source_valid] 67 | kwargs["rebuild_cache"] = False 68 | self.valid_dataset = UnifiedDataset(**kwargs) 69 | 70 | def train_dataloader(self): 71 | return DataLoader( 72 | dataset=self.train_dataset, 73 | shuffle=True, 74 | batch_size=self._train_config.training.batch_size, 75 | num_workers=self._train_config.training.num_data_workers, 76 | drop_last=True, 77 | collate_fn=self.train_dataset.get_collate_fn(return_dict=True), 78 | persistent_workers=True if self._train_config.training.num_data_workers>0 else False 79 | 80 | ) 81 | 82 | def val_dataloader(self): 83 | return DataLoader( 84 | dataset=self.valid_dataset, 85 | shuffle=True, 86 | batch_size=self._train_config.validation.batch_size, 87 | num_workers=self._train_config.validation.num_data_workers, 88 | drop_last=True, 89 | collate_fn=self.valid_dataset.get_collate_fn(return_dict=True), 90 | persistent_workers=True if self._train_config.validation.num_data_workers>0 else False 91 | ) 92 | 93 | def test_dataloader(self): 94 | pass 95 | 96 | def predict_dataloader(self): 97 | pass 98 | -------------------------------------------------------------------------------- /tbsim/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from tbsim.dynamics.single_integrator import SingleIntegrator 4 | from tbsim.dynamics.unicycle import Unicycle 5 | from tbsim.dynamics.bicycle import Bicycle 6 | from tbsim.dynamics.double_integrator import DoubleIntegrator 7 | from tbsim.dynamics.base import Dynamics, DynType 8 | 9 | 10 | def get_dynamics_model(dyn_type: Union[str, DynType]): 11 | if dyn_type in ["Unicycle", DynType.UNICYCLE]: 12 | return Unicycle 13 | elif dyn_type == ["SingleIntegrator", DynType.SI]: 14 | return SingleIntegrator 15 | elif dyn_type == ["DoubleIntegrator", DynType.DI]: 16 | return DoubleIntegrator 17 | else: 18 | raise NotImplementedError("Dynamics model {} is not implemented".format(dyn_type)) 19 | -------------------------------------------------------------------------------- /tbsim/dynamics/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math, copy, time 4 | import abc 5 | from copy import deepcopy 6 | 7 | 8 | class DynType: 9 | """ 10 | Holds environment types - one per environment class. 11 | These act as identifiers for different environments. 12 | """ 13 | 14 | UNICYCLE = 1 15 | SI = 2 16 | DI = 3 17 | BICYCLE = 4 18 | 19 | 20 | class Dynamics(abc.ABC): 21 | @abc.abstractmethod 22 | def __init__(self, name, **kwargs): 23 | self._name = name 24 | self.xdim = 4 25 | self.udim = 2 26 | 27 | @abc.abstractmethod 28 | def __call__(self, x, u): 29 | return 30 | 31 | @abc.abstractmethod 32 | def step(self, x, u, dt, bound=True): 33 | return 34 | 35 | @abc.abstractmethod 36 | def name(self): 37 | return self._name 38 | 39 | @abc.abstractmethod 40 | def type(self): 41 | return 42 | 43 | @abc.abstractmethod 44 | def ubound(self, x): 45 | return 46 | 47 | @staticmethod 48 | def state2pos(x): 49 | return 50 | 51 | @staticmethod 52 | def state2yaw(x): 53 | return 54 | 55 | @staticmethod 56 | def get_state(pos,yaw,dt,mask): 57 | return 58 | 59 | def forward_dynamics(self,initial_states: torch.Tensor,actions: torch.Tensor,step_time: float,bound: bool = True,): 60 | """ 61 | Integrate the state forward with initial state x0, action u 62 | Args: 63 | initial_states (Torch.tensor): state tensor of size [B, (A), 4] 64 | actions (Torch.tensor): action tensor of size [B, (A), T, 2] 65 | step_time (float): delta time between steps 66 | Returns: 67 | state tensor of size [B, (A), T, 4] 68 | """ 69 | num_steps = actions.shape[-2] 70 | x = [initial_states] + [None] * num_steps 71 | for t in range(num_steps): 72 | x[t + 1] = self.step(x[t], actions[..., t, :], step_time,bound=bound) 73 | 74 | x = torch.stack(x[1:], dim=-2) 75 | pos = self.state2pos(x) 76 | yaw = self.state2yaw(x) 77 | return x, pos, yaw 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /tbsim/dynamics/bicycle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from tbsim.dynamics.base import Dynamics, DynType 5 | 6 | 7 | def bicycle_model(state, acc, ddh, vehicle_length, dt, max_hdot=math.pi * 2.0, max_s=50.0): 8 | """ 9 | Simple differentiable bicycle model that does not allow reverse 10 | Args: 11 | state (torch.Tensor): a batch of current kinematic state [B, ..., 5] (x, y, yaw, speed, hdot) 12 | acc (torch.Tensor): a batch of acceleration profile [B, ...] (acc) 13 | ddh (torch.Tensor): a batch of heading acceleration profile [B, ...] (heading) 14 | vehicle_length (torch.Tensor): a batch of vehicle length [B, ...] (length) 15 | dt (float): time between steps 16 | max_hdot (float): maximum change of heading (rad/s) 17 | max_s (float): maximum speed (m/s) 18 | 19 | Returns: 20 | New kinematic state (torch.Tensor) 21 | """ 22 | # state: (x, y, h, speed, hdot) 23 | assert state.shape[-1] == 5 24 | newhdot = (state[..., 4] + ddh * dt).clamp(-max_hdot, max_hdot) 25 | newh = state[..., 2] + dt * state[..., 3].abs() / vehicle_length * newhdot 26 | news = (state[..., 3] + acc * dt).clamp(0.0, max_s) # no reverse 27 | newy = state[..., 1] + news * newh.sin() * dt 28 | newx = state[..., 0] + news * newh.cos() * dt 29 | 30 | newstate = torch.empty_like(state) 31 | newstate[..., 0] = newx 32 | newstate[..., 1] = newy 33 | newstate[..., 2] = newh 34 | newstate[..., 3] = news 35 | newstate[..., 4] = newhdot 36 | 37 | return newstate 38 | 39 | 40 | class Bicycle(Dynamics): 41 | 42 | def __init__( 43 | self, 44 | acc_bound=(-10, 8), 45 | ddh_bound=(-math.pi * 2.0, math.pi * 2.0), 46 | max_speed=50.0, 47 | max_hdot=math.pi * 2.0 48 | ): 49 | """ 50 | A simple bicycle dynamics model 51 | Args: 52 | acc_bound (tuple): acceleration bound (m/s^2) 53 | ddh_bound (tuple): angular acceleration bound (rad/s^2) 54 | max_speed (float): maximum speed, must be positive 55 | max_hdot (float): maximum turning speed, must be positive 56 | """ 57 | super(Bicycle, self).__init__(name="bicycle") 58 | self.xdim = 6 59 | self.udim = 2 60 | assert max_speed >= 0 61 | assert max_hdot >= 0 62 | self.acc_bound = acc_bound 63 | self.ddh_bound = ddh_bound 64 | self.max_speed = max_speed 65 | self.max_hdot = max_hdot 66 | 67 | def get_normalized_controls(self, u): 68 | u = torch.sigmoid(u) # normalize to [0, 1] 69 | acc = self.acc_bound[0] + (self.acc_bound[1] - self.acc_bound[0]) * u[..., 0] 70 | ddh = self.ddh_bound[0] + (self.ddh_bound[1] - self.ddh_bound[0]) * u[..., 1] 71 | return acc, ddh 72 | 73 | def get_clipped_controls(self, u): 74 | acc = torch.clip(u[..., 0], self.acc_bound[0], self.acc_bound[1]) 75 | ddh = torch.clip(u[..., 1], self.ddh_bound[0], self.ddh_bound[1]) 76 | return acc, ddh 77 | 78 | def step(self, x, u, dt, normalize=True): 79 | """ 80 | Take a step with the dynamics model 81 | Args: 82 | x (torch.Tensor): current state [B, ..., 6] (x, y, h, speed, dh, veh_length) 83 | u (torch.Tensor): (un-normalized) actions [B, ..., 2] (acc, ddh) 84 | dt (float): time between steps 85 | normalize (bool): whether to normalize the actions 86 | 87 | Returns: 88 | next_x (torch.Tensor): next state after taking the action 89 | """ 90 | assert x.shape[-1] == self.xdim 91 | assert u.shape[:-1] == x.shape[:-1] 92 | assert u.shape[-1] == self.udim 93 | if normalize: 94 | acc, ddh = self.get_normalized_controls(u) 95 | else: 96 | acc, ddh = self.get_clipped_controls(u) 97 | next_x = x.clone() # keep the extent the same 98 | next_x[..., :5] = bicycle_model( 99 | state=x[..., :5], 100 | acc=acc, 101 | ddh=ddh, 102 | vehicle_length=x[..., 5], 103 | dt=dt, 104 | max_hdot=self.max_hdot, 105 | max_s=self.max_speed 106 | ) 107 | return next_x 108 | @staticmethod 109 | def calculate_vel(pos, yaw, dt, mask): 110 | 111 | vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * torch.cos( 112 | yaw[..., 1:, :] 113 | ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * torch.sin( 114 | yaw[..., 1:, :] 115 | ) 116 | # right finite difference velocity 117 | vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) 118 | # left finite difference velocity 119 | vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) 120 | mask_r = torch.roll(mask, 1, dims=-1) 121 | mask_r[..., 0] = False 122 | mask_r = mask_r & mask 123 | 124 | mask_l = torch.roll(mask, -1, dims=-1) 125 | mask_l[..., -1] = False 126 | mask_l = mask_l & mask 127 | vel = ( 128 | (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 129 | + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l 130 | + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r 131 | ) 132 | 133 | return vel 134 | 135 | def type(self): 136 | return DynType.BICYCLE 137 | 138 | def state2pos(self, x): 139 | return x[..., :2] 140 | 141 | def state2yaw(self, x): 142 | return x[..., 2:3] 143 | 144 | def __call__(self, x, u): 145 | pass 146 | 147 | def ubound(self, x): 148 | pass 149 | 150 | def name(self): 151 | return self._name 152 | -------------------------------------------------------------------------------- /tbsim/dynamics/double_integrator.py: -------------------------------------------------------------------------------- 1 | from tbsim.dynamics.base import DynType, Dynamics 2 | from tbsim.utils.math_utils import soft_sat 3 | import torch 4 | import numpy as np 5 | from copy import deepcopy 6 | 7 | 8 | 9 | class DoubleIntegrator(Dynamics): 10 | def __init__(self, name, abound, vbound=None): 11 | self._name = name 12 | self._type = DynType.DI 13 | self.xdim = 2 14 | self.udim = 2 15 | self.cyclic_state = list() 16 | self.vbound = vbound 17 | self.abound = abound 18 | 19 | def __call__(self, x, u): 20 | assert x.shape[:-1] == u.shape[:, -1] 21 | if isinstance(x, np.ndarray): 22 | return np.hstack((x[..., 2:], u)) 23 | elif isinstance(x, torch.Tensor): 24 | return torch.cat((x[..., 2:], u), dim=-1) 25 | else: 26 | raise NotImplementedError 27 | 28 | def step(self, x, u, dt, bound=True): 29 | 30 | if isinstance(x, np.ndarray): 31 | if bound: 32 | lb, ub = self.ubound(x) 33 | u = np.clip(u, lb, ub) 34 | xn = np.hstack( 35 | ((x[..., 2:4] + 0.5 * u * dt) * dt + x[..., 0:2], x[..., 2:4] + u * dt) 36 | ) 37 | elif isinstance(x, torch.Tensor): 38 | if bound: 39 | lb, ub = self.ubound(x) 40 | u = torch.clip(u, min=lb, max=ub) 41 | xn = torch.clone(x) 42 | xn[..., 0:2] += (x[..., 2:4] + 0.5 * u * dt) * dt 43 | xn[..., 2:4] += u * dt 44 | else: 45 | raise NotImplementedError 46 | return xn 47 | 48 | def name(self): 49 | return self._name 50 | 51 | def type(self): 52 | return self._type 53 | 54 | def ubound(self, x): 55 | if self.vbound is None: 56 | if isinstance(x, np.ndarray): 57 | lb = np.ones_like(x[..., 2:]) * self.abound[0] 58 | ub = np.ones_like(x[..., 2:]) * self.abound[1] 59 | 60 | elif isinstance(x, torch.Tensor): 61 | lb = torch.ones_like(x[..., 2:]) * torch.from_numpy( 62 | self.abound[:, 0] 63 | ).to(x.device) 64 | ub = torch.ones_like(x[..., 2:]) * torch.from_numpy( 65 | self.abound[:, 1] 66 | ).to(x.device) 67 | 68 | else: 69 | raise NotImplementedError 70 | else: 71 | if isinstance(x, np.ndarray): 72 | lb = (x[..., 2:] > self.vbound[0]) * self.abound[0] 73 | ub = (x[..., 2:] < self.vbound[1]) * self.abound[1] 74 | 75 | elif isinstance(x, torch.Tensor): 76 | lb = ( 77 | x[..., 2:] > torch.from_numpy(self.vbound[0]).to(x.device) 78 | ) * torch.from_numpy(self.abound[0]).to(x.device) 79 | ub = ( 80 | x[..., 2:] < torch.from_numpy(self.vbound[1]).to(x.device) 81 | ) * torch.from_numpy(self.abound[1]).to(x.device) 82 | else: 83 | raise NotImplementedError 84 | return lb, ub 85 | 86 | @staticmethod 87 | def state2pos(x): 88 | return x[..., 0:2] 89 | 90 | @staticmethod 91 | def state2yaw(x): 92 | # return torch.atan2(x[..., 3:], x[..., 2:3]) 93 | return torch.zeros_like(x[..., 0:1]) 94 | @staticmethod 95 | def inverse_dyn(x,xp,dt): 96 | return (xp[...,2:]-x[...,2:])/dt 97 | @staticmethod 98 | def calculate_vel(pos, yaw, dt, mask): 99 | vel = (pos[...,1:,:]-pos[...,:-1,:])/dt 100 | if isinstance(pos, torch.Tensor): 101 | # right finite difference velocity 102 | vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) 103 | # left finite difference velocity 104 | vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) 105 | mask_r = torch.roll(mask, 1, dims=-1) 106 | mask_r[..., 0] = False 107 | mask_r = mask_r & mask 108 | 109 | mask_l = torch.roll(mask, -1, dims=-1) 110 | mask_l[..., -1] = False 111 | mask_l = mask_l & mask 112 | vel = ( 113 | (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 114 | + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l 115 | + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r 116 | ) 117 | elif isinstance(pos, np.ndarray): 118 | # right finite difference velocity 119 | vel_r = np.concatenate((vel[..., 0:1, :], vel), axis=-2) 120 | # left finite difference velocity 121 | vel_l = np.concatenate((vel, vel[..., -1:, :]), axis=-2) 122 | mask_r = np.roll(mask, 1, axis=-1) 123 | mask_r[..., 0] = False 124 | mask_r = mask_r & mask 125 | mask_l = np.roll(mask, -1, axis=-1) 126 | mask_l[..., -1] = False 127 | mask_l = mask_l & mask 128 | vel = ( 129 | np.expand_dims(mask_l & mask_r,-1) * (vel_r + vel_l) / 2 130 | + np.expand_dims(mask_l & (~mask_r),-1) * vel_l 131 | + np.expand_dims(mask_r & (~mask_l),-1) * vel_r 132 | ) 133 | else: 134 | raise NotImplementedError 135 | return vel 136 | @staticmethod 137 | def get_state(pos,yaw,dt,mask): 138 | vel = DoubleIntegrator.calculate_vel(pos, yaw, dt, mask) 139 | if isinstance(vel,np.ndarray): 140 | return np.concatenate((pos,vel),-1) 141 | elif isinstance(vel,torch.Tensor): 142 | return torch.cat((pos,vel),-1) 143 | 144 | def forward_dynamics(self, 145 | initial_states: torch.Tensor, 146 | actions: torch.Tensor, 147 | step_time: float, 148 | ): 149 | if isinstance(actions, np.ndarray): 150 | actions = np.clip(actions,self.abound[0],self.abound[1]) 151 | delta_v = np.cumsum(actions*step_time,-2) 152 | vel = initial_states[...,np.newaxis,2:]+delta_v 153 | vel = np.clip(vel,self.vbound[0],self.vbound[1]) 154 | delta_xy = np.cumsum(vel*step_time,-2) 155 | xy = initial_states[...,np.newaxis,:2]+delta_xy 156 | 157 | traj = np.concatenate((xy,vel),-1) 158 | elif isinstance(actions,torch.Tensor): 159 | actions = soft_sat(actions,self.abound[0],self.abound[1]) 160 | delta_v = torch.cumsum(actions*step_time,-2) 161 | vel = initial_states[...,2:].unsqueeze(-2)+delta_v 162 | vel = soft_sat(vel,self.vbound[0],self.vbound[1]) 163 | delta_xy = torch.cumsum(vel*step_time,-2) 164 | xy = initial_states[...,:2].unsqueeze(-2)+delta_xy 165 | 166 | traj = torch.cat((xy,vel),-1) 167 | xy = self.state2pos(traj) 168 | yaw = self.state2yaw(traj) 169 | return traj, xy, yaw -------------------------------------------------------------------------------- /tbsim/dynamics/single_integrator.py: -------------------------------------------------------------------------------- 1 | from tbsim.dynamics.base import DynType, Dynamics 2 | import torch 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | 7 | class SingleIntegrator(Dynamics): 8 | def __init__(self, name, vbound=[1.5,1.5]): 9 | self._name = name 10 | self._type = DynType.SI 11 | self.xdim = vbound.shape[0] 12 | self.udim = vbound.shape[0] 13 | self.cyclic_state = list() 14 | self.vbound = np.array(vbound) 15 | 16 | def __call__(self, x, u): 17 | assert x.shape[:-1] == u.shape[:, -1] 18 | 19 | return u 20 | 21 | def step(self, x, u, dt, bound=True): 22 | assert x.shape[:-1] == u.shape[:, -1] 23 | if bound: 24 | lb, ub = self.ubound(x) 25 | if isinstance(x, np.ndarray): 26 | u = np.clip(u, lb, ub) 27 | elif isinstance(x, torch.Tensor): 28 | u = torch.clip(u, min=lb, max=ub) 29 | 30 | return x + u * dt 31 | 32 | def name(self): 33 | return self._name 34 | 35 | def type(self): 36 | return self._type 37 | 38 | def ubound(self, x): 39 | if isinstance(x, np.ndarray): 40 | lb = np.ones_like(x) * self.vbound[:, 0] 41 | ub = np.ones_like(x) * self.vbound[:, 1] 42 | return lb, ub 43 | elif isinstance(x, torch.Tensor): 44 | lb = torch.ones_like(x) * torch.from_numpy(self.vbound[:, 0]) 45 | ub = torch.ones_like(x) * torch.from_numpy(self.vbound[:, 1]) 46 | return lb, ub 47 | else: 48 | raise NotImplementedError 49 | 50 | @staticmethod 51 | def state2pos(x): 52 | return x[..., 0:2] 53 | -------------------------------------------------------------------------------- /tbsim/envs/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class SimulationException(Exception): 5 | pass 6 | 7 | 8 | class BaseEnv(abc.ABC): 9 | """TODO: Make a Simulator MetaClass""" 10 | 11 | @abc.abstractmethod 12 | def reset(self, scene_indices=None, start_frame_index=None): 13 | return 14 | 15 | @abc.abstractmethod 16 | def reset_multi_episodes_metrics(self): 17 | pass 18 | 19 | @abc.abstractmethod 20 | def step(self, action, num_steps_to_take, render): 21 | return 22 | 23 | @abc.abstractmethod 24 | def update_random_seed(self, seed): 25 | return 26 | 27 | @abc.abstractmethod 28 | def get_metrics(self): 29 | return 30 | 31 | @abc.abstractmethod 32 | def get_multi_episode_metrics(self): 33 | return 34 | 35 | @abc.abstractmethod 36 | def render(self, actions_to_take): 37 | return 38 | 39 | @abc.abstractmethod 40 | def get_info(self): 41 | return 42 | 43 | @abc.abstractmethod 44 | def get_observation(self): 45 | return 46 | 47 | @abc.abstractmethod 48 | def get_reward(self): 49 | return 50 | 51 | @abc.abstractmethod 52 | def is_done(self): 53 | return 54 | 55 | @abc.abstractmethod 56 | def get_info(self): 57 | return 58 | 59 | 60 | class BatchedEnv(abc.ABC): 61 | @abc.abstractmethod 62 | def num_instances(self): 63 | return 64 | -------------------------------------------------------------------------------- /tbsim/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/traffic-behavior-simulation/c470538011f15207b0688e4b430f8568c4cbe257/tbsim/evaluation/__init__.py -------------------------------------------------------------------------------- /tbsim/evaluation/metric_composers.py: -------------------------------------------------------------------------------- 1 | from tbsim.algos.algos import ( 2 | DiscreteVAETrafficModel, 3 | ) 4 | 5 | from tbsim.algos.metric_algos import ( 6 | OccupancyMetric, 7 | ) 8 | 9 | import tbsim.envs.env_metrics as EnvMetrics 10 | 11 | from tbsim.utils.batch_utils import batch_utils 12 | from tbsim.utils.config_utils import get_experiment_config_from_file 13 | from tbsim.configs.base import ExperimentConfig 14 | 15 | from tbsim.utils.experiment_utils import get_checkpoint 16 | 17 | try: 18 | from Pplan.Sampling.spline_planner import SplinePlanner 19 | from Pplan.Sampling.trajectory_tree import TrajTree 20 | except ImportError: 21 | print("Cannot import Pplan") 22 | 23 | 24 | class MetricsComposer(object): 25 | """Wrapper for building learned metrics from trained checkpoints.""" 26 | def __init__(self, eval_config, device, ckpt_root_dir="checkpoints/"): 27 | self.device = device 28 | self.ckpt_root_dir = ckpt_root_dir 29 | self.eval_config = eval_config 30 | self._exp_config = None 31 | 32 | def get_modality_shapes(self, exp_cfg: ExperimentConfig): 33 | return batch_utils().get_modality_shapes(exp_cfg) 34 | 35 | def get_metrics(self): 36 | raise NotImplementedError 37 | 38 | 39 | class CVAEMetrics(MetricsComposer): 40 | def get_metrics(self, eval_config, perturbations=None, rolling=False, env="l5kit", **kwargs): 41 | ckpt_path, config_path = get_checkpoint( 42 | ckpt_key=eval_config.ckpt.cvae_metric.ckpt_key, 43 | ckpt_root_dir=self.ckpt_root_dir 44 | ) 45 | 46 | controller_cfg = get_experiment_config_from_file(config_path) 47 | modality_shapes = batch_utils().get_modality_shapes(controller_cfg) 48 | CVAE_model = DiscreteVAETrafficModel.load_from_checkpoint( 49 | ckpt_path, 50 | algo_config=controller_cfg.algo, 51 | modality_shapes=modality_shapes 52 | ).to(self.device).eval() 53 | if not rolling: 54 | return EnvMetrics.LearnedCVAENLL(metric_algo=CVAE_model, perturbations=perturbations) 55 | else: 56 | if "rolling_horizon" in kwargs: 57 | rolling_horizon = kwargs["rolling_horizon"] 58 | else: 59 | rolling_horizon = None 60 | return EnvMetrics.LearnedCVAENLLRolling(metric_algo=CVAE_model, rolling_horizon=rolling_horizon, perturbations=perturbations) 61 | 62 | 63 | class OccupancyMetrics(MetricsComposer): 64 | def get_metrics(self, eval_config, perturbations = None, rolling=False, env="l5kit", **kwargs): 65 | ckpt_path, config_path = get_checkpoint( 66 | ckpt_key=eval_config.ckpt.occupancy_metric.ckpt_key, 67 | ckpt_root_dir=self.ckpt_root_dir 68 | ) 69 | 70 | cfg = get_experiment_config_from_file(config_path) 71 | 72 | modality_shapes = batch_utils().get_modality_shapes(cfg) 73 | occupancy_model = OccupancyMetric.load_from_checkpoint( 74 | ckpt_path, 75 | algo_config=cfg.algo, 76 | modality_shapes=modality_shapes 77 | ).to(self.device).eval() 78 | 79 | if not rolling: 80 | return EnvMetrics.Occupancy_likelihood(metric_algo=occupancy_model, perturbations=perturbations) 81 | else: 82 | if "rolling_horizon" in kwargs: 83 | rolling_horizon = kwargs["rolling_horizon"] 84 | else: 85 | rolling_horizon = None 86 | return EnvMetrics.Occupancy_rolling(metric_algo=occupancy_model, rolling_horizon=rolling_horizon, perturbations=perturbations) 87 | 88 | -------------------------------------------------------------------------------- /tbsim/l5kit/l5_ego_dataset.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from functools import partial 3 | from typing import Callable, Optional, List 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data import Dataset, IterableDataset 8 | 9 | from l5kit.data import ChunkedDataset, get_frames_slice_from_scenes 10 | from l5kit.dataset.utils import convert_str_to_fixed_length_tensor 11 | from l5kit.kinematic import Perturbation 12 | from l5kit.rasterization import Rasterizer, RenderContext 13 | from l5kit.sampling.agent_sampling import generate_agent_sample 14 | from l5kit.sampling.agent_sampling_vectorized import generate_agent_sample_vectorized 15 | from tbsim.l5kit.agent_sampling_mixed import generate_agent_sample_mixed 16 | from tbsim.l5kit.vectorizer import Vectorizer 17 | from l5kit.dataset.ego import BaseEgoDataset 18 | from tbsim.utils.timer import Timers 19 | 20 | 21 | class EgoDatasetMixed(BaseEgoDataset): 22 | def __init__( 23 | self, 24 | cfg: dict, 25 | zarr_dataset: ChunkedDataset, 26 | vectorizer: Vectorizer, 27 | rasterizer: Rasterizer, 28 | perturbation: Optional[Perturbation] = None, 29 | ): 30 | """ 31 | Get a PyTorch dataset object that can be used to train DNNs with vectorized input 32 | 33 | Args: 34 | cfg (dict): configuration file 35 | zarr_dataset (ChunkedDataset): the raw zarr dataset 36 | vectorizer (Vectorizer): a object that supports vectorization around an AV 37 | perturbation (Optional[Perturbation]): an object that takes care of applying trajectory perturbations. 38 | None if not desired 39 | """ 40 | self.perturbation = perturbation 41 | self.vectorizer = vectorizer 42 | self.rasterizer = rasterizer 43 | self.timer = Timers() 44 | self._skimp = False 45 | super().__init__(cfg, zarr_dataset) 46 | 47 | def set_skimp(self, skimp): 48 | self._skimp = skimp 49 | 50 | def is_skimp(self): 51 | return self._skimp 52 | 53 | def _get_sample_function(self) -> Callable[..., dict]: 54 | render_context = RenderContext( 55 | raster_size_px=np.array(self.cfg["raster_params"]["raster_size"]), 56 | pixel_size_m=np.array(self.cfg["raster_params"]["pixel_size"]), 57 | center_in_raster_ratio=np.array(self.cfg["raster_params"]["ego_center"]), 58 | set_origin_to_bottom=self.cfg["raster_params"]["set_origin_to_bottom"], 59 | ) 60 | return partial( 61 | generate_agent_sample_mixed, 62 | render_context=render_context, 63 | history_num_frames_ego=self.cfg["model_params"]["history_num_frames_ego"], 64 | history_num_frames_agents=self.cfg["model_params"][ 65 | "history_num_frames_agents" 66 | ], 67 | future_num_frames=self.cfg["model_params"]["future_num_frames"], 68 | step_time=self.cfg["model_params"]["step_time"], 69 | filter_agents_threshold=self.cfg["raster_params"][ 70 | "filter_agents_threshold" 71 | ], 72 | timer=self.timer, 73 | perturbation=self.perturbation, 74 | vectorizer=self.vectorizer, 75 | rasterizer=self.rasterizer, 76 | skimp_fn=self.is_skimp, 77 | vectorize_lane=self.cfg["data_generation_params"]["vectorize_lane"], 78 | rasterize_agents = self.cfg["data_generation_params"].get("rasterize_agents", False), 79 | vectorize_agents = self.cfg["data_generation_params"].get("vectorize_agents", True), 80 | ) 81 | 82 | def get_scene_dataset(self, scene_index: int) -> "EgoDatasetMixed": 83 | dataset = self.dataset.get_scene_dataset(scene_index) 84 | return EgoDatasetMixed( 85 | self.cfg, 86 | dataset, 87 | self.vectorizer, 88 | self.rasterizer, 89 | self.perturbation, 90 | ) 91 | 92 | def get_frame( 93 | self, scene_index: int, state_index: int, track_id: Optional[int] = None 94 | ) -> dict: 95 | with self.timer.timed("get_frame"): 96 | data = super().get_frame(scene_index, state_index, track_id=track_id) 97 | # TODO (@lberg): this should not be here but in the rasterizer 98 | if "image" in data: 99 | data["image"] = data["image"].transpose(2, 0, 1) # 0,1,C -> C,0,1 100 | if "other_agents_image" in data: 101 | data["other_agents_image"] = data["other_agents_image"].transpose(0,3,1,2) 102 | return data 103 | 104 | 105 | class EgoReplayBufferMixed(Dataset): 106 | """A Dataset class object for wrapping environment interaction episodes""" 107 | def __init__( 108 | self, 109 | cfg, 110 | vectorizer: Vectorizer, 111 | rasterizer: Rasterizer, 112 | capacity=None, 113 | perturbation: Perturbation = None, 114 | ): 115 | super(EgoReplayBufferMixed, self).__init__() 116 | self.cfg = cfg 117 | self.dataset = dict() 118 | self._capacity = capacity 119 | self._active_scenes = [] 120 | 121 | self.perturbation = perturbation 122 | self.vectorizer = vectorizer 123 | self.rasterizer = rasterizer 124 | 125 | self.sample_function = self._get_sample_function() 126 | 127 | def _get_sample_function(self) -> Callable[..., dict]: 128 | render_context = RenderContext( 129 | raster_size_px=np.array(self.cfg["raster_params"]["raster_size"]), 130 | pixel_size_m=np.array(self.cfg["raster_params"]["pixel_size"]), 131 | center_in_raster_ratio=np.array(self.cfg["raster_params"]["ego_center"]), 132 | set_origin_to_bottom=self.cfg["raster_params"]["set_origin_to_bottom"], 133 | ) 134 | return partial( 135 | generate_agent_sample_mixed, 136 | render_context=render_context, 137 | history_num_frames_ego=self.cfg["model_params"]["history_num_frames_ego"], 138 | history_num_frames_agents=self.cfg["model_params"][ 139 | "history_num_frames_agents" 140 | ], 141 | future_num_frames=self.cfg["model_params"]["future_num_frames"], 142 | step_time=self.cfg["model_params"]["step_time"], 143 | filter_agents_threshold=self.cfg["raster_params"][ 144 | "filter_agents_threshold" 145 | ], 146 | perturbation=self.perturbation, 147 | vectorizer=self.vectorizer, 148 | rasterizer=self.rasterizer, 149 | vectorize_lane=self.cfg["data_generation_params"]["vectorize_lane"], 150 | rasterize_agents = self.cfg["data_generation_params"]["rasterize_agents"], 151 | ) 152 | 153 | def append_experience(self, episodes_data: List): 154 | """ 155 | Append list of episodic experience 156 | Args: 157 | episodes_data (list): a list of episodic experiences 158 | 159 | """ 160 | self._active_scenes.extend([d[0] for d in episodes_data]) 161 | for si, ds in episodes_data: 162 | self.dataset[si] = ds 163 | 164 | if self._capacity is not None and len(self._active_scenes) > self._capacity: 165 | n_to_remove = len(self._active_scenes) - self._capacity 166 | for si in self._active_scenes[:n_to_remove]: 167 | self.dataset.pop(si) 168 | self._active_scenes = self._active_scenes[n_to_remove:] 169 | 170 | def _get_scene_indices(self): 171 | fi = dict() 172 | ind = 0 173 | for si in self._active_scenes: 174 | fl = len(self.dataset[si].frames) 175 | fi[si] = (ind, ind + fl) 176 | ind += fl 177 | return fi 178 | 179 | def _get_scene_by_index(self, index): 180 | fi = self._get_scene_indices() 181 | for si, (start, end) in fi.items(): 182 | if start <= index < end: 183 | return si, index - start 184 | raise IndexError("index {} is out of range".format(index)) 185 | 186 | def __len__(self): 187 | return self._get_scene_indices()[self._active_scenes[-1]][1] 188 | 189 | def __getitem__(self, index): 190 | scene_index, state_index = self._get_scene_by_index(index) 191 | dataset = self.dataset[scene_index] 192 | tl_faces = dataset.tl_faces 193 | if self.cfg["raster_params"]["disable_traffic_light_faces"]: 194 | tl_faces = np.empty(0, dtype= dataset.tl_faces.dtype) 195 | data = self.sample_function( 196 | state_index, 197 | dataset.frames, 198 | dataset.agents, 199 | tl_faces, 200 | selected_track_id=None 201 | ) 202 | data["image"] = data["image"].transpose(2, 0, 1) 203 | 204 | # add information only, so that all data keys are always preserved 205 | data["scene_index"] = scene_index 206 | data["track_id"] = np.int64(-1) # always a number to avoid crashing torch 207 | return data 208 | 209 | 210 | class ExperienceIterableWrapper(IterableDataset): 211 | def __init__(self, dataset): 212 | self.dataset = dataset 213 | self._should_update_indices = False 214 | self._indices = None 215 | self._rnd = None 216 | self._curr_index = 0 217 | 218 | def _update_indices(self): 219 | self._indices = np.arange(len(self.dataset)) 220 | self._rnd.shuffle(self._indices) 221 | self._curr_index = 0 222 | 223 | def append_experience(self, episodes_data: List): 224 | self._should_update_indices = True 225 | self.dataset.append_experience(episodes_data) 226 | 227 | def __iter__(self): 228 | while True: 229 | if self._rnd is None: 230 | winfo = torch.utils.data.get_worker_info() 231 | if winfo is not None: 232 | seed = winfo.id 233 | else: 234 | seed = 0 235 | self._rnd = np.random.RandomState(seed=seed) 236 | if self._should_update_indices: 237 | self._update_indices() 238 | self._should_update_indices = False 239 | yield self.dataset[self._curr_index] -------------------------------------------------------------------------------- /tbsim/l5kit/vis_rasterizer.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from enum import IntEnum 3 | from typing import Dict, List, Optional 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from l5kit.data.filter import filter_tl_faces_by_status 9 | from l5kit.data.map_api import InterpolationMethod, MapAPI, TLFacesColors 10 | from l5kit.geometry import rotation33_as_yaw, transform_point, transform_points 11 | from l5kit.rasterization.rasterizer import Rasterizer 12 | from l5kit.rasterization.render_context import RenderContext 13 | 14 | 15 | # sub-pixel drawing precision constants 16 | CV2_SUB_VALUES = {"shift": 9, "lineType": cv2.LINE_AA} 17 | CV2_SHIFT_VALUE = 2 ** CV2_SUB_VALUES["shift"] 18 | INTERPOLATION_POINTS = 20 19 | 20 | 21 | class RasterEls(IntEnum): # map elements 22 | LANE_NOTL = 0 23 | ROAD = 1 24 | CROSSWALK = 2 25 | 26 | 27 | COLORS = { 28 | TLFacesColors.GREEN.name: (0, 255, 0), 29 | TLFacesColors.RED.name: (255, 0, 0), 30 | TLFacesColors.YELLOW.name: (255, 255, 0), 31 | # RasterEls.LANE_NOTL.name: (255, 217, 82), 32 | RasterEls.LANE_NOTL.name: (164, 184, 196), 33 | # RasterEls.ROAD.name: (17, 17, 31), 34 | RasterEls.ROAD.name: (200, 211, 213), 35 | RasterEls.CROSSWALK.name: (96, 117, 138), 36 | # RasterEls.CROSSWALK.name: (255, 117, 69), 37 | # "Background": (37, 40, 61), 38 | "Background": (255, 255, 255), 39 | } 40 | 41 | 42 | def indices_in_bounds(center: np.ndarray, bounds: np.ndarray, half_extent: float) -> np.ndarray: 43 | """ 44 | Get indices of elements for which the bounding box described by bounds intersects the one defined around 45 | center (square with side 2*half_side) 46 | 47 | Args: 48 | center (float): XY of the center 49 | bounds (np.ndarray): array of shape Nx2x2 [[x_min,y_min],[x_max, y_max]] 50 | half_extent (float): half the side of the bounding box centered around center 51 | 52 | Returns: 53 | np.ndarray: indices of elements inside radius from center 54 | """ 55 | x_center, y_center = center 56 | 57 | x_min_in = x_center > bounds[:, 0, 0] - half_extent 58 | y_min_in = y_center > bounds[:, 0, 1] - half_extent 59 | x_max_in = x_center < bounds[:, 1, 0] + half_extent 60 | y_max_in = y_center < bounds[:, 1, 1] + half_extent 61 | return np.nonzero(x_min_in & y_min_in & x_max_in & y_max_in)[0] 62 | 63 | 64 | def cv2_subpixel(coords: np.ndarray) -> np.ndarray: 65 | """ 66 | Cast coordinates to numpy.int but keep fractional part by previously multiplying by 2**CV2_SHIFT 67 | cv2 calls will use shift to restore original values with higher precision 68 | 69 | Args: 70 | coords (np.ndarray): XY coords as float 71 | 72 | Returns: 73 | np.ndarray: XY coords as int for cv2 shift draw 74 | """ 75 | coords = coords * CV2_SHIFT_VALUE 76 | coords = coords.astype(int) 77 | return coords 78 | 79 | 80 | class VisualizationRasterizer(object): 81 | """ 82 | Rasteriser for visualization purposes 83 | """ 84 | 85 | def __init__( 86 | self, render_context: RenderContext, semantic_map_path: str, world_to_ecef: np.ndarray, 87 | ): 88 | self.render_context = render_context 89 | self.raster_size = render_context.raster_size_px 90 | self.pixel_size = render_context.pixel_size_m 91 | self.ego_center = render_context.center_in_raster_ratio 92 | 93 | self.world_to_ecef = world_to_ecef 94 | 95 | self.mapAPI = MapAPI(semantic_map_path, world_to_ecef) 96 | 97 | def rasterize( 98 | self, 99 | ego_translation_m, 100 | ego_yaw_rad, 101 | ) -> np.ndarray: 102 | raster_from_world = self.render_context.raster_from_world(ego_translation_m, ego_yaw_rad) 103 | world_from_raster = np.linalg.inv(raster_from_world) 104 | 105 | # get XY of center pixel in world coordinates 106 | center_in_raster_px = np.asarray(self.raster_size) * (0.5, 0.5) 107 | center_in_world_m = transform_point(center_in_raster_px, world_from_raster) 108 | 109 | sem_im = self.render_semantic_map(center_in_world_m, raster_from_world) 110 | return sem_im.astype(float32) / 255 111 | 112 | def render_semantic_map( 113 | self, center_in_world: np.ndarray, raster_from_world: np.ndarray 114 | ) -> np.ndarray: 115 | """Renders the semantic map at given x,y coordinates. 116 | 117 | Args: 118 | center_in_world (np.ndarray): XY of the image center in world ref system 119 | raster_from_world (np.ndarray): 120 | Returns: 121 | np.ndarray: RGB raster 122 | 123 | """ 124 | 125 | img = np.ones(shape=(self.raster_size[1], self.raster_size[0], 3)) 126 | img *= [[COLORS["Background"]]] 127 | img = img.astype(np.uint8) 128 | 129 | # filter using half a radius from the center 130 | raster_radius = float(np.linalg.norm(self.raster_size * self.pixel_size)) / 2 131 | 132 | # get all lanes as interpolation so that we can transform them all together 133 | 134 | lane_indices = indices_in_bounds(center_in_world, self.mapAPI.bounds_info["lanes"]["bounds"], raster_radius) 135 | lanes_mask: Dict[str, np.ndarray] = defaultdict(lambda: np.zeros(len(lane_indices) * 2, dtype=bool)) 136 | lanes_area = np.zeros((len(lane_indices) * 2, INTERPOLATION_POINTS, 2)) 137 | 138 | for idx, lane_idx in enumerate(lane_indices): 139 | lane_idx = self.mapAPI.bounds_info["lanes"]["ids"][lane_idx] 140 | 141 | # interpolate over polyline to always have the same number of points 142 | lane_coords = self.mapAPI.get_lane_as_interpolation( 143 | lane_idx, INTERPOLATION_POINTS, InterpolationMethod.INTER_ENSURE_LEN 144 | ) 145 | lanes_area[idx * 2] = lane_coords["xyz_left"][:, :2] 146 | lanes_area[idx * 2 + 1] = lane_coords["xyz_right"][::-1, :2] 147 | 148 | lane_type = RasterEls.LANE_NOTL.name 149 | lane_tl_ids = set(self.mapAPI.get_lane_traffic_control_ids(lane_idx)) 150 | # for tl_id in lane_tl_ids.intersection(active_tl_ids): 151 | # for tl_id in lane_tl_ids: 152 | # lane_type = self.mapAPI.get_color_for_face(tl_id) 153 | 154 | lanes_mask[lane_type][idx * 2: idx * 2 + 2] = True 155 | 156 | if len(lanes_area): 157 | lanes_area = cv2_subpixel(transform_points(lanes_area.reshape((-1, 2)), raster_from_world)) 158 | 159 | for lane_area in lanes_area.reshape((-1, INTERPOLATION_POINTS * 2, 2)): 160 | # need to for-loop otherwise some of them are empty 161 | cv2.fillPoly(img, [lane_area], COLORS[RasterEls.ROAD.name], **CV2_SUB_VALUES) 162 | 163 | lanes_area = lanes_area.reshape((-1, INTERPOLATION_POINTS, 2)) 164 | for name, mask in lanes_mask.items(): # draw each type of lane with its own color 165 | cv2.polylines(img, lanes_area[mask], False, COLORS[name], **CV2_SUB_VALUES) 166 | 167 | # plot crosswalks 168 | crosswalks = [] 169 | for idx in indices_in_bounds(center_in_world, self.mapAPI.bounds_info["crosswalks"]["bounds"], raster_radius): 170 | crosswalk = self.mapAPI.get_crosswalk_coords(self.mapAPI.bounds_info["crosswalks"]["ids"][idx]) 171 | xy_cross = cv2_subpixel(transform_points(crosswalk["xyz"][:, :2], raster_from_world)) 172 | crosswalks.append(xy_cross) 173 | 174 | cv2.polylines(img, crosswalks, True, COLORS[RasterEls.CROSSWALK.name], **CV2_SUB_VALUES) 175 | 176 | return img 177 | 178 | def to_rgb(self, in_im: np.ndarray, **kwargs: dict) -> np.ndarray: 179 | return (in_im * 255).astype(np.uint8) 180 | 181 | def num_channels(self) -> int: 182 | return 3 183 | -------------------------------------------------------------------------------- /tbsim/models/GAN_regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tbsim.dynamics.base import DynType 3 | import tbsim.utils.l5_utils as L5Utils 4 | from tbsim.models.cnn_roi_encoder import obtain_map_enc 5 | 6 | 7 | def pred2obs( 8 | dyn_list, 9 | step_time, 10 | src_pos, 11 | src_yaw, 12 | src_mask, 13 | data_batch, 14 | pos_pred, 15 | yaw_pred, 16 | pred_mask, 17 | raw_type, 18 | src_lanes, 19 | CNNmodel, 20 | algo_config, 21 | f_steps=1, 22 | M=1, 23 | ): 24 | """generate observation for the predicted scene f_step steps into the future 25 | 26 | Args: 27 | src_pos (torch.tensor[torch.float]): xy position in src 28 | src_yaw (torch.tensor[torch.float]): yaw in src 29 | src_mask (torch.tensor[torch.bool]): mask for src 30 | data_batch (dict): input data dictionary 31 | pos_pred (torch.tensor[torch.float]): predicted xy trajectory 32 | yaw_pred (torch.tensor[torch.float]): predicted yaw 33 | pred_mask (torch.tensor[torch.bool]): mask for prediction 34 | raw_type (torch.tensor[torch.int]): type of agents 35 | src_lanes (torch.tensor[torch.float]): lane info 36 | f_steps (int, optional): [description]. Defaults to 1. 37 | 38 | Returns: 39 | torch.tensor[torch.float]: new src for the transformer 40 | torch.tensor[torch.bool]: new src mask 41 | torch.tensor[torch.float]: new map encoding 42 | """ 43 | if pos_pred.ndim == 5: 44 | src_pos = src_pos.unsqueeze(1).repeat(1, M, 1, 1, 1) 45 | src_yaw = src_yaw.unsqueeze(1).repeat(1, M, 1, 1, 1) 46 | src_mask = src_mask.unsqueeze(1).repeat(1, M, 1, 1) 47 | pred_mask = pred_mask.unsqueeze(1).repeat(1, M, 1, 1) 48 | raw_type = raw_type.unsqueeze(1).repeat(1, M, 1) 49 | pos_new = torch.cat( 50 | (src_pos[..., f_steps:, :], pos_pred[..., :f_steps, :]), dim=-2 51 | ) 52 | yaw_new = torch.cat( 53 | (src_yaw[..., f_steps:, :], yaw_pred[..., :f_steps, :]), dim=-2 54 | ) 55 | src_mask_new = torch.cat( 56 | (src_mask[..., f_steps:], pred_mask[..., :f_steps]), dim=-1 57 | ) 58 | vel_new = dyn_list[DynType.UNICYCLE].calculate_vel( 59 | pos_new, yaw_new, step_time, src_mask_new 60 | ) 61 | src_new, _, _ = L5Utils.raw2feature( 62 | pos_new, 63 | vel_new, 64 | yaw_new, 65 | raw_type, 66 | src_mask_new, 67 | torch.zeros_like(src_lanes) if src_lanes is not None else None, 68 | ) 69 | 70 | if M == 1: 71 | map_emb_new = obtain_map_enc( 72 | data_batch["image"], 73 | CNNmodel, 74 | pos_new, 75 | yaw_new, 76 | data_batch["raster_from_agent"], 77 | src_mask_new, 78 | torch.tensor(algo_config.CNN.patch_size).to(src_pos.device), 79 | algo_config.CNN.output_size, 80 | mode="last", 81 | ) 82 | 83 | else: 84 | map_emb_new = list() 85 | for i in range(M): 86 | 87 | map_emb_new_i = obtain_map_enc( 88 | data_batch["image"], 89 | CNNmodel, 90 | pos_new[:, i], 91 | yaw_new[:, i], 92 | data_batch["raster_from_agent"], 93 | src_mask_new, 94 | torch.tensor(algo_config.CNN.patch_size).to(src_pos.device), 95 | algo_config.CNN.output_size, 96 | mode="last", 97 | ) 98 | map_emb_new.append(map_emb_new_i) 99 | map_emb_new = torch.stack(map_emb_new, dim=1) 100 | return src_new, src_mask_new, map_emb_new 101 | 102 | 103 | def pred2obs_static( 104 | dyn_list, 105 | step_time, 106 | data_batch, 107 | pos_pred, 108 | yaw_pred, 109 | pred_mask, 110 | raw_type, 111 | src_lanes, 112 | CNNmodel, 113 | algo_config, 114 | M=1, 115 | ): 116 | """generate observation for every step of the predictions 117 | 118 | Args: 119 | data_batch (dict): input data dictionary 120 | pos_pred (torch.tensor[torch.float]): predicted xy trajectory 121 | yaw_pred (torch.tensor[torch.float]): predicted yaw 122 | pred_mask (torch.tensor[torch.bool]): mask for prediction 123 | raw_type (torch.tensor[torch.int]): type of agents 124 | src_lanes (torch.tensor[torch.float]): lane info 125 | Returns: 126 | torch.tensor[torch.float]: new src for the transformer 127 | torch.tensor[torch.bool]: new src mask 128 | torch.tensor[torch.float]: new map encoding 129 | """ 130 | if pos_pred.ndim == 5: 131 | pred_mask = pred_mask.unsqueeze(1).repeat(1, M, 1, 1) 132 | raw_type = raw_type.unsqueeze(1).repeat(1, M, 1) 133 | 134 | pred_vel = dyn_list[DynType.UNICYCLE].calculate_vel( 135 | pos_pred, yaw_pred, step_time, pred_mask 136 | ) 137 | src_new, _, _ = L5Utils.raw2feature( 138 | pos_pred, 139 | pred_vel, 140 | yaw_pred, 141 | raw_type, 142 | pred_mask, 143 | torch.zeros_like(src_lanes) if src_lanes is not None else None, 144 | add_noise=True, 145 | ) 146 | 147 | if M == 1: 148 | map_emb_new = obtain_map_enc( 149 | data_batch["image"], 150 | CNNmodel, 151 | pos_pred, 152 | yaw_pred, 153 | data_batch["raster_from_agent"], 154 | pred_mask, 155 | torch.tensor(algo_config.CNN.patch_size).to(pos_pred.device), 156 | algo_config.CNN.output_size, 157 | mode="all", 158 | ) 159 | else: 160 | map_emb_new = list() 161 | for i in range(M): 162 | map_emb_new_i = obtain_map_enc( 163 | data_batch["image"], 164 | CNNmodel, 165 | pos_pred[:, i], 166 | yaw_pred[:, i], 167 | data_batch["raster_from_agent"], 168 | pred_mask, 169 | torch.tensor(algo_config.CNN.patch_size).to(pos_pred.device), 170 | algo_config.CNN.output_size, 171 | mode="all", 172 | ) 173 | map_emb_new.append(map_emb_new_i) 174 | map_emb_new = torch.stack(map_emb_new, dim=1) 175 | 176 | return src_new, pred_mask, map_emb_new 177 | -------------------------------------------------------------------------------- /tbsim/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/traffic-behavior-simulation/c470538011f15207b0688e4b430f8568c4cbe257/tbsim/models/__init__.py -------------------------------------------------------------------------------- /tbsim/models/learned_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import tbsim.models.base_models as base_models 7 | import tbsim.utils.tensor_utils as TensorUtils 8 | 9 | 10 | class PermuteEBM(nn.Module): 11 | """Raster-based model for planning. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | model_arch: str, 17 | input_image_shape, 18 | map_feature_dim: int, 19 | traj_feature_dim: int, 20 | embedding_dim: int, 21 | embed_layer_dims: tuple 22 | ) -> None: 23 | 24 | super().__init__() 25 | self.map_encoder = base_models.RasterizedMapEncoder( 26 | model_arch=model_arch, 27 | input_image_shape=input_image_shape, 28 | feature_dim=map_feature_dim, 29 | use_spatial_softmax=False, 30 | output_activation=nn.ReLU 31 | ) 32 | self.traj_encoder = base_models.RNNTrajectoryEncoder( 33 | trajectory_dim=3, 34 | rnn_hidden_size=100, 35 | feature_dim=traj_feature_dim 36 | ) 37 | self.embed_net = base_models.MLP( 38 | input_dim=traj_feature_dim + map_feature_dim, 39 | output_dim=embedding_dim, 40 | output_activation=nn.ReLU, 41 | layer_dims=embed_layer_dims 42 | ) 43 | self.score_net = nn.Linear(embedding_dim, 1) 44 | 45 | def forward(self, data_batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 46 | image_batch = data_batch["image"] 47 | trajs = torch.cat((data_batch["target_positions"], data_batch["target_yaws"]), dim=2) 48 | bs = image_batch.shape[0] 49 | 50 | map_feat = self.map_encoder(image_batch) # [B, D_m] 51 | traj_feat = self.traj_encoder(trajs) # [B, D_t] 52 | 53 | # construct contrastive samples 54 | map_feat_rep = TensorUtils.unsqueeze_expand_at(map_feat, size=bs, dim=1) # [B, B, D_m] 55 | traj_feat_rep = TensorUtils.unsqueeze_expand_at(traj_feat, size=bs, dim=0) # [B, B, D_t] 56 | cat_rep = torch.cat((map_feat_rep, traj_feat_rep), dim=-1) # [B, B, D_m + D_t] 57 | ebm_rep = TensorUtils.time_distributed(cat_rep, self.embed_net) # [B, B, D] 58 | 59 | # calculate embeddings and scores for InfoNCE loss 60 | scores = TensorUtils.time_distributed(ebm_rep, self.score_net).squeeze(-1) # [B, B] 61 | out_dict = dict(features=ebm_rep, scores=scores) 62 | 63 | return out_dict 64 | 65 | def get_scores(self, data_batch): 66 | image_batch = data_batch["image"] 67 | trajs = torch.cat((data_batch["target_positions"], data_batch["target_yaws"]), dim=2) 68 | 69 | map_feat = self.map_encoder(image_batch) # [B, D_m] 70 | traj_feat = self.traj_encoder(trajs) # [B, D_t] 71 | cat_rep = torch.cat((map_feat, traj_feat), dim=-1) # [B, D_m + D_t] 72 | ebm_rep = self.embed_net(cat_rep) 73 | scores = self.score_net(ebm_rep) 74 | out_dict = dict(features=ebm_rep, scores=scores) 75 | 76 | return out_dict 77 | 78 | def compute_losses(self, pred_batch, data_batch): 79 | scores = pred_batch["scores"] 80 | bs = scores.shape[0] 81 | labels = torch.arange(bs).to(scores.device) 82 | loss = nn.CrossEntropyLoss()(scores, labels) 83 | losses = dict(infoNCE_loss=loss) 84 | 85 | return losses -------------------------------------------------------------------------------- /tbsim/models/roi_align.py: -------------------------------------------------------------------------------- 1 | from logging import raiseExceptions 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from tbsim.utils.geometry_utils import batch_nd_transform_points 6 | 7 | 8 | def bilinear_interpolate(img, x, y, floattype=torch.float, flip_y=False): 9 | """Return bilinear interpolation of 4 nearest pts w.r.t to x,y from img 10 | Args: 11 | img (torch.Tensor): Tensor of size cxwxh. Usually one channel of feature layer 12 | x (torch.Tensor): Float dtype, x axis location for sampling 13 | y (torch.Tensor): Float dtype, y axis location for sampling 14 | 15 | Returns: 16 | torch.Tensor: interpolated value 17 | """ 18 | if flip_y: 19 | y = img.shape[-2]-1-y 20 | if img.device.type == "cuda": 21 | x0 = torch.floor(x).type(torch.cuda.LongTensor) 22 | y0 = torch.floor(y).type(torch.cuda.LongTensor) 23 | elif img.device.type == "cpu": 24 | x0 = torch.floor(x).type(torch.LongTensor) 25 | y0 = torch.floor(y).type(torch.LongTensor) 26 | else: 27 | raise ValueError("device not recognized") 28 | x1 = x0 + 1 29 | y1 = y0 + 1 30 | 31 | x0 = torch.clamp(x0, 0, img.shape[-1] - 1) 32 | x1 = torch.clamp(x1, 0, img.shape[-1] - 1) 33 | y0 = torch.clamp(y0, 0, img.shape[-2] - 1) 34 | y1 = torch.clamp(y1, 0, img.shape[-2] - 1) 35 | 36 | Ia = img[..., y0, x0] 37 | Ib = img[..., y1, x0] 38 | Ic = img[..., y0, x1] 39 | Id = img[..., y1, x1] 40 | 41 | step = (x1.type(floattype) - x0.type(floattype)) * ( 42 | y1.type(floattype) - y0.type(floattype) 43 | ) 44 | step = torch.clamp(step, 1e-3, 2) 45 | norm_const = 1 / step 46 | 47 | wa = (x1.type(floattype) - x) * (y1.type(floattype) - y) * norm_const 48 | wb = (x1.type(floattype) - x) * (y - y0.type(floattype)) * norm_const 49 | wc = (x - x0.type(floattype)) * (y1.type(floattype) - y) * norm_const 50 | wd = (x - x0.type(floattype)) * (y - y0.type(floattype)) * norm_const 51 | return ( 52 | Ia * wa.unsqueeze(0) 53 | + Ib * wb.unsqueeze(0) 54 | + Ic * wc.unsqueeze(0) 55 | + Id * wd.unsqueeze(0) 56 | ) 57 | 58 | 59 | def ROI_align(features, ROI, outdim): 60 | """Given feature layers and proposals return bilinear interpolated 61 | points in feature layer 62 | 63 | Args: 64 | features (torch.Tensor): Tensor of shape channels x width x height 65 | proposal (list of torch.Tensor): x0,y0,W1,W2,H1,H2,psi 66 | """ 67 | 68 | bs, num_channels, h, w = features.shape 69 | 70 | xg = ( 71 | torch.cat( 72 | ( 73 | torch.arange(0, outdim).view(-1, 1) - (outdim - 1) / 2, 74 | torch.zeros([outdim, 1]), 75 | ), 76 | dim=-1, 77 | ) 78 | / outdim 79 | ) 80 | yg = ( 81 | torch.cat( 82 | ( 83 | torch.zeros([outdim, 1]), 84 | torch.arange(0, outdim).view(-1, 1) - (outdim - 1) / 2, 85 | ), 86 | dim=-1, 87 | ) 88 | / outdim 89 | ) 90 | gg = xg.view(1, -1, 2) + yg.view(-1, 1, 2) 91 | gg = gg.to(features.device) 92 | res = list() 93 | for i in range(bs): 94 | if ROI[i] is not None: 95 | W1 = ROI[i][..., 2:3] 96 | W2 = ROI[i][..., 3:4] 97 | H1 = ROI[i][..., 4:5] 98 | H2 = ROI[i][..., 5:6] 99 | psi = ROI[i][..., 6:] 100 | WH = torch.cat((W1 + W2, H1 + H2), dim=-1) 101 | offset = torch.cat(((W1 - W2) / 2, (H1 - H2) / 2), dim=-1) 102 | s = torch.sin(psi).unsqueeze(-1) 103 | c = torch.cos(psi).unsqueeze(-1) 104 | rotM = torch.cat( 105 | (torch.cat((c, -s), dim=-1), torch.cat((s, c), dim=-1)), dim=-2 106 | ) 107 | ggi = gg * WH[..., None, None, :] - offset[..., None, None, :] 108 | ggi = ggi @ rotM[..., None, :, :] + ROI[i][..., None, None, 0:2] 109 | 110 | x_sample = ggi[..., 0].flatten() 111 | y_sample = ggi[..., 1].flatten() 112 | res.append( 113 | bilinear_interpolate(features[i], x_sample, y_sample).view( 114 | ggi.shape[0], num_channels, *ggi.shape[1:-1] 115 | ) 116 | ) 117 | else: 118 | res.append(None) 119 | 120 | return res 121 | 122 | 123 | def generate_ROIs( 124 | pos, 125 | yaw, 126 | raster_from_agent, 127 | mask, 128 | patch_size, 129 | mode="last", 130 | ): 131 | """ 132 | This version generates ROI for all agents only at most recent time step unless specified otherwise 133 | """ 134 | if mode == "all": 135 | bs = pos.shape[0] 136 | yaw = yaw.type(torch.float) 137 | Mat = raster_from_agent.view(-1, 1, 1, 3, 3).type(torch.float) 138 | raster_xy = batch_nd_transform_points(pos, Mat) 139 | raster_mult = torch.linalg.norm( 140 | raster_from_agent[0, 0, 0:2], dim=[-1]).item() 141 | patch_size = patch_size.type(torch.float) 142 | patch_size *= raster_mult 143 | ROI = [None] * bs 144 | index = [None] * bs 145 | for i in range(bs): 146 | ii, jj = torch.where(mask[i]) 147 | index[i] = (ii, jj) 148 | if patch_size.ndim == 1: 149 | patches_size = patch_size.repeat(ii.shape[0], 1) 150 | else: 151 | sizes = patch_size[i, ii] 152 | patches_size = torch.cat( 153 | ( 154 | sizes[:, 0:1] * 0.5, 155 | sizes[:, 0:1] * 0.5, 156 | sizes[:, 1:2] * 0.5, 157 | sizes[:, 1:2] * 0.5, 158 | ), 159 | dim=-1, 160 | ) 161 | ROI[i] = torch.cat( 162 | ( 163 | raster_xy[i, ii, jj], 164 | patches_size, 165 | yaw[i, ii, jj], 166 | ), 167 | dim=-1, 168 | ).to(pos.device) 169 | return ROI, index 170 | elif mode == "last": 171 | num = torch.arange(0, mask.shape[2]).view(1, 1, -1).to(mask.device) 172 | nummask = num * mask 173 | last_idx, _ = torch.max(nummask, dim=2) 174 | bs = pos.shape[0] 175 | Mat = raster_from_agent.view(-1, 1, 1, 3, 3).type(torch.float) 176 | raster_xy = batch_nd_transform_points(pos, Mat) 177 | raster_mult = torch.linalg.norm( 178 | raster_from_agent[0, 0, 0:2], dim=[-1]).item() 179 | patch_size = patch_size.type(torch.float) 180 | patch_size *= raster_mult 181 | agent_mask = mask.any(dim=2) 182 | ROI = [None] * bs 183 | index = [None] * bs 184 | for i in range(bs): 185 | ii = torch.where(agent_mask[i])[0] 186 | index[i] = ii 187 | if patch_size.ndim == 1: 188 | patches_size = patch_size.repeat(ii.shape[0], 1) 189 | else: 190 | sizes = patch_size[i, ii] 191 | patches_size = torch.cat( 192 | ( 193 | sizes[:, 0:1] * 0.5, 194 | sizes[:, 0:1] * 0.5, 195 | sizes[:, 1:2] * 0.5, 196 | sizes[:, 1:2] * 0.5, 197 | ), 198 | dim=-1, 199 | ) 200 | ROI[i] = torch.cat( 201 | ( 202 | raster_xy[i, ii, last_idx[i, ii]], 203 | patches_size, 204 | yaw[i, ii, last_idx[i, ii]], 205 | ), 206 | dim=-1, 207 | ) 208 | return ROI, index 209 | else: 210 | raise ValueError("mode must be 'all' or 'last'") 211 | 212 | 213 | def Indexing_ROI_result(CNN_out, index, emb_size): 214 | """put the lists of ROI align result into embedding tensor with the help of index""" 215 | bs = len(CNN_out) 216 | map_emb = torch.zeros(emb_size).to(CNN_out[0].device) 217 | if map_emb.ndim == 3: 218 | for i in range(bs): 219 | map_emb[i, index[i]] = CNN_out[i] 220 | elif map_emb.ndim == 4: 221 | for i in range(bs): 222 | ii, jj = index[i] 223 | map_emb[i, ii, jj] = CNN_out[i] 224 | else: 225 | raise ValueError("wrong dimension for the map embedding!") 226 | 227 | return map_emb 228 | -------------------------------------------------------------------------------- /tbsim/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/traffic-behavior-simulation/c470538011f15207b0688e4b430f8568c4cbe257/tbsim/policies/__init__.py -------------------------------------------------------------------------------- /tbsim/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Policy(abc.ABC): 5 | def __init__(self, device, *args, **kwargs): 6 | self.device = device 7 | 8 | @abc.abstractmethod 9 | def get_action(self, obs_dict, **kwargs): 10 | """Predict an action based on the input observation """ 11 | pass 12 | 13 | @abc.abstractmethod 14 | def eval(self): 15 | """Set the policy to evaluation mode""" 16 | pass -------------------------------------------------------------------------------- /tbsim/policies/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from copy import deepcopy 4 | 5 | import tbsim.utils.tensor_utils as TensorUtils 6 | from tbsim.utils.geometry_utils import transform_points_tensor 7 | from l5kit.geometry import transform_points 8 | 9 | 10 | class Trajectory(object): 11 | """Container for sequences of 2D positions and yaws""" 12 | def __init__(self, positions, yaws): 13 | assert positions.shape[:-1] == yaws.shape[:-1] 14 | assert positions.shape[-1] == 2 15 | assert yaws.shape[-1] == 1 16 | self._positions = positions 17 | self._yaws = yaws 18 | 19 | @property 20 | def trajectories(self): 21 | if isinstance(self.positions, np.ndarray): 22 | return np.concatenate([self._positions, self._yaws], axis=-1) 23 | else: 24 | return torch.cat([self._positions, self._yaws], dim=-1) 25 | 26 | @property 27 | def positions(self): 28 | return TensorUtils.clone(self._positions) 29 | 30 | @positions.setter 31 | def positions(self, x): 32 | self._positions = TensorUtils.clone(x) 33 | 34 | @property 35 | def yaws(self): 36 | return TensorUtils.clone(self._yaws) 37 | 38 | @yaws.setter 39 | def yaws(self, x): 40 | self._yaws = TensorUtils.clone(x) 41 | 42 | def to_dict(self): 43 | return dict( 44 | positions=self.positions, 45 | yaws=self.yaws 46 | ) 47 | 48 | @classmethod 49 | def from_dict(cls, d): 50 | return cls(**d) 51 | 52 | def transform(self, trans_mats, rot_rads): 53 | if isinstance(self.positions, np.ndarray): 54 | pos = transform_points(self.positions, trans_mats) 55 | else: 56 | pos = transform_points_tensor(self.positions, trans_mats) 57 | 58 | yaw = self.yaws + rot_rads 59 | return self.__class__(pos, yaw) 60 | 61 | def to_numpy(self): 62 | return self.__class__(**TensorUtils.to_numpy(self.to_dict())) 63 | 64 | 65 | class Action(Trajectory): 66 | pass 67 | 68 | 69 | class Plan(Trajectory): 70 | """Container for sequences of 2D positions, yaws, controls, availabilities.""" 71 | def __init__(self, positions, yaws, availabilities, controls=None): 72 | assert positions.shape[:-1] == yaws.shape[:-1] 73 | assert positions.shape[-1] == 2 74 | assert yaws.shape[-1] == 1 75 | assert availabilities.shape == positions.shape[:-1] 76 | self._positions = positions 77 | self._yaws = yaws 78 | self._availabilities = availabilities 79 | self._controls = controls 80 | 81 | @property 82 | def availabilities(self): 83 | return TensorUtils.clone(self._availabilities) 84 | 85 | @property 86 | def controls(self): 87 | return TensorUtils.clone(self._controls) 88 | 89 | def to_dict(self): 90 | p = dict( 91 | positions=self.positions, 92 | yaws=self.yaws, 93 | availabilities=self.availabilities, 94 | ) 95 | if self._controls is not None: 96 | p["controls"] = self.controls 97 | return p 98 | 99 | def transform(self, trans_mats, rot_rads): 100 | if isinstance(self.positions, np.ndarray): 101 | pos = transform_points(self.positions, trans_mats) 102 | else: 103 | pos = transform_points_tensor(self.positions, trans_mats) 104 | 105 | yaw = self.yaws + rot_rads 106 | return self.__class__(pos, yaw, self.availabilities, controls=self.controls) 107 | 108 | 109 | class RolloutAction(object): 110 | """Actions used to control agent rollouts""" 111 | def __init__(self, ego=None, ego_info=None, agents=None, agents_info=None): 112 | assert ego is None or isinstance(ego, Action) 113 | assert agents is None or isinstance(agents, Action) 114 | assert ego_info is None or isinstance(ego_info, dict) 115 | assert agents_info is None or isinstance(agents_info, dict) 116 | 117 | self.ego = ego 118 | self.ego_info = ego_info 119 | self.agents = agents 120 | self.agents_info = agents_info 121 | 122 | @property 123 | def has_ego(self): 124 | return self.ego is not None 125 | 126 | @property 127 | def has_agents(self): 128 | return self.agents is not None 129 | 130 | def transform(self, ego_trans_mats, ego_rot_rads, agents_trans_mats=None, agents_rot_rads=None): 131 | trans_action = RolloutAction() 132 | if self.has_ego: 133 | trans_action.ego = self.ego.transform( 134 | trans_mats=ego_trans_mats, rot_rads=ego_rot_rads) 135 | if self.ego_info is not None: 136 | trans_action.ego_info = deepcopy(self.ego_info) 137 | if "plan" in trans_action.ego_info: 138 | plan = Plan.from_dict(trans_action.ego_info["plan"]) 139 | trans_action.ego_info["plan"] = plan.transform( 140 | trans_mats=ego_trans_mats, rot_rads=ego_rot_rads 141 | ).to_dict() 142 | if self.has_agents: 143 | assert agents_trans_mats is not None and agents_rot_rads is not None 144 | trans_action.agents = self.agents.transform( 145 | trans_mats=agents_trans_mats, rot_rads=agents_rot_rads) 146 | if self.agents_info is not None: 147 | trans_action.agents_info = deepcopy(self.agents_info) 148 | if "plan" in trans_action.agents_info: 149 | plan = Plan.from_dict(trans_action.agents_info["plan"]) 150 | trans_action.agents_info["plan"] = plan.transform( 151 | trans_mats=agents_trans_mats, rot_rads=agents_rot_rads 152 | ).to_dict() 153 | return trans_action 154 | 155 | def to_dict(self): 156 | d = dict() 157 | if self.has_ego: 158 | d["ego"] = self.ego.to_dict() 159 | d["ego_info"] = deepcopy(self.ego_info) 160 | if self.has_agents: 161 | d["agents"] = self.agents.to_dict() 162 | d["agents_info"] = deepcopy(self.agents_info) 163 | return d 164 | 165 | def to_numpy(self): 166 | return self.__class__( 167 | ego=self.ego.to_numpy() if self.has_ego else None, 168 | ego_info=TensorUtils.to_numpy( 169 | self.ego_info) if self.has_ego else None, 170 | agents=self.agents.to_numpy() if self.has_agents else None, 171 | agents_info=TensorUtils.to_numpy( 172 | self.agents_info) if self.has_agents else None, 173 | ) 174 | 175 | @classmethod 176 | def from_dict(cls, d): 177 | d = deepcopy(d) 178 | if "ego" in d: 179 | d["ego"] = Action.from_dict(d["ego"]) 180 | if "agents" in d: 181 | d["agents"] = Action.from_dict(d["agents"]) 182 | return cls(**d) 183 | 184 | -------------------------------------------------------------------------------- /tbsim/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/traffic-behavior-simulation/c470538011f15207b0688e4b430f8568c4cbe257/tbsim/utils/__init__.py -------------------------------------------------------------------------------- /tbsim/utils/batch_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import tbsim.utils.l5_utils as l5_utils 4 | import tbsim.utils.trajdata_utils as av_utils 5 | from tbsim import dynamics as dynamics 6 | from tbsim.configs.base import ExperimentConfig 7 | 8 | 9 | BATCH_TYPE = None 10 | 11 | 12 | def set_global_batch_type(batch_type): 13 | global BATCH_TYPE 14 | assert batch_type in ["trajdata", "l5kit"] 15 | BATCH_TYPE = batch_type 16 | 17 | 18 | def batch_utils(**kwargs): 19 | if BATCH_TYPE == "trajdata": 20 | return trajdataBatchUtils(**kwargs) 21 | elif BATCH_TYPE == "l5kit": 22 | return L5BatchUtils(**kwargs) 23 | else: 24 | raise NotImplementedError("Please set BATCH_TYPE in batch_utils.py to {trajdata, l5kit}") 25 | 26 | 27 | class BatchUtils(object): 28 | """A base class for processing environment-independent batches""" 29 | 30 | def __init__(self,**kwargs): 31 | if "parse" in kwargs: 32 | self.parse = kwargs["parse"] 33 | else: 34 | self.parse = True 35 | if "rasterize_mode" in kwargs: 36 | self.rasterize_mode = kwargs["rasterize_mode"] 37 | else: 38 | self.rasterize_mode = "point" 39 | 40 | @staticmethod 41 | def get_last_available_index(avails): 42 | """ 43 | Args: 44 | avails (torch.Tensor): target availabilities [B, (A), T] 45 | 46 | Returns: 47 | last_indices (torch.Tensor): index of the last available frame 48 | """ 49 | num_frames = avails.shape[-1] 50 | inds = torch.arange(0, num_frames).to(avails.device) # [T] 51 | inds = (avails > 0).float() * inds # [B, (A), T] arange indices with unavailable indices set to 0 52 | last_inds = inds.max(dim=-1)[1] # [B, (A)] calculate the index of the last availale frame 53 | return last_inds 54 | 55 | @staticmethod 56 | def get_current_states(batch: dict, dyn_type: dynamics.DynType) -> torch.Tensor: 57 | """Get the dynamic states of the current timestep""" 58 | bs = batch["curr_speed"].shape[0] 59 | if dyn_type == dynamics.DynType.BICYCLE: 60 | current_states = torch.zeros(bs, 6).to(batch["curr_speed"].device) # [x, y, yaw, vel, dh, veh_len] 61 | current_states[:, 3] = batch["curr_speed"].abs() 62 | current_states[:, [4]] = (batch["history_yaws"][:, 0] - batch["history_yaws"][:, 1]).abs() 63 | current_states[:, 5] = batch["extent"][:, 0] # [veh_len] 64 | else: 65 | current_states = torch.zeros(bs, 4).to(batch["curr_speed"].device) # [x, y, vel, yaw] 66 | current_states[:, 2] = batch["curr_speed"] 67 | return current_states 68 | 69 | @classmethod 70 | def get_current_states_all_agents(cls, batch: dict, step_time, dyn_type: dynamics.DynType) -> torch.Tensor: 71 | raise NotImplementedError 72 | 73 | @staticmethod 74 | def parse_batch(data_batch): 75 | raise NotImplementedError 76 | 77 | @staticmethod 78 | def batch_to_raw_all_agents(data_batch, step_time): 79 | raise NotImplementedError 80 | 81 | @staticmethod 82 | def batch_to_target_all_agents(data_batch): 83 | raise NotImplementedError 84 | 85 | @staticmethod 86 | def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): 87 | raise NotImplementedError 88 | 89 | @staticmethod 90 | def generate_edges(raw_type, extents, pos_pred, yaw_pred, edge_mask=None): 91 | raise NotImplementedError 92 | 93 | @staticmethod 94 | def gen_edges_masked(raw_type, extents, pred): 95 | raise NotImplementedError 96 | 97 | @staticmethod 98 | def gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types): 99 | raise NotImplementedError 100 | 101 | @staticmethod 102 | def gen_EC_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types, mask=None): 103 | raise NotImplementedError 104 | 105 | @staticmethod 106 | def get_drivable_region_map(rasterized_map): 107 | raise NotImplementedError 108 | 109 | @staticmethod 110 | def get_modality_shapes(cfg: ExperimentConfig): 111 | raise NotImplementedError 112 | 113 | 114 | class L5BatchUtils(BatchUtils): 115 | """Batch utils for L5Kit""" 116 | @staticmethod 117 | def parse_batch(data_batch): 118 | return data_batch 119 | 120 | @staticmethod 121 | def batch_to_raw_all_agents(data_batch, step_time): 122 | return l5_utils.batch_to_raw_all_agents(data_batch, step_time) 123 | 124 | @staticmethod 125 | def get_current_states_all_agents(batch, step_time, dyn_type): 126 | return l5_utils.get_current_states_all_agents(batch,step_time,dyn_type) 127 | 128 | @staticmethod 129 | def batch_to_target_all_agents(data_batch): 130 | return l5_utils.batch_to_target_all_agents(data_batch) 131 | 132 | @staticmethod 133 | def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): 134 | return l5_utils.get_edges_from_batch(data_batch, ego_predictions, all_predictions) 135 | 136 | @staticmethod 137 | def generate_edges(raw_type, extents, pos_pred, yaw_pred,edge_mask = None): 138 | return l5_utils.generate_edges(raw_type, extents, pos_pred, yaw_pred,edge_mask) 139 | 140 | @staticmethod 141 | def gen_edges_masked(raw_type, extents, pred): 142 | return l5_utils.gen_edges_masked(raw_type, extents, pred) 143 | 144 | @staticmethod 145 | def gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types): 146 | return l5_utils.gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types) 147 | 148 | @staticmethod 149 | def gen_EC_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types, mask=None): 150 | return l5_utils.gen_EC_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types, mask) 151 | 152 | @staticmethod 153 | def get_drivable_region_map(rasterized_map): 154 | return l5_utils.get_drivable_region_map(rasterized_map) 155 | 156 | @staticmethod 157 | def get_modality_shapes(cfg: ExperimentConfig): 158 | return l5_utils.get_modality_shapes(cfg) 159 | 160 | 161 | class trajdataBatchUtils(BatchUtils): 162 | """Batch utils for trajdata""" 163 | def parse_batch(self,data_batch): 164 | if self.parse: 165 | return av_utils.parse_trajdata_batch(data_batch,self.rasterize_mode) 166 | else: 167 | return data_batch 168 | 169 | @staticmethod 170 | def batch_to_raw_all_agents(data_batch, step_time): 171 | raw_type = torch.cat( 172 | (data_batch["type"].unsqueeze(1), data_batch["all_other_agents_types"]), 173 | dim=1, 174 | ).type(torch.int64) 175 | 176 | src_pos = torch.cat( 177 | ( 178 | data_batch["history_positions"].unsqueeze(1), 179 | data_batch["all_other_agents_history_positions"], 180 | ), 181 | dim=1, 182 | ) 183 | src_yaw = torch.cat( 184 | ( 185 | data_batch["history_yaws"].unsqueeze(1), 186 | data_batch["all_other_agents_history_yaws"], 187 | ), 188 | dim=1, 189 | ) 190 | src_mask = torch.cat( 191 | ( 192 | data_batch["history_availabilities"].unsqueeze(1), 193 | data_batch["all_other_agents_history_availability"], 194 | ), 195 | dim=1, 196 | ).bool() 197 | 198 | extents = torch.cat( 199 | ( 200 | data_batch["extent"][..., :2].unsqueeze(1), 201 | data_batch["all_other_agents_history_extents"][..., -1,:2], 202 | ), 203 | dim=1, 204 | ) 205 | 206 | curr_speed = torch.cat( 207 | ( 208 | data_batch["curr_speed"].unsqueeze(1), 209 | data_batch["all_other_agents_curr_speed"] 210 | ), 211 | dim=1, 212 | ) 213 | 214 | return { 215 | "history_positions": src_pos, 216 | "history_yaws": src_yaw, 217 | "curr_speed": curr_speed, 218 | "raw_types": raw_type, 219 | "history_availabilities": src_mask, 220 | "extents": extents, 221 | } 222 | 223 | @staticmethod 224 | def batch_to_target_all_agents(data_batch): 225 | pos = torch.cat( 226 | ( 227 | data_batch["target_positions"].unsqueeze(1), 228 | data_batch["all_other_agents_future_positions"], 229 | ), 230 | dim=1, 231 | ) 232 | yaw = torch.cat( 233 | ( 234 | data_batch["target_yaws"].unsqueeze(1), 235 | data_batch["all_other_agents_future_yaws"], 236 | ), 237 | dim=1, 238 | ) 239 | avails = torch.cat( 240 | ( 241 | data_batch["target_availabilities"].unsqueeze(1), 242 | data_batch["all_other_agents_future_availability"], 243 | ), 244 | dim=1, 245 | ) 246 | 247 | extents = torch.cat( 248 | ( 249 | data_batch["extent"][..., :2].unsqueeze(1), 250 | data_batch["all_other_agents_extents"][..., :2], 251 | ), 252 | dim=1, 253 | ) 254 | 255 | return { 256 | "target_positions": pos, 257 | "target_yaws": yaw, 258 | "target_availabilities": avails, 259 | "extents": extents 260 | } 261 | 262 | @staticmethod 263 | def get_current_states_all_agents(batch: dict, step_time, dyn_type: dynamics.DynType) -> torch.Tensor: 264 | if batch["history_positions"].ndim==3: 265 | state_all = trajdataBatchUtils.batch_to_raw_all_agents(batch, step_time) 266 | else: 267 | state_all = batch 268 | bs, na = state_all["curr_speed"].shape[:2] 269 | if dyn_type == dynamics.DynType.BICYCLE: 270 | current_states = torch.zeros(bs, na, 6).to(state_all["curr_speed"].device) # [x, y, yaw, vel, dh, veh_len] 271 | current_states[:, :, :2] = state_all["history_positions"][:, :, -1] 272 | current_states[:, :, 3] = state_all["curr_speed"].abs() 273 | current_states[:, :, [4]] = (state_all["history_yaws"][:, :, -1] - state_all["history_yaws"][:, :, 1]).abs() 274 | current_states[:, :, 5] = state_all["extent"][:, :, -1] # [veh_len] 275 | else: 276 | current_states = torch.zeros(bs, na, 4).to(state_all["curr_speed"].device) # [x, y, vel, yaw] 277 | current_states[:, :, :2] = state_all["history_positions"][:, :, -1] 278 | current_states[:, :, 2] = state_all["curr_speed"] 279 | current_states[:,:,3:] = state_all["history_yaws"][:,:,-1] 280 | return current_states 281 | 282 | @staticmethod 283 | def get_edges_from_batch(data_batch, ego_predictions=None, all_predictions=None): 284 | raise NotImplementedError 285 | 286 | @staticmethod 287 | def generate_edges(raw_type, extents, pos_pred, yaw_pred,edge_mask=None): 288 | return l5_utils.generate_edges(raw_type, extents, pos_pred, yaw_pred,edge_mask) 289 | 290 | @staticmethod 291 | def gen_edges_masked(raw_type, extents, pred): 292 | return l5_utils.gen_edges_masked(raw_type, extents, pred) 293 | 294 | @staticmethod 295 | def gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types): 296 | return l5_utils.gen_ego_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types) 297 | 298 | @staticmethod 299 | def gen_EC_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types, mask=None): 300 | return l5_utils.gen_EC_edges(ego_trajectories, agent_trajectories, ego_extents, agent_extents, raw_types, mask) 301 | 302 | @staticmethod 303 | def get_drivable_region_map(rasterized_map): 304 | return av_utils.get_drivable_region_map(rasterized_map) 305 | 306 | def get_modality_shapes(self, cfg: ExperimentConfig): 307 | return av_utils.get_modality_shapes(cfg,rasterize_mode=self.rasterize_mode) 308 | 309 | -------------------------------------------------------------------------------- /tbsim/utils/bokeh_script.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import tbsim.utils.geometry_utils as GeoUtils 4 | import tbsim.utils.tensor_utils as TensorUtils 5 | import tbsim.utils.lane_utils as LaneUtils 6 | from tbsim.policies.MPC.homotopy import HomotopyType,HOMOTOPY_THRESHOLD 7 | 8 | XYH_INDEX = np.array([0,1,3]) 9 | import h5py 10 | import pickle 11 | import json 12 | import pandas as pd 13 | from bokeh.io import curdoc 14 | from bokeh.layouts import row,column 15 | from bokeh.models import ColumnDataSource,LabelSet,PointDrawTool 16 | from bokeh.models.widgets import Slider, Paragraph,Button,CheckboxButtonGroup 17 | from bokeh.plotting import figure, show, output_file, output_notebook 18 | from scipy.spatial import ConvexHull 19 | from trajdata import MapAPI, VectorMap 20 | from pathlib import Path 21 | 22 | from bokeh.models import Range1d 23 | import bokeh 24 | import sys 25 | from collections import defaultdict 26 | import os 27 | args = defaultdict(list) 28 | argname = None 29 | for x in sys.argv: 30 | if x.startswith("-"): 31 | argname=x[1:] 32 | else: 33 | if argname is not None: 34 | args[argname].append(x) 35 | 36 | for k,v in args.items(): 37 | if len(v)==1: 38 | args[k]=v[0] 39 | 40 | 41 | assert "result_dir" in args 42 | result_dir = args["result_dir"] 43 | 44 | if "cache_path" in args: 45 | cache_path = Path(args["cache_path"]).expanduser() 46 | else: 47 | cache_path = Path("~/.unified_data_cache").expanduser() 48 | mapAPI = MapAPI(cache_path) 49 | 50 | sim_info_path = os.path.join(result_dir,"sim_info.json") 51 | sim_info = json.load(open(sim_info_path, "r")) 52 | 53 | if "scene_name" in args: 54 | scene_name = args["scene_name"] 55 | else: 56 | scene_name = sim_info["scene_index"][0] 57 | 58 | if "episode" in args: 59 | ei = args["episode"] 60 | else: 61 | ei = 0 62 | sim_name = f"{scene_name}_{ei}" 63 | 64 | 65 | map_name = sim_info["map_info"][scene_name] 66 | 67 | vec_map = mapAPI.get_map(map_name, scene_cache=None) 68 | 69 | 70 | 71 | 72 | hdf5_path = os.path.join(result_dir,"data.hdf5") 73 | h5f = h5py.File(hdf5_path, "r") 74 | 75 | trace_path = hdf5_path = os.path.join(result_dir,"trace.pkl") 76 | 77 | with open(trace_path, "rb") as f: 78 | trace = pickle.load(f) 79 | 80 | 81 | trace_offset = 11 82 | def plot_lane(lane,plot,color="grey"): 83 | bdry_l,_ = LaneUtils.get_edge(lane,dir="L",num_pts=15) 84 | bdry_r,_ = LaneUtils.get_edge(lane,dir="R",num_pts=15) 85 | lane_center,_ = LaneUtils.get_edge(lane,dir="C",num_pts=15) 86 | bdry_xy = np.concatenate([bdry_l,np.flip(bdry_r,0)],0) 87 | patch_glyph = plot.patch(x=bdry_xy[:,0],y=bdry_xy[:,1],fill_alpha=0.5,color = color) 88 | centerline_glyph = plot.line(x=lane_center[:,0],y=lane_center[:,1],line_dash="dashed",line_width=2) 89 | return patch_glyph,centerline_glyph 90 | 91 | 92 | def get_agent_edge(xy,h,extent): 93 | 94 | edges = np.array([[0.5,0.5],[0.5,-0.5],[-0.5,-0.5],[-0.5,0.5]])*extent[np.newaxis,:2] 95 | rotM = np.array([[np.cos(h),-np.sin(h)],[np.sin(h),np.cos(h)]]) 96 | edges = (rotM@edges[...,np.newaxis]).squeeze(-1)+xy[np.newaxis,:] 97 | return edges 98 | 99 | def button_callback(): 100 | sys.exit() # Stop the server 101 | plot = figure(name='base',height=1000, width=1000, title="traffic Animation", 102 | tools="reset,save",toolbar_location="below",match_aspect=True) 103 | plot.xgrid.grid_line_color = None 104 | plot.ygrid.grid_line_color = None 105 | plot.axis.visible=False 106 | 107 | sim_record = h5f[sim_name] 108 | sim_trace = trace[sim_name] 109 | plan_ts = np.array(list(trace[sim_name].keys())) 110 | lanes = set() 111 | lanecenter_glyph = dict() 112 | lanepatch_glyph = dict() 113 | agents = set() 114 | plan_glyph = dict() 115 | 116 | Na,T = sim_record["centroid"].shape[:2] 117 | 118 | agent_id=["ego"]+[f"A{i}" for i in range(1,Na)] 119 | palette = bokeh.palettes.Category20[20] 120 | agent_color = ["blueviolet"] + [palette[i%20] for i in range(Na-1)] 121 | agent_ds = dict() 122 | agent_patch = dict() 123 | agents = set() 124 | agent_plan_ds=dict() 125 | agent_plan_glyph = dict() 126 | 127 | # setup patches 128 | 129 | numModes = 5 130 | 131 | for t in range(T): 132 | # plotting lanes 133 | xyz = np.hstack([sim_record["centroid"][0,t],np.zeros(1)]) 134 | lanes_t = vec_map.get_lanes_within(xyz,100) 135 | for lane in lanes_t: 136 | if lane not in lanes: 137 | patch_glyph,centerline_glyph = plot_lane(lane,plot) 138 | lanecenter_glyph[lane] = centerline_glyph 139 | lanepatch_glyph[lane] = patch_glyph 140 | lanes.add(lane) 141 | 142 | track_ids = np.where((sim_record["centroid"][:,t]!=0).any(-1))[0] 143 | for id in track_ids: 144 | if id not in agents: 145 | agents.add(id) 146 | edges = get_agent_edge(sim_record["centroid"][id,t],sim_record["yaw"][id,t],sim_record["extent"][id,t]) 147 | source = ColumnDataSource(data=dict(x=edges[:,0],y=edges[:,1])) 148 | agent_patch[id] = plot.patch(x="x",y="y",source=source,color=agent_color[id]) 149 | agent_ds[id] = source 150 | # plotting plans 151 | 152 | plan_source = [ColumnDataSource(data=dict(x=np.random.randn(30),y=np.random.randn(30))) for i in range(numModes)] 153 | agent_plan_glyph[id] = [plot.line(x="x",y="y",source=plan_source[i],color=agent_color[id],line_width=2) for i in range(numModes)] 154 | agent_plan_ds[id] = plan_source 155 | 156 | xref_source = ColumnDataSource(data=dict(x=np.random.randn(30),y=np.random.randn(30))) 157 | xref_glyph = plot.line(x="x",y="y",source=xref_source,color=agent_color[0],line_width=1.5,line_dash="dashed") 158 | 159 | 160 | slider = Slider(title="Sim time", value=0, start=0, end=T-1, step=1) 161 | 162 | def update_data(attrname, old, new): 163 | t = slider.value #holds the current time value of slider after updating the slider 164 | 165 | # update agent patches 166 | track_ids = np.where((sim_record["centroid"][:,t]!=0).any(-1))[0] 167 | last_update_idx = np.where(t+trace_offset>=plan_ts)[0].argmax().item() 168 | last_update_t = plan_ts[last_update_idx].item() 169 | world_from_agent = sim_record["world_from_agent"][0,last_update_t-trace_offset] 170 | if sim_trace[last_update_t]["obj_x"] is not None: 171 | obj_plan = GeoUtils.batch_nd_transform_points_np(sim_trace[last_update_t]["obj_x"][:,:,:2],world_from_agent[np.newaxis,:]) 172 | else: 173 | obj_plan = None 174 | ego_plan = GeoUtils.batch_nd_transform_points_np(sim_trace[last_update_t]["ego_x"][:,:2],world_from_agent) 175 | if "xref" in sim_trace[last_update_t] and sim_trace[last_update_t]["xref"] is not None: 176 | xref = sim_trace[last_update_t]["xref"][...,:2] 177 | 178 | xref = GeoUtils.batch_nd_transform_points_np(xref,world_from_agent) 179 | else: 180 | xref = None 181 | for id in track_ids: 182 | edges = get_agent_edge(sim_record["centroid"][id,t],sim_record["yaw"][id,t],sim_record["extent"][id,t]) 183 | agent_ds[id].data.update(dict(x=edges[:,0],y=edges[:,1])) 184 | if id==0: 185 | agent_plan_ds[id][0].data.update(dict(x=ego_plan[:,0],y=ego_plan[:,1])) 186 | if sim_trace[last_update_t]["ego_x"] is not None: 187 | if sim_trace[last_update_t]["ego_candidate_x"] is not None: 188 | ego_candidate_trajs = GeoUtils.batch_nd_transform_points_np(sim_trace[last_update_t]["ego_candidate_x"][...,:2],world_from_agent[None,None,:]) 189 | else: 190 | ego_candidate_trajs = None 191 | for i in range(numModes-1): 192 | if ego_candidate_trajs is not None and i slider.end: 230 | t = 0 #if slider value+1 is above max, reset to 0 231 | slider.value = t 232 | #Update the label on the button once the button is clicked 233 | global play_cb 234 | def animate(): 235 | global play_cb 236 | if play_button.label == '► Play': 237 | play_button.label = '❚❚ Pause' 238 | 239 | play_cb = curdoc().add_periodic_callback(animate_update, 100) #50 is speed of animation 240 | # curdoc().remove_periodic_callback(animate_update) 241 | else: 242 | play_button.label = '► Play' 243 | curdoc().remove_periodic_callback(play_cb) 244 | 245 | #callback when button is clicked. 246 | play_button.on_click(animate) 247 | 248 | 249 | layout = column(row(slider,exit_button,play_button),plot) #add plot to layout 250 | 251 | curdoc().add_root(layout) 252 | 253 | -------------------------------------------------------------------------------- /tbsim/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tbsim.configs.registry import get_registered_experiment_config 3 | from tbsim.configs.base import ExperimentConfig 4 | from tbsim.configs.config import Dict 5 | 6 | 7 | def translate_l5kit_cfg(cfg: ExperimentConfig): 8 | """ 9 | Translate a tbsim config to a l5kit config 10 | 11 | Args: 12 | cfg (ExperimentConfig): an ExperimentConfig instance 13 | 14 | Returns: 15 | cfg for l5kit 16 | """ 17 | rcfg = dict() 18 | 19 | rcfg["raster_params"] = cfg.env.rasterizer.to_dict() 20 | rcfg["raster_params"]["dataset_meta_key"] = cfg.train.dataset_meta_key 21 | rcfg["model_params"] = cfg.algo 22 | if "data_generation_params" in cfg.env.keys(): 23 | rcfg["data_generation_params"] = cfg.env["data_generation_params"] 24 | return rcfg 25 | 26 | 27 | def get_experiment_config_from_file(file_path, locked=False): 28 | ext_cfg = json.load(open(file_path, "r")) 29 | cfg = get_registered_experiment_config(ext_cfg["registered_name"]) 30 | cfg.update(**ext_cfg) 31 | cfg.lock(locked) 32 | return cfg 33 | 34 | 35 | def translate_trajdata_cfg(cfg: ExperimentConfig): 36 | rcfg = Dict() 37 | # assert cfg.algo.step_time == 0.5 # TODO: support interpolation 38 | if "scene_centric" in cfg.algo and cfg.algo.scene_centric: 39 | rcfg.centric="scene" 40 | else: 41 | rcfg.centric="agent" 42 | if "standardize_data" in cfg.env.data_generation_params: 43 | rcfg.standardize_data = cfg.env.data_generation_params.standardize_data 44 | else: 45 | rcfg.standardize_data = True 46 | rcfg.step_time = cfg.algo.step_time 47 | rcfg.trajdata_source_root = cfg.train.trajdata_source_root 48 | rcfg.trajdata_source_train = cfg.train.trajdata_source_train 49 | rcfg.trajdata_source_valid = cfg.train.trajdata_source_valid 50 | rcfg.dataset_path = cfg.train.dataset_path 51 | rcfg.history_num_frames = cfg.algo.history_num_frames 52 | rcfg.future_num_frames = cfg.algo.future_num_frames 53 | rcfg.max_agents_distance = cfg.env.data_generation_params.max_agents_distance 54 | rcfg.num_other_agents = cfg.env.data_generation_params.other_agents_num 55 | rcfg.max_agents_distance_simulation = cfg.env.simulation.distance_th_close 56 | rcfg.pixel_size = cfg.env.rasterizer.pixel_size 57 | rcfg.raster_size = int(cfg.env.rasterizer.raster_size) 58 | rcfg.raster_center = cfg.env.rasterizer.ego_center 59 | rcfg.yaw_correction_speed = cfg.env.data_generation_params.yaw_correction_speed 60 | rcfg.incl_neighbor_map = cfg.env.incl_neighbor_map 61 | rcfg.other_agents_num = cfg.env.data_generation_params.other_agents_num 62 | if "vectorize_lane" in cfg.env.data_generation_params: 63 | rcfg.vectorize_lane = cfg.env.data_generation_params.vectorize_lane 64 | else: 65 | rcfg.vectorize_lane = "None" 66 | 67 | rcfg.lock() 68 | return rcfg 69 | -------------------------------------------------------------------------------- /tbsim/utils/experiment_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import itertools 4 | from collections import namedtuple 5 | from typing import List 6 | from glob import glob 7 | import subprocess 8 | import shutil 9 | from pathlib import Path 10 | 11 | import tbsim 12 | from tbsim.configs.registry import get_registered_experiment_config 13 | from tbsim.configs.config import Dict 14 | from tbsim.configs.eval_config import EvaluationConfig 15 | from tbsim.configs.base import ExperimentConfig 16 | 17 | 18 | class Param(namedtuple("Param", "config_var alias value")): 19 | pass 20 | 21 | 22 | class ParamRange(namedtuple("Param", "config_var alias range")): 23 | def linearize(self): 24 | return [Param(self.config_var, self.alias, v) for v in self.range] 25 | 26 | def __len__(self): 27 | return len(self.range) 28 | 29 | 30 | class ParamConfig(object): 31 | def __init__(self, params: List[Param] = None): 32 | self.params = [] 33 | self.aliases = [] 34 | self.config_vars = [] 35 | print(params) 36 | if params is not None: 37 | for p in params: 38 | self.add(p) 39 | 40 | def add(self, param: Param): 41 | assert param.config_var not in self.config_vars 42 | assert param.alias not in self.aliases 43 | self.config_vars.append(param.config_var) 44 | self.aliases.append(param.alias) 45 | self.params.append(param) 46 | 47 | def __str__(self): 48 | char_to_remove = [" ", "(", ")", ";", "[", "]"] 49 | name = [] 50 | for p in self.params: 51 | v_str = str(p.value) 52 | for c in char_to_remove: 53 | v_str = v_str.replace(c, "") 54 | name.append(p.alias + v_str) 55 | 56 | return "_".join(name) 57 | 58 | def generate_config(self, base_cfg: Dict): 59 | cfg = base_cfg.clone() 60 | for p in self.params: 61 | var_list = p.config_var.split(".") 62 | c = cfg 63 | # traverse the indexing list 64 | for v in var_list[:-1]: 65 | assert v in c, "{} is not a valid config variable".format( 66 | p.config_var) 67 | c = c[v] 68 | assert var_list[-1] in c, "{} is not a valid config variable".format( 69 | p.config_var 70 | ) 71 | c[var_list[-1]] = p.value 72 | cfg.name = str(self) 73 | return cfg 74 | 75 | 76 | class ParamSearchPlan(object): 77 | def __init__(self): 78 | self.param_configs = [] 79 | self.const_params = [] 80 | 81 | def add_const_param(self, param: Param): 82 | self.const_params.append(param) 83 | 84 | def add(self, param_config: ParamConfig): 85 | for c in self.const_params: 86 | param_config.add(c) 87 | self.param_configs.append(param_config) 88 | 89 | def extend(self, param_configs: List[ParamConfig]): 90 | for pc in param_configs: 91 | self.add(pc) 92 | 93 | @staticmethod 94 | def compose_concate(param_ranges: List[ParamRange]): 95 | pcs = [] 96 | for pr in param_ranges: 97 | for p in pr.linearize(): 98 | pcs.append(ParamConfig([p])) 99 | return pcs 100 | 101 | @staticmethod 102 | def compose_cartesian(param_ranges: List[ParamRange]): 103 | """Cartesian product among parameters""" 104 | prs = [pr.linearize() for pr in param_ranges] 105 | return [ParamConfig(pr) for pr in itertools.product(*prs)] 106 | 107 | @staticmethod 108 | def compose_zip(param_ranges: List[ParamRange]): 109 | l = len(param_ranges[0]) 110 | assert all( 111 | len(pr) == l for pr in param_ranges 112 | ), "All param_range must be the same length" 113 | prs = [pr.linearize() for pr in param_ranges] 114 | return [ParamConfig(prz) for prz in zip(*prs)] 115 | 116 | def generate_configs(self, base_cfg: Dict): 117 | """ 118 | Generate configs from the parameter search plan, also rename the experiment by generating the correct alias. 119 | """ 120 | if len(self.param_configs) > 0: 121 | return [pc.generate_config(base_cfg) for pc in self.param_configs] 122 | else: 123 | # constant-only 124 | const_cfg = ParamConfig(self.const_params) 125 | return [const_cfg.generate_config(base_cfg)] 126 | 127 | 128 | def create_configs( 129 | configs_to_search_fn, 130 | config_name, 131 | config_file, 132 | config_dir, 133 | prefix, 134 | delete_config_dir=True, 135 | ): 136 | if config_name is not None: 137 | cfg = get_registered_experiment_config(config_name) 138 | print("Generating configs for {}".format(config_name)) 139 | elif config_file is not None: 140 | # Update default config with external json file 141 | ext_cfg = json.load(open(config_file, "r")) 142 | cfg = get_registered_experiment_config(ext_cfg["registered_name"]) 143 | cfg.update(**ext_cfg) 144 | print("Generating configs with {} as template".format(config_file)) 145 | else: 146 | raise FileNotFoundError("No base config is provided") 147 | 148 | configs = configs_to_search_fn(base_cfg=cfg) 149 | for c in configs: 150 | pfx = "{}_".format(prefix) if prefix is not None else "" 151 | c.name = pfx + c.name 152 | config_fns = [] 153 | 154 | if delete_config_dir and os.path.exists(config_dir): 155 | shutil.rmtree(config_dir) 156 | os.makedirs(config_dir, exist_ok=True) 157 | for c in configs: 158 | fn = os.path.join(config_dir, "{}.json".format(c.name)) 159 | config_fns.append(fn) 160 | print("Saving config to {}".format(fn)) 161 | c.dump(fn) 162 | 163 | return configs, config_fns 164 | 165 | 166 | def read_configs(config_dir): 167 | configs = [] 168 | config_fns = [] 169 | for cfn in glob(config_dir + "/*.json"): 170 | print(cfn) 171 | config_fns.append(cfn) 172 | ext_cfg = json.load(open(cfn, "r")) 173 | c = get_registered_experiment_config(ext_cfg["registered_name"]) 174 | c.update(**ext_cfg) 175 | configs.append(c) 176 | return configs, config_fns 177 | 178 | 179 | def create_evaluation_configs( 180 | configs_to_search_fn, 181 | config_dir, 182 | cfg, 183 | prefix=None, 184 | delete_config_dir=True, 185 | ): 186 | configs = configs_to_search_fn(base_cfg=cfg) 187 | for c in configs: 188 | if prefix is not None: 189 | c.name = prefix + "_" + c.name 190 | 191 | config_fns = [] 192 | 193 | if delete_config_dir and os.path.exists(config_dir): 194 | shutil.rmtree(config_dir) 195 | os.makedirs(config_dir, exist_ok=True) 196 | for c in configs: 197 | fn = os.path.join(config_dir, "{}.json".format(c.name)) 198 | config_fns.append(fn) 199 | print("Saving config to {}".format(fn)) 200 | c.dump(fn) 201 | 202 | return configs, config_fns 203 | 204 | 205 | def read_evaluation_configs(config_dir): 206 | configs = [] 207 | config_fns = [] 208 | for cfn in glob(config_dir + "/*.json"): 209 | print(cfn) 210 | config_fns.append(cfn) 211 | c = EvaluationConfig() 212 | ext_cfg = json.load(open(cfn, "r")) 213 | c.update(**ext_cfg) 214 | configs.append(c) 215 | return configs, config_fns 216 | 217 | 218 | 219 | 220 | def launch_experiments_local(script_path, cfgs, cfg_paths, extra_args=[]): 221 | for cfg, cpath in zip(cfgs, cfg_paths): 222 | cmd = ["python", script_path, "--config_file", cpath] + extra_args 223 | subprocess.run(cmd) 224 | 225 | 226 | 227 | 228 | def get_checkpoint( 229 | ckpt_key, ckpt_dir=None, ckpt_root_dir="checkpoints/", download_tmp_dir="/tmp" 230 | ): 231 | """ 232 | Get checkpoint and config path given a local dir. 233 | 234 | 235 | 236 | If a @ckpt_dir is specified, the function will look for the directory locally and return the ckpt that contains 237 | @ckpt_key, as well as its config.json. 238 | 239 | Args: 240 | ckpt_key (str): a string that uniquely identifies a checkpoint file with a directory, e.g., `iter50000.ckpt` 241 | ckpt_dir (str): (Optional) a local directory that contains the specified checkpoint 242 | ckpt_root_dir (str): (Optional) a directory that the function will look for checkpoints 243 | download_tmp_dir (str): a temporary storage for the checkpoint. 244 | 245 | Returns: 246 | ckpt_path (str): path to a checkpoint file 247 | cfg_path (str): path to a config.json file 248 | """ 249 | def ckpt_path_func(paths): return [p for p in paths if str(ckpt_key) in p] 250 | local_dir = ckpt_dir 251 | assert ckpt_dir is not None 252 | 253 | ckpt_paths = glob(local_dir + "/**/*.ckpt", recursive=True) 254 | if len(ckpt_path_func(ckpt_paths)) == 0: 255 | raise FileNotFoundError("Cannot find checkpoint in {} with key {}".format(local_dir, ckpt_key)) 256 | else: 257 | ckpt_dir = local_dir 258 | 259 | ckpt_paths = ckpt_path_func(glob(ckpt_dir + "/**/*.ckpt", recursive=True)) 260 | assert len(ckpt_paths) > 0, "Could not find a checkpoint that has key {}".format( 261 | ckpt_key 262 | ) 263 | assert len(ckpt_paths) == 1, "More than one checkpoint found {}".format(ckpt_paths) 264 | cfg_path = glob(ckpt_dir + "/**/config.json", recursive=True)[0] 265 | print("Checkpoint path: {}".format(ckpt_paths[0])) 266 | print("Config path: {}".format(cfg_path)) 267 | return ckpt_paths[0], cfg_path 268 | 269 | 270 | -------------------------------------------------------------------------------- /tbsim/utils/lane_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tbsim.utils.geometry_utils as GeoUtils 3 | import tbsim.utils.tensor_utils as TensorUtils 4 | 5 | 6 | def get_edge(lane,dir,W=2.0,num_pts = None): 7 | 8 | if dir == "L": 9 | 10 | if lane.left_edge is not None: 11 | lane.left_edge = lane.left_edge.interpolate(num_pts) 12 | xy = lane.left_edge.xy 13 | if lane.left_edge.has_heading: 14 | h = lane.left_edge.h 15 | else: 16 | dxy = xy[1:]-xy[:-1] 17 | h = GeoUtils.round_2pi(np.arctan2(dxy[:,1],dxy[:,0])) 18 | h = np.hstack((h,h[-1])) 19 | else: 20 | lane.center = lane.center.interpolate(num_pts) 21 | angle = lane.center.h+np.pi/2 22 | offset = np.stack([W*np.cos(angle),W*np.sin(angle)],-1) 23 | xy = lane.center.xy+offset 24 | h = lane.center.h 25 | elif dir == "R": 26 | 27 | if lane.right_edge is not None: 28 | lane.right_edge = lane.right_edge.interpolate(num_pts) 29 | xy = lane.right_edge.xy 30 | if lane.right_edge.has_heading: 31 | h = lane.right_edge.h 32 | else: 33 | dxy = xy[1:]-xy[:-1] 34 | h = GeoUtils.round_2pi(np.arctan2(dxy[:,1],dxy[:,0])) 35 | h = np.hstack((h,h[-1])) 36 | else: 37 | lane.center = lane.center.interpolate(num_pts) 38 | angle = lane.center.h-np.pi/2 39 | offset = np.stack([W*np.cos(angle),W*np.sin(angle)],-1) 40 | xy = lane.center.xy+offset 41 | h = lane.center.h 42 | elif dir =="C": 43 | lane.center = lane.center.interpolate(num_pts) 44 | xy = lane.center.xy 45 | if lane.center.has_heading: 46 | h = lane.center.h 47 | else: 48 | dxy = xy[1:]-xy[:-1] 49 | h = GeoUtils.round_2pi(np.arctan2(dxy[:,1],dxy[:,0])) 50 | h = np.hstack((h,h[-1])) 51 | return xy,h 52 | 53 | def get_bdry_xyh(lane1,lane2=None,dir="L",W=3.6,num_pts = 25): 54 | if lane2 is None: 55 | xy,h = get_edge(lane1,dir,W,num_pts*2) 56 | else: 57 | xy1,h1 = get_edge(lane1,dir,W,num_pts) 58 | xy2,h2 = get_edge(lane2,dir,W,num_pts) 59 | xy = np.concatenate((xy1,xy2),0) 60 | h = np.concatenate((h1,h2),0) 61 | return xy,h 62 | -------------------------------------------------------------------------------- /tbsim/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains utility classes and functions for logging to stdout, stderr, 3 | and to tensorboard. 4 | """ 5 | import sys 6 | 7 | 8 | class PrintLogger(object): 9 | """ 10 | This class redirects print statements to both console and a file. 11 | """ 12 | def __init__(self, log_file): 13 | self.terminal = sys.stdout 14 | print('STDOUT will be forked to %s' % log_file) 15 | self.log_file = open(log_file, "a") 16 | 17 | def write(self, message): 18 | self.terminal.write(message) 19 | self.log_file.write(message) 20 | self.log_file.flush() 21 | 22 | def flush(self): 23 | # this flush method is needed for python 3 compatibility. 24 | # this handles the flush command by doing nothing. 25 | # you might want to specify some extra behavior here. 26 | pass 27 | -------------------------------------------------------------------------------- /tbsim/utils/math_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def soft_min(x,y,gamma=5): 5 | if isinstance(x,torch.Tensor): 6 | expfun = torch.exp 7 | elif isinstance(x,np.ndarray): 8 | expfun = np.exp 9 | exp1 = expfun((y-x)/2) 10 | exp2 = expfun((x-y)/2) 11 | return (exp1*x+exp2*y)/(exp1+exp2) 12 | 13 | def soft_max(x,y,gamma=5): 14 | if isinstance(x,torch.Tensor): 15 | expfun = torch.exp 16 | elif isinstance(x,np.ndarray): 17 | expfun = np.exp 18 | exp1 = expfun((x-y)/2) 19 | exp2 = expfun((y-x)/2) 20 | return (exp1*x+exp2*y)/(exp1+exp2) 21 | 22 | def soft_sat(x,x_min=None,x_max=None,gamma=5): 23 | if x_min is None and x_max is None: 24 | return x 25 | elif x_min is None and x_max is not None: 26 | return soft_min(x,x_max,gamma) 27 | elif x_min is not None and x_max is None: 28 | return soft_max(x,x_min,gamma) 29 | else: 30 | if isinstance(x_min,torch.Tensor) or isinstance(x_min,np.ndarray): 31 | assert (x_max>x_min).all() 32 | else: 33 | assert x_max>x_min 34 | xc = x - (x_min+x_max)/2 35 | if isinstance(x,torch.Tensor): 36 | return xc/(torch.pow(1+torch.pow(torch.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 37 | elif isinstance(x,np.ndarray): 38 | return xc/(np.power(1+np.power(np.abs(xc*2/(x_max-x_min)),gamma),1/gamma))+(x_min+x_max)/2 39 | else: 40 | raise Exception("data type not supported") 41 | 42 | -------------------------------------------------------------------------------- /tbsim/utils/rollout_logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import numpy as np 3 | from copy import deepcopy 4 | 5 | import tbsim.utils.tensor_utils as TensorUtils 6 | from tbsim.policies.common import RolloutAction 7 | 8 | from torch.nn.utils.rnn import pad_sequence 9 | class RolloutLogger(object): 10 | """Log trajectories and other essential info during rollout for visualization and evaluation""" 11 | def __init__(self, obs_keys=None): 12 | if obs_keys is None: 13 | obs_keys = dict() 14 | self._obs_keys = obs_keys 15 | self._scene_indices = None 16 | self._agent_id_per_scene = dict() 17 | self._agent_data_by_scene = dict() 18 | self._scene_ts = defaultdict(lambda:0) 19 | 20 | def _combine_obs(self, obs): 21 | combined = dict() 22 | excluded_keys = ["extras"] 23 | if "ego" in obs and obs["ego"] is not None: 24 | combined.update(obs["ego"]) 25 | if "agents" in obs and obs["agents"] is not None: 26 | for k in obs["agents"].keys(): 27 | if k in combined and k not in excluded_keys: 28 | if obs["agents"][k] is not None: 29 | if combined[k] is not None: 30 | combined[k] = np.concatenate((combined[k], obs["agents"][k]), axis=0) 31 | else: 32 | combined[k] = obs["agents"][k] 33 | else: 34 | combined[k] = obs["agents"][k] 35 | return combined 36 | 37 | def _combine_action(self, action: RolloutAction): 38 | combined = dict(action=dict()) 39 | if action.has_ego and not action.has_agents: 40 | combined["action"] = action.ego.to_dict() 41 | if action.ego_info is not None and "action_samples" in action.ego_info: 42 | combined["action_samples"] = action.ego_info["action_samples"] 43 | return combined 44 | 45 | elif action.has_agents and not action.has_ego: 46 | combined["action"] = action.agents.to_dict() 47 | if action.agents_info is not None and "action_samples" in action.agents_info: 48 | combined["action_samples"] = action.agents_info["action_samples"] 49 | return combined 50 | elif action.has_agents and action.has_ego: 51 | Nego = action.ego.positions.shape[0] 52 | Nagents = action.agents.positions.shape[0] 53 | combined["action"] = dict() 54 | agents_action = action.agents.to_dict() 55 | ego_action = action.ego.to_dict() 56 | for k in agents_action: 57 | if k in ego_action: 58 | combined["action"][k] = np.concatenate((ego_action[k], agents_action[k]), axis=0) 59 | if action.agents_info is not None and action.ego_info is not None: 60 | if "action_samples" in action.ego_info: 61 | ego_samples = action.ego_info["action_samples"] 62 | else: 63 | ego_samples = None 64 | if "action_samples" in action.agents_info: 65 | agents_samples = action.agents_info["action_samples"] 66 | else: 67 | agents_samples = None 68 | if ego_samples is not None and agents_samples is None: 69 | combined["action_samples"] = dict() 70 | for k in ego_samples: 71 | pad_k = np.zeros([Nagents,*ego_samples[k].shape[1:]]) 72 | combined["action_samples"][k]=np.concatenate((ego_samples[k],pad_k),0) 73 | elif ego_samples is None and agents_samples is not None: 74 | combined["action_samples"] = dict() 75 | for k in agents_samples: 76 | pad_k = np.zeros([Nego,*agents_samples[k].shape[1:]]) 77 | combined["action_samples"][k]=np.concatenate((pad_k,agents_samples[k]),0) 78 | elif ego_samples is not None and agents_samples is not None: 79 | combined["action_samples"] = dict() 80 | for k in ego_samples: 81 | if k in agents_samples: 82 | if ego_samples[k].shape[1]>agents_samples[k].shape[1]: 83 | pad_k = np.zeros([Nagents,ego_samples[k].shape[1]-agents_samples[k].shape[1],*agents_samples[k].shape[2:]]) 84 | agents_samples[k]=np.concatenate((agents_samples[k],pad_k),1) 85 | elif ego_samples[k].shape[1]0: 156 | default_val = list(self._agent_data_by_scene[si][k][ti].values())[0] 157 | ti_k = list() 158 | for ts in range(self._scene_ts[si]): 159 | ti_k.append(self._agent_data_by_scene[si][k][ti][ts] if ts in self._agent_data_by_scene[si][k][ti] else np.ones_like(default_val)*np.nan) 160 | default_val = ti_k[-1] 161 | if not all(elem.shape==ti_k[0].shape for elem in ti_k): 162 | # requires padding 163 | if np.issubdtype(ti_k[0].dtype,float): 164 | padding_value = np.nan 165 | else: 166 | padding_value = 0 167 | ti_k = [x[0] for x in ti_k] 168 | ti_k_torch = TensorUtils.to_tensor(ti_k,ignore_if_unspecified=True) 169 | 170 | ti_k_padded = pad_sequence(ti_k_torch,padding_value=padding_value,batch_first=True) 171 | serialized[si][k].append(TensorUtils.to_numpy(ti_k_padded)[np.newaxis,:]) 172 | else: 173 | if ti_k[0].ndim==0: 174 | serialized[si][k].append(np.array(ti_k)[np.newaxis,:]) 175 | else: 176 | serialized[si][k].append(np.concatenate(ti_k,axis=0)[np.newaxis,:]) 177 | else: 178 | serialized[si][k].append(np.zeros_like(serialized[si][k][-1])) 179 | if not all(elem.shape==serialized[si][k][0].shape for elem in serialized[si][k]): 180 | # requires padding 181 | if np.issubdtype(serialized[si][k][0][0].dtype,float): 182 | padding_value = np.nan 183 | else: 184 | padding_value = 0 185 | axes=[1,0]+np.arange(2,serialized[si][k][0].ndim-1).tolist() 186 | mk_transpose = [np.transpose(x[0],axes) for x in serialized[si][k]] 187 | mk_torch = TensorUtils.to_tensor(mk_transpose,ignore_if_unspecified=True) 188 | mk_padded = pad_sequence(mk_torch,padding_value=padding_value) 189 | mk = TensorUtils.to_numpy(mk_padded) 190 | axes=[1,2,0]+np.arange(3,mk.ndim).tolist() 191 | serialized[si][k]=np.transpose(mk,axes) 192 | else: 193 | serialized[si][k] = np.concatenate(serialized[si][k],axis=0) 194 | 195 | 196 | 197 | self._serialized_scene_buffer = serialized 198 | return deepcopy(self._serialized_scene_buffer) 199 | 200 | def get_trajectory(self): 201 | """Get per-scene rollout trajectory in the world coordinate system""" 202 | buffer = self.get_serialized_scene_buffer() 203 | traj = dict() 204 | for si in buffer: 205 | traj[si] = dict( 206 | positions=buffer[si]["centroid"], 207 | yaws=buffer[si]["yaw"] 208 | ) 209 | return traj 210 | 211 | def get_track_id(self): 212 | return deepcopy(self._agent_id_per_scene) 213 | 214 | def get_stats(self): 215 | # TODO 216 | raise NotImplementedError() 217 | 218 | def log_step(self, obs, action: RolloutAction): 219 | combined_obs = self._combine_obs(obs) 220 | combined_action = self._combine_action(action) 221 | assert combined_obs["scene_index"].shape[0] == combined_action["action"]["positions"].shape[0] 222 | self._maybe_initialize(combined_obs) 223 | self._append_buffer(combined_obs, combined_action) 224 | for si in np.unique(combined_obs["scene_index"]): 225 | self._scene_ts[si]+=1 226 | del combined_obs 227 | -------------------------------------------------------------------------------- /tbsim/utils/timer.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import numpy as np 4 | from contextlib import contextmanager 5 | 6 | 7 | class Timer(object): 8 | """A simple timer.""" 9 | def __init__(self): 10 | self.total_time = 0. 11 | self.calls = 0 12 | self.start_time = 0. 13 | self.diff = 0. 14 | self.average_time = 0. 15 | self.times = [] 16 | 17 | def recent_average_time(self, latest_n): 18 | return np.mean(np.array(self.times)[-latest_n:]) 19 | 20 | def tic(self): 21 | # using time.time instead of time.clock because time time.clock 22 | # does not normalize for multithreading 23 | self.start_time = time.time() 24 | 25 | def toc(self, average=True): 26 | self.diff = time.time() - self.start_time 27 | self.times.append(self.diff) 28 | self.total_time += self.diff 29 | self.calls += 1 30 | self.average_time = self.total_time / self.calls 31 | if average: 32 | return self.average_time 33 | else: 34 | return self.diff 35 | 36 | @contextmanager 37 | def timed(self): 38 | self.tic() 39 | yield 40 | self.toc() 41 | 42 | 43 | class Timers(object): 44 | def __init__(self): 45 | self._timers = {} 46 | 47 | def tic(self, key): 48 | if key not in self._timers: 49 | self._timers[key] = Timer() 50 | self._timers[key].tic() 51 | 52 | def toc(self, key): 53 | self._timers[key].toc() 54 | 55 | @contextmanager 56 | def timed(self, key): 57 | self.tic(key) 58 | yield 59 | self.toc(key) 60 | 61 | def __str__(self): 62 | msg = [] 63 | for k, v in self._timers.items(): 64 | msg.append('%s: %f' % (k, v.average_time)) 65 | return ', '.join(msg) -------------------------------------------------------------------------------- /tbsim/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains some PyTorch utilities. 3 | """ 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.optim as optim 8 | import functools 9 | from tqdm.auto import tqdm 10 | 11 | 12 | def soft_update(source, target, tau): 13 | """ 14 | Soft update from the parameters of a @source torch module to a @target torch module 15 | with strength @tau. The update follows target = target * (1 - tau) + source * tau. 16 | 17 | Args: 18 | source (torch.nn.Module): source network to push target network parameters towards 19 | target (torch.nn.Module): target network to update 20 | """ 21 | for target_param, param in zip(target.parameters(), source.parameters()): 22 | target_param.copy_(target_param * (1.0 - tau) + param * tau) 23 | 24 | 25 | def hard_update(source, target): 26 | """ 27 | Hard update @target parameters to match @source. 28 | 29 | Args: 30 | source (torch.nn.Module): source network to provide parameters 31 | target (torch.nn.Module): target network to update parameters for 32 | """ 33 | for target_param, param in zip(target.parameters(), source.parameters()): 34 | target_param.copy_(param) 35 | 36 | 37 | def get_torch_device(try_to_use_cuda): 38 | """ 39 | Return torch device. If using cuda (GPU), will also set cudnn.benchmark to True 40 | to optimize CNNs. 41 | 42 | Args: 43 | try_to_use_cuda (bool): if True and cuda is available, will use GPU 44 | 45 | Returns: 46 | device (torch.Device): device to use for models 47 | """ 48 | if try_to_use_cuda and torch.cuda.is_available(): 49 | torch.backends.cudnn.benchmark = True 50 | device = torch.device("cuda:0") 51 | else: 52 | device = torch.device("cpu") 53 | return device 54 | 55 | 56 | def reparameterize(mu, logvar): 57 | """ 58 | Reparameterize for the backpropagation of z instead of q. 59 | This makes it so that we can backpropagate through the sampling of z from 60 | our encoder when feeding the sampled variable to the decoder. 61 | 62 | (See "The reparameterization trick" section of https://arxiv.org/abs/1312.6114) 63 | 64 | Args: 65 | mu (torch.Tensor): batch of means from the encoder distribution 66 | logvar (torch.Tensor): batch of log variances from the encoder distribution 67 | 68 | Returns: 69 | z (torch.Tensor): batch of sampled latents from the encoder distribution that 70 | support backpropagation 71 | """ 72 | # logvar = \log(\sigma^2) = 2 * \log(\sigma) 73 | # \sigma = \exp(0.5 * logvar) 74 | 75 | # clamped for numerical stability 76 | logstd = (0.5 * logvar).clamp(-4, 15) 77 | std = torch.exp(logstd) 78 | 79 | # Sample \epsilon from normal distribution 80 | # use std to create a new tensor, so we don't have to care 81 | # about running on GPU or not 82 | eps = std.new(std.size()).normal_() 83 | 84 | # Then multiply with the standard deviation and add the mean 85 | z = eps.mul(std).add_(mu) 86 | 87 | return z 88 | 89 | 90 | def optimizer_from_optim_params(net_optim_params, net): 91 | """ 92 | Helper function to return a torch Optimizer from the optim_params 93 | section of the config for a particular network. 94 | 95 | Args: 96 | optim_params (Config): optim_params part of algo_config corresponding 97 | to @net. This determines the optimizer that is created. 98 | 99 | net (torch.nn.Module): module whose parameters this optimizer will be 100 | responsible 101 | 102 | Returns: 103 | optimizer (torch.optim.Optimizer): optimizer 104 | """ 105 | return optim.Adam( 106 | params=net.parameters(), 107 | lr=net_optim_params["learning_rate"]["initial"], 108 | weight_decay=net_optim_params["regularization"]["L2"], 109 | ) 110 | 111 | 112 | def lr_scheduler_from_optim_params(net_optim_params, net, optimizer): 113 | """ 114 | Helper function to return a LRScheduler from the optim_params 115 | section of the config for a particular network. Returns None 116 | if a scheduler is not needed. 117 | 118 | Args: 119 | optim_params (Config): optim_params part of algo_config corresponding 120 | to @net. This determines whether a learning rate scheduler is created. 121 | 122 | net (torch.nn.Module): module whose parameters this optimizer will be 123 | responsible 124 | 125 | optimizer (torch.optim.Optimizer): optimizer for this net 126 | 127 | Returns: 128 | lr_scheduler (torch.optim.lr_scheduler or None): learning rate scheduler 129 | """ 130 | lr_scheduler = None 131 | if len(net_optim_params["learning_rate"]["epoch_schedule"]) > 0: 132 | # decay LR according to the epoch schedule 133 | lr_scheduler = optim.lr_scheduler.MultiStepLR( 134 | optimizer=optimizer, 135 | milestones=net_optim_params["learning_rate"]["epoch_schedule"], 136 | gamma=net_optim_params["learning_rate"]["decay_factor"], 137 | ) 138 | return lr_scheduler 139 | 140 | 141 | def backprop_for_loss(net, optim, loss, max_grad_norm=None, retain_graph=False): 142 | """ 143 | Backpropagate loss and update parameters for network with 144 | name @name. 145 | 146 | Args: 147 | net (torch.nn.Module): network to update 148 | 149 | optim (torch.optim.Optimizer): optimizer to use 150 | 151 | loss (torch.Tensor): loss to use for backpropagation 152 | 153 | max_grad_norm (float): if provided, used to clip gradients 154 | 155 | retain_graph (bool): if True, graph is not freed after backward call 156 | 157 | Returns: 158 | grad_norms (float): average gradient norms from backpropagation 159 | """ 160 | 161 | # backprop 162 | optim.zero_grad() 163 | loss.backward(retain_graph=retain_graph) 164 | 165 | # gradient clipping 166 | if max_grad_norm is not None: 167 | torch.nn.utils.clip_grad_norm_(net.parameters(), max_grad_norm) 168 | 169 | # compute grad norms 170 | grad_norms = 0.0 171 | for p in net.parameters(): 172 | # only clip gradients for parameters for which requires_grad is True 173 | if p.grad is not None: 174 | grad_norms += p.grad.data.norm(2).pow(2).item() 175 | 176 | # step 177 | optim.step() 178 | 179 | return grad_norms 180 | 181 | 182 | class dummy_context_mgr: 183 | """ 184 | A dummy context manager - useful for having conditional scopes (such 185 | as @maybe_no_grad). Nothing happens in this scope. 186 | """ 187 | 188 | def __enter__(self): 189 | return None 190 | 191 | def __exit__(self, exc_type, exc_value, traceback): 192 | return False 193 | 194 | 195 | def maybe_no_grad(no_grad): 196 | """ 197 | Args: 198 | no_grad (bool): if True, the returned context will be torch.no_grad(), otherwise 199 | it will be a dummy context 200 | """ 201 | return torch.no_grad() if no_grad else dummy_context_mgr() 202 | 203 | 204 | def rgetattr(obj, attr, *args): 205 | "recursively get attributes" 206 | 207 | def _getattr(obj, attr): 208 | return getattr(obj, attr, *args) 209 | 210 | return functools.reduce(_getattr, [obj] + attr.split(".")) 211 | 212 | 213 | def rsetattr(obj, attr, val): 214 | "recursively set attributes" 215 | pre, _, post = attr.rpartition(".") 216 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) 217 | 218 | 219 | class ProgressBar(pl.Callback): 220 | def __init__( 221 | self, global_progress: bool = True, leave_global_progress: bool = True 222 | ): 223 | super().__init__() 224 | 225 | self.global_progress = global_progress 226 | self.global_desc = "Epoch: {epoch}/{max_epoch}" 227 | self.leave_global_progress = leave_global_progress 228 | self.global_pb = None 229 | 230 | def on_fit_start(self, trainer, pl_module): 231 | desc = self.global_desc.format( 232 | epoch=trainer.current_epoch + 1, max_epoch=trainer.max_epochs 233 | ) 234 | 235 | self.global_pb = tqdm( 236 | desc=desc, 237 | total=trainer.max_epochs, 238 | initial=trainer.current_epoch, 239 | leave=self.leave_global_progress, 240 | disable=not self.global_progress, 241 | ) 242 | 243 | def on_fit_end(self, trainer, pl_module): 244 | self.global_pb.close() 245 | self.global_pb = None 246 | 247 | def on_epoch_end(self, trainer, pl_module): 248 | 249 | # Set description 250 | desc = self.global_desc.format( 251 | epoch=trainer.current_epoch + 1, max_epoch=trainer.max_epochs 252 | ) 253 | self.global_pb.set_description(desc) 254 | 255 | # Set logs and metrics 256 | # logs = pl_module.logs 257 | # for k, v in logs.items(): 258 | # if isinstance(v, torch.Tensor): 259 | # logs[k] = v.squeeze().item() 260 | # self.global_pb.set_postfix(logs) 261 | 262 | # Update progress 263 | self.global_pb.update(1) 264 | -------------------------------------------------------------------------------- /tbsim/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains several utility functions used to define the main training loop. It 3 | mainly consists of functions to assist with logging, rollouts, and the @run_epoch function, 4 | which is the core training logic for models in this repository. 5 | """ 6 | import os 7 | import socket 8 | import shutil 9 | 10 | 11 | def infinite_iter(data_loader): 12 | """ 13 | Get an infinite generator 14 | Args: 15 | data_loader (DataLoader): data loader to iterate through 16 | 17 | """ 18 | c_iter = iter(data_loader) 19 | while True: 20 | try: 21 | data = next(c_iter) 22 | except StopIteration: 23 | c_iter = iter(data_loader) 24 | data = next(c_iter) 25 | yield data 26 | 27 | 28 | def get_exp_dir(exp_name, output_dir, save_checkpoints=True, auto_remove_exp_dir=False): 29 | """ 30 | Create experiment directory from config. If an identical experiment directory 31 | exists and @auto_remove_exp_dir is False (default), the function will prompt 32 | the user on whether to remove and replace it, or keep the existing one and 33 | add a new subdirectory with the new timestamp for the current run. 34 | 35 | Args: 36 | exp_name (str): name of the experiment 37 | output_dir (str): output directory of the experiment 38 | save_checkpoints (bool): if save checkpoints 39 | auto_remove_exp_dir (bool): if True, automatically remove the existing experiment 40 | folder if it exists at the same path. 41 | 42 | Returns: 43 | log_dir (str): path to created log directory (sub-folder in experiment directory) 44 | output_dir (str): path to created models directory (sub-folder in experiment directory) 45 | to store model checkpoints 46 | video_dir (str): path to video directory (sub-folder in experiment directory) 47 | to store rollout videos 48 | """ 49 | 50 | # create directory for where to dump model parameters, tensorboard logs, and videos 51 | base_output_dir = output_dir 52 | if not os.path.isabs(base_output_dir): 53 | base_output_dir = os.path.abspath(base_output_dir) 54 | base_output_dir = os.path.join(base_output_dir, exp_name) 55 | if os.path.exists(base_output_dir): 56 | if not auto_remove_exp_dir: 57 | ans = input( 58 | "WARNING: model directory ({}) already exists! \noverwrite? (y/n)\n".format( 59 | base_output_dir 60 | ) 61 | ) 62 | else: 63 | ans = "y" 64 | if ans == "y": 65 | print("REMOVING") 66 | shutil.rmtree(base_output_dir) 67 | os.makedirs(base_output_dir, exist_ok=True) 68 | 69 | # version the run 70 | existing_runs = [ 71 | a 72 | for a in os.listdir(base_output_dir) 73 | if os.path.isdir(os.path.join(base_output_dir, a)) 74 | ] 75 | run_counts = [-1] 76 | for ep in existing_runs: 77 | m = ep.split("run") 78 | if len(m) == 2 and m[0] == "": 79 | run_counts.append(int(m[1])) 80 | version_str = "run{}".format(max(run_counts) + 1) 81 | 82 | # only make model directory if model saving is enabled 83 | ckpt_dir = None 84 | if save_checkpoints: 85 | ckpt_dir = os.path.join(base_output_dir, version_str, "checkpoints") 86 | os.makedirs(ckpt_dir) 87 | 88 | # tensorboard directory 89 | log_dir = os.path.join(base_output_dir, version_str, "logs") 90 | os.makedirs(log_dir) 91 | 92 | # video directory 93 | video_dir = os.path.join(base_output_dir, version_str, "videos") 94 | os.makedirs(video_dir) 95 | return base_output_dir, log_dir, ckpt_dir, video_dir, version_str 96 | -------------------------------------------------------------------------------- /tbsim/utils/tree.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from collections import defaultdict 3 | import networkx as nx 4 | try: 5 | import pydot 6 | except: 7 | print("pydot not found") 8 | 9 | class Tree(object): 10 | 11 | def __init__(self, content, parent, depth): 12 | self.content = content 13 | self.children = list() 14 | self.parent = parent 15 | if parent is not None: 16 | parent.expand(self) 17 | self.depth = depth 18 | self.attribute = dict() 19 | 20 | def expand(self, child): 21 | self.children.append(child) 22 | 23 | def expand_set(self, children): 24 | self.children += children 25 | 26 | def isroot(self): 27 | return self.parent is None 28 | 29 | def isleaf(self): 30 | return len(self.children) == 0 31 | 32 | def get_subseq_trajs(self): 33 | return [child.traj for child in self.children] 34 | 35 | 36 | def get_all_leaves(self,leaf_set=[]): 37 | if self.isleaf(): 38 | leaf_set.append(self) 39 | else: 40 | for child in self.children: 41 | leaf_set = child.get_all_leaves(leaf_set) 42 | return leaf_set 43 | def get_label(self): 44 | raise NotImplementedError 45 | 46 | @staticmethod 47 | def get_nodes_by_level(obj,depth,nodes=None,trim_short_branch=True): 48 | assert obj.depth<=depth 49 | if nodes is None: 50 | nodes = defaultdict(lambda: list()) 51 | if obj.depth==depth: 52 | nodes[depth].append(obj) 53 | return nodes, True 54 | else: 55 | if obj.isleaf(): 56 | return nodes, False 57 | 58 | else: 59 | flag = False 60 | children_flags = dict() 61 | for child in obj.children: 62 | nodes, child_flag = Tree.get_nodes_by_level(child,depth,nodes) 63 | children_flags[child] = child_flag 64 | flag = flag or child_flag 65 | if trim_short_branch: 66 | obj.children = [child for child in obj.children if children_flags[child]] 67 | if flag: 68 | nodes[obj.depth].append(obj) 69 | return nodes, flag 70 | 71 | @staticmethod 72 | def get_children(obj): 73 | if isinstance(obj, Tree): 74 | return obj.children 75 | elif isinstance(obj, list): 76 | children = [node.children for node in obj] 77 | children = list(itertools.chain.from_iterable(children)) 78 | return children 79 | else: 80 | raise TypeError("obj must be a TrajTree or a list") 81 | 82 | def as_network(self): 83 | G = nx.Graph() 84 | G.add_node(self.get_label()) 85 | for child in self.children: 86 | G = nx.union(G,child.as_network()) 87 | G.add_edge(self.get_label(),child.get_label()) 88 | return G 89 | 90 | def plot(self): 91 | G = self.as_network() 92 | 93 | pos = nx.nx_agraph.pygraphviz_layout(G, prog="dot") 94 | nx.draw(G, pos,with_labels = True) 95 | # nx.draw(G, with_labels = True) 96 | 97 | 98 | 99 | 100 | def depth_first_traverse(tree:Tree,func, visited:dict, result): 101 | result = func(tree,result) 102 | visited[tree] = True 103 | for child in tree.children: 104 | if not (child in visited and visited[child]): 105 | result, visited = depth_first_traverse(child, func, visited, result) 106 | return result, visited -------------------------------------------------------------------------------- /trajdata_requirements.txt: -------------------------------------------------------------------------------- 1 | # Python 3.8 2 | 3 | pyarrow 4 | tqdm==4.62 5 | matplotlib==3.5 6 | dill==0.3.4 7 | pandas==1.4.1 8 | pyarrow==7.0.0 9 | 10 | nuscenes-devkit==1.1.9 11 | 12 | black==22.1.0 13 | isort==5.10.1 14 | pytest==7.1.1 15 | pytest-xdist==2.5.0 16 | 17 | zarr==2.11.0 18 | kornia==0.6.4 --------------------------------------------------------------------------------