├── .gitignore ├── LICENSE ├── README.md ├── gt_files └── trajdata_nusc_new │ └── GroundTruth │ └── hist_stats_new_scene_current_speed.json ├── scripts ├── parse_scene_edit_results.py ├── scene_editor.py └── train.py ├── setup.py └── tbsim ├── algos ├── __init__.py ├── algo_utils.py ├── algos.py ├── factory.py ├── metric_algos.py └── multiagent_algos.py ├── configs ├── __init__.py ├── algo_config.py ├── base.py ├── config.py ├── eval_config.py ├── l5kit_config.py ├── nusc_config.py ├── orca_config.py ├── registry.py ├── scene_edit_config.py ├── trajdata_config.py ├── trajdata_drivesim_config.py ├── trajdata_eupeds_config.py ├── trajdata_l5kit_config.py ├── trajdata_nuplan_all_config.py ├── trajdata_nuplan_config.py ├── trajdata_nuplan_ped_config.py ├── trajdata_nuplan_scene_config.py ├── trajdata_nusc_all_config.py ├── trajdata_nusc_config.py ├── trajdata_nusc_ped_config.py └── trajdata_nusc_scene_config.py ├── datasets ├── __init__.py ├── factory.py ├── l5kit_datamodules.py └── trajdata_datamodules.py ├── dynamics ├── __init__.py ├── base.py ├── bicycle.py ├── double_integrator.py ├── single_integrator.py └── unicycle.py ├── envs ├── base.py ├── env_avdata.py ├── env_l5kit.py ├── env_metrics.py └── env_trajdata.py ├── evaluation ├── __init__.py ├── env_builders.py ├── metric_composers.py └── policy_composers.py ├── l5kit ├── agent_sampling_mixed.py ├── l5_agent_dataset.py ├── l5_ego_dataset.py ├── simulation_dataset.py ├── vectorizer.py ├── vis_rasterizer.py └── visualizer.py ├── models ├── GAN_regularizer.py ├── Transformer.py ├── __init__.py ├── base_models.py ├── cnn_roi_encoder.py ├── context_encoders.py ├── diffuser.py ├── diffuser_helpers.py ├── learned_metrics.py ├── multiagent_models.py ├── rasterized_models.py ├── roi_align.py ├── scenediffuser.py ├── scenetemporal.py ├── strive.py ├── temporal.py ├── transformer_model.py └── vaes.py ├── policies ├── __init__.py ├── base.py ├── common.py ├── differential_stack_policy.py ├── hardcoded.py └── wrappers.py ├── rules └── stl_traffic_rules.py └── utils ├── __init__.py ├── batch_utils.py ├── config_utils.py ├── diffuser_utils ├── arrays.py └── progress.py ├── ema.py ├── env_utils.py ├── experiment_utils.py ├── ftocp.py ├── geometry_utils.py ├── gpt_utils.py ├── guidance_loss.py ├── guidance_metrics.py ├── l5_utils.py ├── lane_utils.py ├── log_utils.py ├── loss_utils.py ├── metrics.py ├── planning_utils.py ├── rollout_logger.py ├── scene_edit_utils.py ├── tensor_utils.py ├── timer.py ├── torch_utils.py ├── train_utils.py ├── trajdata_utils.py └── vis_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | videos/ 3 | figures/ 4 | checkpoints 5 | experiments/ 6 | visualizations/ 7 | results/ 8 | new_results/ 9 | new_results_52/ 10 | new_results_20_2/ 11 | new_results_20_3/ 12 | new_results_20_4/ 13 | nusc_results/ 14 | l5_results/ 15 | nuplan_results/ 16 | drivesim_results/ 17 | out/ 18 | experiments/ 19 | visualizations/ 20 | ## config files 21 | #*.yaml 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | 27 | # C extensions 28 | *.so 29 | 30 | # Distribution / packaging 31 | bin/ 32 | build/ 33 | develop-eggs/ 34 | dist/ 35 | eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | .tox/ 51 | .coverage 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | 56 | # Translations 57 | *.mo 58 | 59 | # Mr Developer 60 | .mr.developer.cfg 61 | .project 62 | .pydevproject 63 | 64 | # Rope 65 | .ropeproject 66 | 67 | # Django stuff: 68 | *.log 69 | *.pot 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # speical local ignore 75 | .vscode/* 76 | agent_predictor_trained_models/* 77 | bc_trained_models/* 78 | diffuser_trained_models/* 79 | scene_diffuser_trained_models/* 80 | spatial_planner_trained_models/* 81 | vae_trained_models/* 82 | trained_models/* 83 | trained_models_only/* 84 | trained_models_only_new/* 85 | visualizations/* 86 | scripts/create_evaluation.py 87 | scripts/evaluate.py 88 | tbsim/configs/eval_configs.py 89 | 90 | l5_rasterized_tree_vae_trained_models/test/run0/logs/log.txt 91 | scripts/create_evaluation.py 92 | scripts/create_evaluation.py 93 | tbsim/configs/eval_configs.py 94 | generate_all_evaluations.sh 95 | generate_all_evaluations_l5kit.sh 96 | generate_all_evaluations_nusc.sh 97 | plot_result_l5kit.ipynb 98 | plot_result_nusc.ipynb 99 | plot_result.ipynb 100 | test.py 101 | 102 | unified_data_cache 103 | openai_key.py 104 | lan2stl.md 105 | *.TTF 106 | add_text.py 107 | gpt_query.md 108 | changes.md -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | NVIDIA License 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. 7 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 8 | Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. 9 | 10 | 2. License Grant 11 | 12 | 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 13 | 14 | 3. Limitations 15 | 16 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 17 | 18 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 19 | 20 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 21 | 22 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. 23 | 24 | 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. 25 | 26 | 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. 27 | 28 | 4. Disclaimer of Warranty. 29 | 30 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 31 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. 32 | 33 | 5. Limitation of Liability. 34 | 35 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 36 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Controllable Traffic Generation (CTG) 2 | Codebase of Controllable Traffic Generation (CTG) and Controllable Traffic Generation Plus Plus (CTG++). 3 | 4 | This repo is mostly built on top of [traffic-behavior-simulation (tbsim)](https://github.com/NVlabs/traffic-behavior-simulation). The diffusion model part is built on top of initial implementation in [Diffuser](https://github.com/jannerm/diffuser). It also lightly uses [STLCG](https://github.com/StanfordASL/stlcg). 5 | 6 | 7 | ## Installation 8 | ### Basic (mainly based on tbsim) 9 | Create conda environment (Note nuplan-devkit needs `python>=3.9` so the virtual environment with python version 3.9 needs to be created instead of python 3.8.) 10 | ```angular2html 11 | conda create -n bg3.9 python=3.9 12 | conda activate bg3.9 13 | ``` 14 | 15 | Install `CTG` (this repo) 16 | ```angular2html 17 | git clone https://github.com/NVlabs/CTG.git 18 | cd CTG 19 | pip install -e . 20 | ``` 21 | 22 | Install a customized version of `trajdata` 23 | ```angular2html 24 | cd .. 25 | git clone https://github.com/AIasd/trajdata.git 26 | cd trajdata 27 | pip install -r trajdata_requirements.txt 28 | pip install -e . 29 | ``` 30 | 31 | Install `Pplan` 32 | ```angular2html 33 | cd .. 34 | git clone https://github.com/NVlabs/spline-planner.git Pplan 35 | cd Pplan 36 | pip install -e . 37 | ``` 38 | 39 | ### Potential Issue 40 | One might need to run the following: 41 | ```angular2html 42 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 torchmetrics==0.11.1 torchtext --extra-index-url https://download.pytorch.org/whl/cu113 43 | ``` 44 | 45 | 46 | 47 | ### Extra Setup for STL (CTG) 48 | Install STLCG and switch to `dev` branch 49 | ```angular2html 50 | cd .. 51 | git clone https://github.com/StanfordASL/stlcg.git 52 | cd stlcg 53 | pip install graphviz 54 | pip install -e . 55 | git checkout dev 56 | ``` 57 | 58 | 59 | ### Extra Setup for ChatGPT API (if using lanugage interface of CTG++) 60 | ``` 61 | pip install openai 62 | pip install tiktoken 63 | ``` 64 | Create a file `openai_key.py` and put your openai key in it with the variable name `openai_key`. 65 | 66 | ## Quick start 67 | ### 1. Obtain dataset(s) 68 | We currently support the nuScenes [dataset](https://www.nuscenes.org/nuscenes). 69 | 70 | 71 | #### nuScenes 72 | * Download the nuScenes dataset (with the v1.3 map extension pack) and organize the dataset directory as follows: 73 | ``` 74 | nuscenes/ 75 | │ maps/ 76 | │ v1.0-mini/ 77 | │ v1.0-trainval/ 78 | ``` 79 | 80 | 81 | ### 2. Train a diffuser model 82 | nuScenes dataset (Note: remove `--debug` flag when doing the actual training and to support wandb logging): 83 | ``` 84 | python scripts/train.py --dataset_path --config_name trajdata_nusc_diff --debug 85 | ``` 86 | 87 | A concrete example (CTG): 88 | ``` 89 | python scripts/train.py --dataset_path ../behavior-generation-dataset/nuscenes --config_name trajdata_nusc_diff --debug 90 | ``` 91 | 92 | A concrete example (CTG++): 93 | ``` 94 | python scripts/train.py --dataset_path ../behavior-generation-dataset/nuscenes --config_name trajdata_nusc_scene_diff --debug 95 | ``` 96 | 97 | ### 3. Run rollout of a trained model (closed-loop simulation) 98 | Run Rollout 99 | ``` 100 | python scripts/scene_editor.py \ 101 | --results_root_dir nusc_results/ \ 102 | --num_scenes_per_batch 1 \ 103 | --dataset_path \ 104 | --env trajdata \ 105 | --policy_ckpt_dir \ 106 | --policy_ckpt_key \ 107 | --eval_class \ 108 | --editing_source 'config' 'heuristic' \ 109 | --registered_name 'trajdata_nusc_diff' \ 110 | --render 111 | ``` 112 | 113 | The following is a concrete example for running CTG (when using the pre-trained model): 114 | ``` 115 | python scripts/scene_editor.py \ 116 | --results_root_dir nusc_results/ \ 117 | --num_scenes_per_batch 1 \ 118 | --dataset_path ../behavior-generation-dataset/nuscenes \ 119 | --env trajdata \ 120 | --policy_ckpt_dir ../../summer_project/behavior-generation/trained_models_only_new/trajdata_nusc/ctg_original \ 121 | --policy_ckpt_key iter70000.ckpt \ 122 | --eval_class Diffuser \ 123 | --editing_source 'config' 'heuristic' \ 124 | --registered_name 'trajdata_nusc_diff' \ 125 | --render 126 | ``` 127 | 128 | The following is a concrete example for running CTG++ (when using the pre-trained model): 129 | ``` 130 | python scripts/scene_editor.py \ 131 | --results_root_dir nusc_results/ \ 132 | --num_scenes_per_batch 1 \ 133 | --dataset_path ../behavior-generation-dataset/nuscenes \ 134 | --env trajdata \ 135 | --policy_ckpt_dir ../../summer_project/behavior-generation/trained_models_only_new/trajdata_nusc/ctg++8_9,10edge \ 136 | --policy_ckpt_key iter50000.ckpt \ 137 | --eval_class SceneDiffuser \ 138 | --editing_source 'config' 'heuristic' \ 139 | --registered_name 'trajdata_nusc_scene_diff' \ 140 | --render 141 | ``` 142 | 143 | ### 4.Parse Results for rollout 144 | ``` 145 | python scripts/parse_scene_edit_results.py --results_dir 146 | --estimate_dist 147 | ``` 148 | 149 | ## Pre-trained models 150 | We have provided checkpoints for models of CTG and CTG++ [here](https://drive.google.com/drive/folders/17oYCNGTzBPWjKqvvA8JO67WswyI0j5vw?usp=sharing). 151 | Note that the provided CTG model slightly differ from that in the original CTG paper. The main difference is that the prediction horizon is 52 rather than 20. The pre-trained models are provided under the **CC-BY-NC-SA-4.0 license**. 152 | 153 | ## Configurations 154 | check out `class DiffuserConfig` and `class SceneDiffuserConfig` in `algo_config.py` for algorithm configs, `trajdata_nusc_config.py` for dataset configs, and `scene_edit_config.py` for rollout configs (including changing the guidance used during denoising). 155 | 156 | 157 | ## References 158 | If you find this repo useful, please consider to cite our relevant work: 159 | 160 | ``` 161 | @INPROCEEDINGS{10161463, 162 | author={Zhong, Ziyuan and Rempe, Davis and Xu, Danfei and Chen, Yuxiao and Veer, Sushant and Che, Tong and Ray, Baishakhi and Pavone, Marco}, 163 | booktitle={2023 IEEE International Conference on Robotics and Automation (ICRA)}, 164 | title={Guided Conditional Diffusion for Controllable Traffic Simulation}, 165 | year={2023}, 166 | volume={}, 167 | number={}, 168 | pages={3560-3566}, 169 | doi={10.1109/ICRA48891.2023.10161463}} 170 | 171 | ``` 172 | 173 | ``` 174 | @inproceedings{ 175 | zhong2023languageguided, 176 | title={Language-Guided Traffic Simulation via Scene-Level Diffusion}, 177 | author={Ziyuan Zhong and Davis Rempe and Yuxiao Chen and Boris Ivanovic and Yulong Cao and Danfei Xu and Marco Pavone and Baishakhi Ray}, 178 | booktitle={7th Annual Conference on Robot Learning}, 179 | year={2023}, 180 | url={https://openreview.net/forum?id=nKWQnYkkwX} 181 | } 182 | ``` 183 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # read the contents of your README file 4 | from os import path 5 | 6 | this_directory = path.abspath(path.dirname(__file__)) 7 | with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: 8 | lines = f.readlines() 9 | 10 | # remove images from README 11 | lines = [x for x in lines if ".png" not in x] 12 | long_description = "".join(lines) 13 | 14 | setup( 15 | name="tbsim", 16 | packages=[package for package in find_packages() if package.startswith("tbsim")], 17 | install_requires=[ 18 | "l5kit==1.5.0", 19 | "numpy==1.23.4", # need to manually update numpy version to (1.21.4) due to conflict with l5kit's requirement 20 | "pytorch-lightning==1.8.3.post0", 21 | "wandb", 22 | "torch==1.11", 23 | "torchvision==0.12.0", 24 | "pyemd", 25 | "h5py", 26 | "imageio-ffmpeg", 27 | "casadi", 28 | "protobuf==3.20.1", # new version might cause error 29 | "einops==0.6.0", 30 | "torchtext", # weird pytorch-lightning dependency bug 31 | ], 32 | eager_resources=["*"], 33 | include_package_data=True, 34 | python_requires=">=3", 35 | description="Traffic Behavior Simulation", 36 | author="NVIDIA AV Research", 37 | author_email="danfeix@nvidia.com", 38 | version="0.0.1", 39 | long_description=long_description, 40 | long_description_content_type="text/markdown", 41 | ) 42 | -------------------------------------------------------------------------------- /tbsim/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/CTG/f916c008c3ecf2360bfa050639606eaab7c207f5/tbsim/algos/__init__.py -------------------------------------------------------------------------------- /tbsim/algos/factory.py: -------------------------------------------------------------------------------- 1 | """Factory methods for creating models""" 2 | from pytorch_lightning import LightningDataModule 3 | from tbsim.configs.base import ExperimentConfig 4 | 5 | from tbsim.algos.algos import ( 6 | BehaviorCloning, 7 | TransformerTrafficModel, 8 | TransformerGANTrafficModel, 9 | VAETrafficModel, 10 | DiscreteVAETrafficModel, 11 | BehaviorCloningGC, 12 | SpatialPlanner, 13 | GANTrafficModel, 14 | BehaviorCloningEC, 15 | TreeVAETrafficModel, 16 | DiffuserTrafficModel, 17 | SceneTreeTrafficModel, 18 | STRIVETrafficModel, 19 | SceneDiffuserTrafficModel, 20 | ) 21 | 22 | from tbsim.algos.multiagent_algos import ( 23 | MATrafficModel, 24 | ) 25 | 26 | from tbsim.algos.metric_algos import ( 27 | OccupancyMetric 28 | ) 29 | 30 | 31 | def algo_factory(config: ExperimentConfig, modality_shapes: dict): 32 | """ 33 | A factory for creating training algos 34 | 35 | Args: 36 | config (ExperimentConfig): an ExperimentConfig object, 37 | modality_shapes (dict): a dictionary that maps observation modality names to shapes 38 | 39 | Returns: 40 | algo: pl.LightningModule 41 | """ 42 | algo_config = config.algo 43 | algo_name = algo_config.name 44 | 45 | if algo_name == "bc": 46 | algo = BehaviorCloning(algo_config=algo_config, modality_shapes=modality_shapes) 47 | elif algo_name == "bc_gc": 48 | algo = BehaviorCloningGC(algo_config=algo_config, modality_shapes=modality_shapes) 49 | elif algo_name == "vae": 50 | algo = VAETrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 51 | elif algo_name == "discrete_vae": 52 | algo = DiscreteVAETrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 53 | elif algo_name == "tree_vae": 54 | if algo_config.scene_centric: 55 | algo = SceneTreeTrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 56 | else: 57 | algo = TreeVAETrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 58 | elif algo_name == "bc_ec": 59 | algo = BehaviorCloningEC(algo_config=algo_config, modality_shapes=modality_shapes) 60 | elif algo_name == "spatial_planner": 61 | algo = SpatialPlanner(algo_config=algo_config, modality_shapes=modality_shapes) 62 | elif algo_name == "occupancy": 63 | algo = OccupancyMetric(algo_config=algo_config, modality_shapes=modality_shapes) 64 | elif algo_name == "agent_predictor": 65 | algo = MATrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 66 | elif algo_name == "TransformerPred": 67 | algo = TransformerTrafficModel(algo_config=algo_config) 68 | elif algo_name == "TransformerGAN": 69 | algo = TransformerGANTrafficModel(algo_config=algo_config) 70 | elif algo_name == "gan": 71 | algo = GANTrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 72 | elif algo_name == "diffuser": 73 | algo = DiffuserTrafficModel(algo_config=algo_config, modality_shapes=modality_shapes, registered_name=config.registered_name) 74 | elif algo_name == "strive": 75 | algo = STRIVETrafficModel(algo_config=algo_config, modality_shapes=modality_shapes) 76 | elif algo_name == "scene_diffuser": 77 | algo = SceneDiffuserTrafficModel(algo_config=algo_config, modality_shapes=modality_shapes, registered_name=config.registered_name) 78 | else: 79 | raise NotImplementedError("{} is not a valid algorithm" % algo_name) 80 | return algo 81 | -------------------------------------------------------------------------------- /tbsim/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from tbsim.configs.base import ExperimentConfig 2 | 3 | -------------------------------------------------------------------------------- /tbsim/configs/base.py: -------------------------------------------------------------------------------- 1 | from tbsim.configs.config import Dict 2 | from copy import deepcopy 3 | # CHANGE: Change this to handle train-time rollout 4 | # from tbsim.configs.eval_config import TrainTimeEvaluationConfig 5 | from tbsim.configs.scene_edit_config import TrainTimeEvaluationConfig 6 | 7 | class TrainConfig(Dict): 8 | def __init__(self): 9 | super(TrainConfig, self).__init__() 10 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 11 | self.logging.log_tb = False # enable tensorboard logging 12 | self.logging.log_wandb = True # enable wandb logging 13 | self.logging.wandb_project_name = "tbsim" 14 | self.logging.log_every_n_steps = 10 15 | self.logging.flush_every_n_steps = 100 16 | 17 | ## save config - if and when to save model checkpoints ## 18 | self.save.enabled = True # whether model saving should be enabled or disabled 19 | self.save.every_n_steps = 100 # save model every n epochs 20 | self.save.best_k = 5 21 | self.save.save_best_rollout = False 22 | self.save.save_best_validation = True 23 | 24 | ## evaluation rollout config ## 25 | self.rollout.save_video = True 26 | self.rollout.enabled = False # enable evaluation rollouts 27 | self.rollout.every_n_steps = 1000 # do rollouts every @rate epochs 28 | self.rollout.warm_start_n_steps = 1 # number of steps to wait before starting rollouts 29 | 30 | 31 | ## training config 32 | self.training.batch_size = 100 33 | self.training.num_steps = 100000 34 | self.training.num_data_workers = 0 35 | 36 | ## validation config 37 | self.validation.enabled = True 38 | self.validation.batch_size = 100 39 | self.validation.num_data_workers = 0 40 | self.validation.every_n_steps = 1000 41 | self.validation.num_steps_per_epoch = 100 42 | 43 | ## Training parallelism (e.g., multi-GPU) 44 | self.parallel_strategy = "ddp_spawn" 45 | 46 | self.on_ngc = False 47 | 48 | 49 | class EnvConfig(Dict): 50 | def __init__(self): 51 | super(EnvConfig, self).__init__() 52 | self.name = "my_env" 53 | 54 | 55 | class AlgoConfig(Dict): 56 | def __init__(self): 57 | super(AlgoConfig, self).__init__() 58 | self.name = "my_algo" 59 | 60 | 61 | class ExperimentConfig(Dict): 62 | def __init__( 63 | self, 64 | train_config: TrainConfig, 65 | env_config: EnvConfig, 66 | algo_config: AlgoConfig, 67 | eval_config: TrainTimeEvaluationConfig = None, 68 | registered_name: str = None, 69 | ): 70 | """ 71 | 72 | Args: 73 | train_config (TrainConfig): training config 74 | env_config (EnvConfig): environment config 75 | algo_config (AlgoConfig): algorithm config 76 | registered_name (str): name of the experiment config object in the global config registry 77 | """ 78 | super(ExperimentConfig, self).__init__() 79 | self.registered_name = registered_name 80 | 81 | self.train = train_config 82 | self.env = env_config 83 | self.algo = algo_config 84 | self.eval = TrainTimeEvaluationConfig(registered_name) if eval_config is None else eval_config 85 | 86 | # Write all results to this directory. A new folder with the timestamp will be created 87 | # in this directory, and it will contain three subfolders - "log", "models", and "videos". 88 | # The "log" directory will contain tensorboard and stdout txt logs. The "models" directory 89 | # will contain saved model checkpoints. The "videos" directory contains evaluation rollout 90 | # videos. 91 | self.name = ( 92 | "test" # name of the experiment (creates a subdirectory under root_dir) 93 | ) 94 | 95 | self.root_dir = "{}_trained_models/".format(self.algo.name) 96 | self.seed = 1 # seed for everything (for reproducibility) 97 | 98 | self.devices.num_gpus = 1 # Set to 0 to use CPU 99 | 100 | def clone(self): 101 | return self.__class__( 102 | train_config=deepcopy(self.train), 103 | env_config=deepcopy(self.env), 104 | algo_config=deepcopy(self.algo), 105 | eval_config=deepcopy(self.eval), 106 | registered_name=self.registered_name, 107 | ) 108 | -------------------------------------------------------------------------------- /tbsim/configs/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Basic config class - provides a convenient way to work with nested 3 | dictionaries (by exposing keys as attributes) and to save / load from jsons. 4 | 5 | Based on addict: https://github.com/mewwts/addict 6 | """ 7 | 8 | import json 9 | import copy 10 | import contextlib 11 | from copy import deepcopy 12 | 13 | 14 | class Dict(dict): 15 | 16 | def __init__(__self, *args, **kwargs): 17 | object.__setattr__(__self, '__parent', kwargs.pop('__parent', None)) 18 | object.__setattr__(__self, '__key', kwargs.pop('__key', None)) 19 | object.__setattr__(__self, '__frozen', False) 20 | for arg in args: 21 | if not arg: 22 | continue 23 | elif isinstance(arg, dict): 24 | for key, val in arg.items(): 25 | __self[key] = __self._hook(val) 26 | elif isinstance(arg, tuple) and (not isinstance(arg[0], tuple)): 27 | __self[arg[0]] = __self._hook(arg[1]) 28 | else: 29 | for key, val in iter(arg): 30 | __self[key] = __self._hook(val) 31 | 32 | for key, val in kwargs.items(): 33 | __self[key] = __self._hook(val) 34 | 35 | def __setattr__(self, name, value): 36 | if hasattr(self.__class__, name): 37 | raise AttributeError("'Dict' object attribute " 38 | "'{0}' is read-only".format(name)) 39 | else: 40 | self[name] = value 41 | 42 | def __setitem__(self, name, value): 43 | isFrozen = (hasattr(self, '__frozen') and 44 | object.__getattribute__(self, '__frozen')) 45 | if isFrozen and name not in super(Dict, self).keys(): 46 | raise KeyError(name) 47 | super(Dict, self).__setitem__(name, value) 48 | try: 49 | p = object.__getattribute__(self, '__parent') 50 | key = object.__getattribute__(self, '__key') 51 | except AttributeError: 52 | p = None 53 | key = None 54 | if p is not None: 55 | p[key] = self 56 | object.__delattr__(self, '__parent') 57 | object.__delattr__(self, '__key') 58 | 59 | def __add__(self, other): 60 | if not self.keys(): 61 | return other 62 | else: 63 | self_type = type(self).__name__ 64 | other_type = type(other).__name__ 65 | msg = "unsupported operand type(s) for +: '{}' and '{}'" 66 | raise TypeError(msg.format(self_type, other_type)) 67 | 68 | @classmethod 69 | def _hook(cls, item): 70 | if isinstance(item, dict): 71 | return cls(item) 72 | elif isinstance(item, (list, tuple)): 73 | return type(item)(cls._hook(elem) for elem in item) 74 | return item 75 | 76 | def __getattr__(self, item): 77 | return self.__getitem__(item) 78 | 79 | def __missing__(self, name): 80 | if object.__getattribute__(self, '__frozen'): 81 | raise KeyError(name) 82 | return Dict(__parent=self, __key=name) 83 | 84 | def __delattr__(self, name): 85 | del self[name] 86 | 87 | def __repr__(self): 88 | json_string = json.dumps(self.to_dict(), indent=4) 89 | return json_string 90 | 91 | def to_dict(self): 92 | base = {} 93 | for key, value in self.items(): 94 | if isinstance(value, type(self)): 95 | base[key] = value.to_dict() 96 | elif isinstance(value, (list, tuple)): 97 | base[key] = type(value)( 98 | item.to_dict() if isinstance(item, type(self)) else 99 | item for item in value) 100 | else: 101 | base[key] = value 102 | return base 103 | 104 | def copy(self): 105 | return copy.copy(self) 106 | 107 | def deepcopy(self): 108 | return copy.deepcopy(self) 109 | 110 | def __deepcopy__(self, memo): 111 | other = self.__class__() 112 | memo[id(self)] = other 113 | for key, value in self.items(): 114 | other[copy.deepcopy(key, memo)] = copy.deepcopy(value, memo) 115 | return other 116 | 117 | def update(self, *args, **kwargs): 118 | other = {} 119 | if args: 120 | if len(args) > 1: 121 | raise TypeError() 122 | other.update(args[0]) 123 | other.update(kwargs) 124 | for k, v in other.items(): 125 | if ((k not in self) or 126 | (not isinstance(self[k], dict)) or 127 | (not isinstance(v, dict))): 128 | self[k] = v 129 | else: 130 | self[k].update(v) 131 | 132 | def __getnewargs__(self): 133 | return tuple(self.items()) 134 | 135 | def __getstate__(self): 136 | return self 137 | 138 | def __setstate__(self, state): 139 | self.update(state) 140 | 141 | def __or__(self, other): 142 | if not isinstance(other, (Dict, dict)): 143 | return NotImplemented 144 | new = Dict(self) 145 | new.update(other) 146 | return new 147 | 148 | def __ror__(self, other): 149 | if not isinstance(other, (Dict, dict)): 150 | return NotImplemented 151 | new = Dict(other) 152 | new.update(self) 153 | return new 154 | 155 | def __ior__(self, other): 156 | self.update(other) 157 | return self 158 | 159 | def setdefault(self, key, default=None): 160 | if key in self: 161 | return self[key] 162 | else: 163 | self[key] = default 164 | return default 165 | 166 | def lock(self, should_lock=True): 167 | object.__setattr__(self, '__frozen', should_lock) 168 | for key, val in self.items(): 169 | if isinstance(val, Dict): 170 | val.lock(should_lock) 171 | 172 | def dump(self, filename = None): 173 | json_string = json.dumps(self.to_dict(), indent=4) 174 | if filename is not None: 175 | f = open(filename, "w") 176 | f.write(json_string) 177 | f.close() 178 | return json_string 179 | 180 | def unlock(self): 181 | self.lock(False) 182 | 183 | @contextlib.contextmanager 184 | def unlocked(self): 185 | self.unlock() 186 | yield 187 | self.lock() 188 | 189 | def clone(self): 190 | return deepcopy(self) -------------------------------------------------------------------------------- /tbsim/configs/eval_config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from copy import deepcopy 3 | 4 | from tbsim.configs.config import Dict 5 | 6 | l5kit_indices = [9058, 5232, 14153, 8173, 10314, 7027, 9812, 1090, 9453, 978, 10263, 874, 5563, 9613, 261, 2826, 2175, 9977, 6423, 1069, 1836, 8198, 5034, 6016, 2525, 927, 3634, 11806, 4911, 6192, 11641, 461, 142, 15493, 4919, 8494, 14572, 2402, 308, 1952, 13287, 15614, 6529, 12, 11543, 4558, 489, 6876, 15279, 6095, 5877, 8928, 10599, 16150, 11296, 9382, 13352, 1794, 16122, 12429, 15321, 8614, 12447, 4502, 13235, 2919, 15893, 12960, 7043, 9278, 952, 4699, 768, 13146, 8827, 16212, 10777, 15885, 11319, 9417, 14092, 14873, 6740, 11847, 15331, 15639, 11361, 14784, 13448, 10124, 4872, 3567, 5543, 2214, 7624, 10193, 7297, 1308, 3951, 14001] 7 | class EvaluationConfig(Dict): 8 | def __init__(self): 9 | super(EvaluationConfig, self).__init__() 10 | self.name = None 11 | self.env = "nusc" # [l5kit, nusc] 12 | self.dataset_path = None 13 | self.eval_class = "" 14 | self.seed = 0 15 | self.num_scenes_per_batch = 1 16 | # need to be equal to len(self.nusc.eval_scenes) when nusc is used 17 | # and len(self.l5kit.eval_scenes) when l5kit is used 18 | self.num_scenes_to_evaluate = 1 19 | 20 | self.num_episode_repeats = 1 21 | self.start_frame_index_each_episode = None # if specified, should be the same length as num_episode_repeats 22 | self.seed_each_episode = None # if specified, should be the same length as num_episode_repeats 23 | 24 | self.ego_only = False 25 | self.agent_eval_class = None 26 | 27 | self.ckpt_root_dir = "checkpoints/" 28 | self.experience_hdf5_path = None 29 | self.results_dir = "results/" 30 | 31 | self.ckpt.policy.ngc_job_id = None 32 | self.ckpt.policy.ckpt_dir = None 33 | self.ckpt.policy.ckpt_key = None 34 | 35 | self.ckpt.planner.ngc_job_id = None 36 | self.ckpt.planner.ckpt_dir = None 37 | self.ckpt.planner.ckpt_key = None 38 | 39 | self.ckpt.predictor.ngc_job_id = None 40 | self.ckpt.predictor.ckpt_dir = None 41 | self.ckpt.predictor.ckpt_key = None 42 | 43 | self.ckpt.cvae_metric.ngc_job_id = None 44 | self.ckpt.cvae_metric.ckpt_dir = None 45 | self.ckpt.cvae_metric.ckpt_key = None 46 | 47 | self.ckpt.occupancy_metric.ngc_job_id = None 48 | self.ckpt.occupancy_metric.ckpt_dir = None 49 | self.ckpt.occupancy_metric.ckpt_key = None 50 | 51 | self.policy.mask_drivable = True 52 | self.policy.num_plan_samples = 50 53 | self.policy.num_action_samples = 10 54 | self.policy.pos_to_yaw = True 55 | self.policy.yaw_correction_speed = 1.0 56 | self.policy.diversification_clearance = None 57 | self.policy.sample = True 58 | 59 | 60 | self.policy.cost_weights.collision_weight = 10.0 61 | self.policy.cost_weights.lane_weight = 1.0 62 | self.policy.cost_weights.likelihood_weight = 0.0 # 0.1 63 | self.policy.cost_weights.progress_weight = 0.0 # 0.005 64 | 65 | # CHANGE: add ema, perturb_output_trajectory 66 | self.policy.use_ema = False 67 | self.goal_conditional = False 68 | self.perturb_output_trajectory = False 69 | # perturb_th is large since usually the scales are not standardized 70 | self.perturb_opt_params = {'optimizer':'adam', 'grad_steps':30, 'perturb_th':100.0, 'optimizer_params':{'lr':0.001}} 71 | self.filtration = False 72 | self.num_filtration_samples = 5 73 | 74 | # cvae, bc, diffuser 75 | self.guidance_optimization_params = { 76 | 'optimizer': 'adam', 77 | 'lr': 0.3, 78 | 'grad_steps': 1, # Diffuser: 1, BC: 5, CVAE: 35 79 | 'perturb_th': None, # when None, sigma is used for Diffuser; no threshold for others 80 | } 81 | # diffuser specific 82 | self.denoising_params = { 83 | 'stride': 1, # only for diffuser 84 | 'perturb_output_trajectory': False, # only for diffuser 85 | } 86 | 87 | self.metrics.compute_analytical_metrics = True 88 | self.metrics.compute_learned_metrics = False 89 | 90 | self.perturb.enabled = False 91 | self.perturb.OU.theta = 0.8 92 | self.perturb.OU.sigma = [0.0, 0.1,0.2,0.5,1.0,2.0,4.0] 93 | self.perturb.OU.scale = [1.0,1.0,0.2] 94 | 95 | self.rolling_perturb.enabled = False 96 | self.rolling_perturb.OU.theta = 0.8 97 | self.rolling_perturb.OU.sigma = 0.5 98 | self.rolling_perturb.OU.scale = [1.0,1.0,0.2] 99 | 100 | self.occupancy.rolling = True 101 | self.occupancy.rolling_horizon = [5,10,20] 102 | 103 | self.cvae.rolling = True 104 | self.cvae.rolling_horizon = [5,10,20] 105 | 106 | # to compat nusc with scene_editor 107 | self.nusc.trajdata_source_test = ["nusc_trainval-val"] 108 | self.nusc.trajdata_data_dirs = { 109 | "nusc_trainval" : "../behavior-generation-dataset/nuscenes", 110 | } 111 | self.nusc.future_sec = 5.2 # 2.0, 5.2, 14.0 112 | self.nusc.history_sec = 3.0 # 1.0, 3.0 113 | #---------------------------------------------------------------------------------------------- 114 | self.nusc.eval_scenes = [30] # np.arange(100).tolist() # np.arange(100).tolist() [30] for multi-veh intersection, [75, 79] for simple collision (state), [62, 63] for collison (state-action) 115 | self.nusc.n_step_action = 5 116 | self.nusc.num_simulation_steps = 200 117 | self.nusc.skip_first_n = 0 118 | 119 | self.l5kit.eval_scenes = l5kit_indices # [l5kit_indices[1]] # 120 | self.l5kit.n_step_action = 5 121 | self.l5kit.num_simulation_steps = 200 122 | self.l5kit.skip_first_n = 1 123 | self.l5kit.skimp_rollout = False 124 | 125 | self.adjustment.random_init_plan=False 126 | self.adjustment.remove_existing_neighbors = False 127 | self.adjustment.initial_num_neighbors = 4 128 | self.adjustment.num_frame_per_new_agent = 20 129 | 130 | # to compat nusc with scene_editor 131 | self.trajdata.trajdata_cache_location = "~/.unified_data_cache" 132 | self.trajdata.trajdata_rebuild_cache = False 133 | #---------------------------------------------------------------------------------------------- 134 | 135 | def clone(self): 136 | return deepcopy(self) 137 | 138 | 139 | class TrainTimeEvaluationConfig(EvaluationConfig): 140 | def __init__(self): 141 | super(TrainTimeEvaluationConfig, self).__init__() 142 | 143 | self.num_scenes_per_batch = 4 144 | self.nusc.eval_scenes = np.arange(0, 100, 10).tolist() 145 | self.l5kit.eval_scenes = self.l5kit.eval_scenes[:20] 146 | 147 | self.policy.sample = False 148 | -------------------------------------------------------------------------------- /tbsim/configs/l5kit_config.py: -------------------------------------------------------------------------------- 1 | from tbsim.configs.base import TrainConfig, EnvConfig 2 | 3 | 4 | class L5KitTrainConfig(TrainConfig): 5 | def __init__(self): 6 | super(L5KitTrainConfig, self).__init__() 7 | 8 | self.dataset_path = "/home/yuxiaoc/repos/l5kit/prediction-dataset" 9 | self.dataset_valid_key = "scenes/validate.zarr" 10 | self.dataset_train_key = "scenes/train.zarr" 11 | self.dataset_meta_key = "meta.json" 12 | self.datamodule_class = "L5MixedDataModule" 13 | self.dataset_mode = "agents" 14 | 15 | self.rollout.enabled = True 16 | self.rollout.save_video = True 17 | self.rollout.every_n_steps = 10000 18 | 19 | # training config 20 | self.training.batch_size = 100 21 | self.training.num_steps = 300000 22 | self.training.num_data_workers = 12 23 | 24 | self.save.every_n_steps = 5000 25 | self.save.best_k = 10 26 | 27 | # validation config 28 | self.validation.enabled = True 29 | self.validation.batch_size = 32 30 | self.validation.num_data_workers = 8 31 | self.validation.every_n_steps = 1000 32 | self.validation.num_steps_per_epoch = 50 33 | 34 | 35 | class L5KitMixedEnvConfig(EnvConfig): 36 | """Vectorized Scene Component + Rasterized Map""" 37 | 38 | def __init__(self): 39 | super(L5KitMixedEnvConfig, self).__init__() 40 | self.name = "l5kit" 41 | # the keys are relative to the dataset environment variable 42 | self.rasterizer.semantic_map_key = "semantic_map/semantic_map.pb" 43 | self.rasterizer.dataset_meta_key = "meta.json" 44 | 45 | # e.g. 0.0 include every obstacle, 0.5 show those obstacles with >0.5 probability of being 46 | # one of the classes we care about (cars, bikes, peds, etc.), >=1.0 filter all other agents. 47 | self.rasterizer.filter_agents_threshold = 0.5 48 | 49 | # whether to completely disable traffic light faces in the semantic rasterizer 50 | # this disable option is not supported in avsw_semantic 51 | self.rasterizer.disable_traffic_light_faces = False 52 | 53 | self.generate_agent_obs = False 54 | 55 | self.data_generation_params.other_agents_num = 20 56 | self.data_generation_params.max_agents_distance = 50 57 | self.data_generation_params.lane_params.max_num_lanes = 15 58 | self.data_generation_params.lane_params.max_points_per_lane = 5 59 | self.data_generation_params.lane_params.max_points_per_crosswalk = 5 60 | self.data_generation_params.lane_params.max_retrieval_distance_m = 35 61 | self.data_generation_params.lane_params.max_num_crosswalks = 20 62 | self.data_generation_params.rasterize_agents = False 63 | self.data_generation_params.vectorize_agents = True 64 | 65 | # step size of lane interpolation 66 | self.data_generation_params.lane_params.lane_interp_step_size = 5.0 67 | self.data_generation_params.vectorize_lane = True 68 | 69 | self.rasterizer.raster_size = (224, 224) 70 | 71 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 72 | self.rasterizer.pixel_size = (0.5, 0.5) 73 | 74 | # From 0 to 1 per axis, [0.5,0.5] would show the ego centered in the image. 75 | self.rasterizer.ego_center = (0.25, 0.5) 76 | 77 | self.rasterizer.map_type = "py_semantic" 78 | # self.rasterizer.map_type = "scene_semantic" 79 | 80 | # the keys are relative to the dataset environment variable 81 | self.rasterizer.satellite_map_key = "aerial_map/aerial_map.png" 82 | self.rasterizer.semantic_map_key = "semantic_map/semantic_map.pb" 83 | 84 | # When set to True, the rasterizer will set the raster origin at bottom left, 85 | # i.e. vehicles are driving on the right side of the road. 86 | # With this change, the vertical flipping on the raster used in the visualization code is no longer needed. 87 | # Set it to False for models trained before v1.1.0-25-g3c517f0 (December 2020). 88 | # In that case visualisation will be flipped (we've removed the flip there) but the model's input will be correct. 89 | self.rasterizer.set_origin_to_bottom = True 90 | 91 | # if a tracked agent is closed than this value to ego, it will be controlled 92 | self.simulation.distance_th_far = 50 93 | 94 | # if a new agent is closer than this value to ego, it will be controlled 95 | self.simulation.distance_th_close = 50 96 | 97 | # whether to disable agents that are not returned at start_frame_index 98 | self.simulation.disable_new_agents = False 99 | 100 | # maximum number of simulation steps to run (0.1sec / step) 101 | self.simulation.num_simulation_steps = 50 102 | 103 | # which frame to start an simulation episode with 104 | self.simulation.start_frame_index = 0 105 | 106 | 107 | class L5KitMixedSemanticMapEnvConfig(L5KitMixedEnvConfig): 108 | def __init__(self): 109 | super(L5KitMixedSemanticMapEnvConfig, self).__init__() 110 | self.rasterizer.map_type = "py_semantic" 111 | self.data_generation_params.vectorize_lane = False 112 | self.generate_agent_obs = True 113 | -------------------------------------------------------------------------------- /tbsim/configs/nusc_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from tbsim.configs.base import TrainConfig, EnvConfig, AlgoConfig 4 | 5 | MAX_POINTS_LANE = 5 6 | 7 | 8 | class NuscTrainConfig(TrainConfig): 9 | def __init__(self): 10 | super(NuscTrainConfig, self).__init__() 11 | 12 | self.trajdata_source_train = "train" 13 | self.trajdata_source_train_val = "train_val" 14 | self.trajdata_source_valid = "val" 15 | self.trajdata_source_root = "nusc_trainval" 16 | 17 | self.dataset_path = "SET-THIS-THROUGH-TRAIN-SCRIPT-ARGS" 18 | self.datamodule_class = "UnifiedDataModule" 19 | 20 | self.rollout.enabled = True 21 | self.rollout.save_video = True 22 | self.rollout.every_n_steps = 5000 23 | 24 | # training config 25 | self.training.batch_size = 100 26 | self.training.num_steps = 100000 27 | self.training.num_data_workers = 8 28 | 29 | self.save.every_n_steps = 2000 30 | self.save.best_k = 10 31 | 32 | # validation config 33 | self.validation.enabled = True 34 | self.validation.batch_size = 32 35 | self.validation.num_data_workers = 6 36 | self.validation.every_n_steps = 1000 37 | self.validation.num_steps_per_epoch = 50 38 | 39 | 40 | class NuscEnvConfig(EnvConfig): 41 | def __init__(self): 42 | super(NuscEnvConfig, self).__init__() 43 | 44 | self.name = "nusc" 45 | 46 | # raster image size [pixels] 47 | self.rasterizer.raster_size = 224 48 | 49 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 50 | self.rasterizer.pixel_size = 0.5 51 | 52 | # where the agent is on the map, (0.0, 0.0) is the center 53 | # WARNING: this should not be changed before resolving TODO in parse_trajdata_batch() in trajdata_utils.py 54 | self.rasterizer.ego_center = (-0.5, 0.0) 55 | 56 | # maximum number of agents to consider during training 57 | self.data_generation_params.other_agents_num = 20 58 | 59 | self.data_generation_params.max_agents_distance = 30 60 | 61 | # correct for yaw (zero-out delta yaw) when speed is lower than this threshold 62 | self.data_generation_params.yaw_correction_speed = 1.0 63 | 64 | self.simulation.distance_th_close = 30 65 | 66 | # maximum number of simulation steps to run (0.1sec / step) 67 | self.simulation.num_simulation_steps = 50 68 | 69 | # which frame to start an simulation episode with 70 | self.simulation.start_frame_index = 0 71 | 72 | # whether to get lane information 73 | self.simulation.vectorize_lane = "ego" 74 | -------------------------------------------------------------------------------- /tbsim/configs/orca_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class OrcaTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(OrcaTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | # 13 | # with maps 14 | # 15 | # self.trajdata_source_train = ["orca_maps-train"] 16 | # self.trajdata_source_valid = ["orca_maps-val"] 17 | # 18 | # no maps 19 | # 20 | # self.trajdata_source_train = ["orca_no_maps-train"] 21 | # self.trajdata_source_valid = ["orca_no_maps-val"] 22 | # 23 | # mixed 24 | # 25 | self.trajdata_source_train = ["orca_maps-train", "orca_no_maps-train"] 26 | self.trajdata_source_valid = ["orca_maps-val", "orca_no_maps-val"] 27 | # dict mapping dataset IDs -> root path 28 | # all datasets that will be used must be included here 29 | self.trajdata_data_dirs = { 30 | "orca_maps" : "../orca-sim/datagen_out/datagen_trajdata_v3", 31 | "orca_no_maps" : "../orca-sim/datagen_out/datagen_trajdata_v3", 32 | } 33 | 34 | # for debug 35 | self.trajdata_rebuild_cache = False 36 | 37 | self.rollout.enabled = False 38 | self.rollout.save_video = True 39 | self.rollout.every_n_steps = 5000 40 | self.rollout.warm_start_n_steps = 0 41 | 42 | # training config 43 | self.training.batch_size = 400 44 | self.training.num_steps = 100000 45 | self.training.num_data_workers = 8 46 | 47 | self.save.every_n_steps = 3000 # 1000 48 | self.save.best_k = 5 49 | 50 | # validation config 51 | self.validation.enabled = True 52 | self.validation.batch_size = 32 53 | self.validation.num_data_workers = 4 54 | self.validation.every_n_steps = 200 #570 # 210 55 | self.validation.num_steps_per_epoch = 100 # 25 56 | 57 | self.on_ngc = False 58 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 59 | self.logging.log_tb = False # enable tensorboard logging 60 | self.logging.log_wandb = True # enable wandb logging 61 | self.logging.wandb_project_name = "tbsim" 62 | self.logging.log_every_n_steps = 10 63 | self.logging.flush_every_n_steps = 100 64 | 65 | 66 | class OrcaEnvConfig(TrajdataEnvConfig): 67 | def __init__(self): 68 | super(OrcaEnvConfig, self).__init__() 69 | 70 | # 71 | # with map 72 | # 73 | self.data_generation_params.trajdata_incl_map = True 74 | self.data_generation_params.trajdata_max_agents_distance = np.inf # 0.001 75 | self.rasterizer.num_sem_layers = 2 76 | 77 | # 78 | # no map 79 | # 80 | # self.data_generation_params.trajdata_incl_map = False 81 | # self.data_generation_params.trajdata_max_agents_distance = np.inf 82 | # self.rasterizer.num_sem_layers = 0 83 | 84 | self.data_generation_params.trajdata_only_types = ["pedestrian"] 85 | 86 | # NOTE: rasterization info must still be provided even if incl_map=False 87 | # since still used for agent states 88 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 89 | # how to group layers together to viz RGB image 90 | self.rasterizer.rgb_idx_groups = ([1], [0], [1]) 91 | # raster image size [pixels] 92 | self.rasterizer.raster_size = 224 93 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 94 | self.rasterizer.pixel_size = 1.0 / 12.0 95 | # where the agent is on the map, (0.0, 0.0) is the center and image width is 2.0, i.e. (1.0, 0.0) is the right edge 96 | self.rasterizer.ego_center = (-0.5, 0.0) 97 | # if incl_map = True, but no map is available, will fill dummy map with this value 98 | self.rasterizer.no_map_fill_value = 0.5 # -1.0 -------------------------------------------------------------------------------- /tbsim/configs/trajdata_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | from tbsim.configs.base import TrainConfig, EnvConfig, AlgoConfig 4 | 5 | # 6 | # Base configurations for using unified data loader (trajdata) 7 | # 8 | 9 | class TrajdataTrainConfig(TrainConfig): 10 | def __init__(self): 11 | super(TrajdataTrainConfig, self).__init__() 12 | 13 | # essentially passes through args to unified dataset 14 | self.datamodule_class = "PassUnifiedDataModule" 15 | 16 | self.trajdata_cache_location = "~/.unified_data_cache" 17 | # list of desired_data for training set 18 | self.trajdata_source_train = ["nusc_trainval-train"] 19 | # list of desired_data for validation set 20 | self.trajdata_source_valid = ["nusc_trainval-val"] 21 | # dict mapping dataset IDs -> root path 22 | # all datasets that will be used must be included here 23 | self.trajdata_data_dirs = { 24 | "nusc_trainval" : "../behavior-generation-dataset/nuscenes", 25 | "nusc_test" : "../behavior-generation-dataset/nuscenes", 26 | "nusc_mini" : "../behavior-generation-dataset/nuscenes", 27 | } 28 | 29 | # whether to rebuild the cache or not 30 | self.trajdata_rebuild_cache = False 31 | 32 | # vec map params (only used by algorithm which uses vec map) 33 | self.training_vec_map_params = { 34 | 'S_seg': 15, 35 | 'S_point': 80, 36 | 'map_max_dist': 80, 37 | 'max_heading_error': 0.25*np.pi, 38 | 'ahead_threshold': -40, 39 | 'dist_weight': 1.0, 40 | 'heading_weight': 0.1, 41 | } 42 | 43 | 44 | class TrajdataEnvConfig(EnvConfig): 45 | def __init__(self): 46 | super(TrajdataEnvConfig, self).__init__() 47 | 48 | # NOTE: this should NOT be changed in sub-classes 49 | self.name = "trajdata" 50 | 51 | # 52 | # general data options 53 | # 54 | self.data_generation_params.trajdata_centric = "agent" # or "scene" 55 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 56 | self.data_generation_params.trajdata_only_types = ["vehicle", "pedestrian"] 57 | self.data_generation_params.trajdata_predict_types = None 58 | # list of scene description filters 59 | self.data_generation_params.trajdata_scene_desc_contains = None 60 | # whether or not to include the map in the data 61 | self.data_generation_params.trajdata_incl_map = True 62 | # max distance to be considered neighbors 63 | self.data_generation_params.trajdata_max_agents_distance = np.inf 64 | # standardize position and heading for the predicted agnet 65 | self.data_generation_params.trajdata_standardize_data = True 66 | 67 | # 68 | # map params -- default for nuscenes 69 | # NOTE: rasterization info must still be provided even if incl_map=False 70 | # since still used for agent states 71 | # whether or not to rasterize the agent histories 72 | self.rasterizer.include_hist = True 73 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 74 | self.rasterizer.num_sem_layers = 7 75 | # which layers constitute the drivable area 76 | # None uses the default drivable layers for the given data source 77 | # empty list assumes the entire map is drivable (even regions with 0 in all layers) 78 | # non-empty list only uses the specified layer indices as drivable region 79 | self.rasterizer.drivable_layers = None 80 | # how to group layers together to viz RGB image 81 | self.rasterizer.rgb_idx_groups = ([0, 1, 2], [3, 4], [5, 6]) 82 | # raster image size [pixels] 83 | self.rasterizer.raster_size = 224 84 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 85 | self.rasterizer.pixel_size = 0.5 86 | # where the agent is on the map, (0.0, 0.0) is the center 87 | self.rasterizer.ego_center = (-0.5, 0.0) 88 | # if incl_map = True, but no map is available, will fill dummy map with this value 89 | self.rasterizer.no_map_fill_value = -1.0 90 | -------------------------------------------------------------------------------- /tbsim/configs/trajdata_drivesim_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class DriveSimTrajdataTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(DriveSimTrajdataTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["main"] 13 | self.trajdata_source_valid = ["main"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "drivesim": "home" 18 | } 19 | 20 | # for debug 21 | self.trajdata_rebuild_cache = False 22 | 23 | self.rollout.enabled = True 24 | self.rollout.save_video = True 25 | self.rollout.every_n_steps = 10000 26 | self.rollout.warm_start_n_steps = 0 27 | 28 | # training config 29 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 30 | self.training.batch_size = 100 31 | self.training.num_steps = 100000 32 | self.training.num_data_workers = 8 33 | 34 | self.save.every_n_steps = 10000 35 | self.save.best_k = 10 36 | 37 | # validation config 38 | self.validation.enabled = True 39 | self.validation.batch_size = 32 40 | self.validation.num_data_workers = 6 41 | self.validation.every_n_steps = 500 42 | self.validation.num_steps_per_epoch = 5 # 50 43 | 44 | self.on_ngc = False 45 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 46 | self.logging.log_tb = False # enable tensorboard logging 47 | self.logging.log_wandb = True # enable wandb logging 48 | self.logging.wandb_project_name = "tbsim" 49 | self.logging.log_every_n_steps = 10 50 | self.logging.flush_every_n_steps = 100 51 | 52 | 53 | class DriveSimTrajdataEnvConfig(TrajdataEnvConfig): 54 | def __init__(self): 55 | super(DriveSimTrajdataEnvConfig, self).__init__() 56 | 57 | self.data_generation_params.trajdata_centric = "agent" # or "scene" 58 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 59 | self.data_generation_params.trajdata_only_types = ["vehicle"] 60 | # which types of agents to predict 61 | self.data_generation_params.trajdata_predict_types = ["vehicle"] 62 | # list of scene description filters 63 | self.data_generation_params.trajdata_scene_desc_contains = None 64 | # whether or not to include the map in the data 65 | # TODO: handle mixed map-nomap datasets 66 | self.data_generation_params.trajdata_incl_map = True 67 | # max distance to be considered neighbors 68 | self.data_generation_params.trajdata_max_agents_distance = 300.0 69 | # standardize position and heading for the predicted agnet 70 | self.data_generation_params.trajdata_standardize_data = True 71 | 72 | # NOTE: rasterization info must still be provided even if incl_map=False 73 | # since still used for agent states 74 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 75 | self.rasterizer.num_sem_layers = 3 # 7 76 | # how to group layers together to viz RGB image 77 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 78 | # raster image size [pixels] 79 | self.rasterizer.raster_size = 224 80 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 81 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 82 | # where the agent is on the map, (0.0, 0.0) is the center 83 | self.rasterizer.ego_center = (-0.5, 0.0) 84 | 85 | # maximum number of agents to consider during training 86 | self.data_generation_params.other_agents_num = 50 -------------------------------------------------------------------------------- /tbsim/configs/trajdata_eupeds_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.base import TrainConfig, EnvConfig, AlgoConfig 5 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 6 | 7 | 8 | class EupedsTrainConfig(TrajdataTrainConfig): 9 | def __init__(self): 10 | super(EupedsTrainConfig, self).__init__() 11 | 12 | self.trajdata_cache_location = "~/.unified_data_cache" 13 | # leaves out the ETH-Univ dataset for training 14 | self.trajdata_source_train = ["eupeds_eth-train_loo"] 15 | self.trajdata_source_valid = ["eupeds_eth-val_loo"] 16 | # dict mapping dataset IDs -> root path 17 | # all datasets that will be used must be included here 18 | self.trajdata_data_dirs = { 19 | "eupeds_eth" : "./datasets/eth_ucy", 20 | "eupeds_hotel" : "./datasets/eth_ucy", 21 | "eupeds_univ" : "./datasets/eth_ucy", 22 | "eupeds_zara1" : "./datasets/eth_ucy", 23 | "eupeds_zara2" : "./datasets/eth_ucy" 24 | } 25 | 26 | # for debug 27 | self.trajdata_rebuild_cache = False 28 | 29 | self.rollout.enabled = False 30 | self.rollout.save_video = True 31 | self.rollout.every_n_steps = 5000 32 | self.rollout.warm_start_n_steps = 0 33 | 34 | # training config 35 | # assuming dt=0.4, history_frames=8, future_frames=12 (benchmark setting) 36 | self.training.batch_size = 400 37 | self.training.num_steps = 72000 38 | self.training.num_data_workers = 8 39 | 40 | self.save.every_n_steps = 1000 41 | self.save.best_k = 10 42 | 43 | # validation config 44 | self.validation.enabled = True 45 | self.validation.batch_size = 32 46 | self.validation.num_data_workers = 4 47 | self.validation.every_n_steps = 70 48 | self.validation.num_steps_per_epoch = 20 49 | 50 | self.on_ngc = False 51 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 52 | self.logging.log_tb = False # enable tensorboard logging 53 | self.logging.log_wandb = True # enable wandb logging 54 | self.logging.wandb_project_name = "tbsim" 55 | self.logging.log_every_n_steps = 10 56 | self.logging.flush_every_n_steps = 100 57 | 58 | 59 | class EupedsEnvConfig(TrajdataEnvConfig): 60 | def __init__(self): 61 | super(EupedsEnvConfig, self).__init__() 62 | 63 | # no maps to include 64 | self.data_generation_params.trajdata_incl_map = False 65 | self.data_generation_params.trajdata_only_types = ["pedestrian"] 66 | self.data_generation_params.trajdata_max_agents_distance = np.inf 67 | 68 | # NOTE: rasterization info must still be provided even if incl_map=False 69 | # since still used for agent states 70 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 71 | self.rasterizer.num_sem_layers = 0 72 | # raster image size [pixels] 73 | self.rasterizer.raster_size = 224 74 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 75 | self.rasterizer.pixel_size = 1. / 10. 76 | # where the agent is on the map, (0.0, 0.0) is the center and image width is 2.0, i.e. (1.0, 0.0) is the right edge 77 | self.rasterizer.ego_center = (0.0, 0.0) -------------------------------------------------------------------------------- /tbsim/configs/trajdata_l5kit_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class L5KitTrajdataTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(L5KitTrajdataTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["lyft_train-train"] 13 | self.trajdata_source_valid = ["lyft_val-val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "lyft_train" : "../behavior-generation-dataset/lyft_prediction/scenes/train.zarr", 18 | "lyft_val" : "../behavior-generation-dataset/lyft_prediction/scenes/validate.zarr", 19 | } 20 | # for debug 21 | self.trajdata_rebuild_cache = False 22 | 23 | self.rollout.enabled = True 24 | self.rollout.save_video = True 25 | self.rollout.every_n_steps = 10000 26 | self.rollout.warm_start_n_steps = 0 27 | 28 | # training config 29 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 30 | self.training.batch_size = 100 31 | self.training.num_steps = 100000 32 | self.training.num_data_workers = 8 33 | 34 | self.save.every_n_steps = 10000 35 | self.save.best_k = 10 36 | 37 | # validation config 38 | self.validation.enabled = True 39 | self.validation.batch_size = 32 40 | self.validation.num_data_workers = 4 41 | self.validation.every_n_steps = 500 42 | self.validation.num_steps_per_epoch = 50 43 | 44 | self.on_ngc = False 45 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 46 | self.logging.log_tb = False # enable tensorboard logging 47 | self.logging.log_wandb = True # enable wandb logging 48 | self.logging.wandb_project_name = "tbsim" 49 | self.logging.log_every_n_steps = 10 50 | self.logging.flush_every_n_steps = 100 51 | 52 | 53 | class L5KitTrajdataEnvConfig(TrajdataEnvConfig): 54 | def __init__(self): 55 | super(L5KitTrajdataEnvConfig, self).__init__() 56 | 57 | self.data_generation_params.trajdata_centric = "agent" # or "scene" 58 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 59 | self.data_generation_params.trajdata_only_types = ["vehicle"] 60 | # list of scene description filters 61 | self.data_generation_params.trajdata_scene_desc_contains = None 62 | # whether or not to include the map in the data 63 | # TODO: handle mixed map-nomap datasets 64 | self.data_generation_params.trajdata_incl_map = True 65 | # max distance to be considered neighbors 66 | self.data_generation_params.trajdata_max_agents_distance = 30.0 67 | # standardize position and heading for the predicted agnet 68 | self.data_generation_params.trajdata_standardize_data = True 69 | 70 | # NOTE: rasterization info must still be provided even if incl_map=False 71 | # since still used for agent states 72 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 73 | self.rasterizer.num_sem_layers = 3 74 | # how to group layers together to viz RGB image 75 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 76 | # raster image size [pixels] 77 | self.rasterizer.raster_size = 224 78 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 79 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 80 | # where the agent is on the map, (0.0, 0.0) is the center 81 | self.rasterizer.ego_center = (-0.5, 0.0) 82 | 83 | # maximum number of agents to consider during training 84 | self.data_generation_params.other_agents_num = 20 -------------------------------------------------------------------------------- /tbsim/configs/trajdata_nuplan_all_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class NuplanTrajdataAllTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(NuplanTrajdataAllTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["nuplan_mini-mini_train"] 13 | self.trajdata_source_valid = ["nuplan_mini-mini_val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "nuplan_mini" : "../behavior-generation-dataset/nuplan/dataset/nuplan-v1.1", 18 | } 19 | 20 | # for debug 21 | self.trajdata_rebuild_cache = False 22 | 23 | self.rollout.enabled = True 24 | self.rollout.save_video = True 25 | self.rollout.every_n_steps = 10000 26 | self.rollout.warm_start_n_steps = 0 27 | 28 | # training config 29 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 30 | self.training.batch_size = 2 # 100 31 | self.training.num_steps = 100000 32 | self.training.num_data_workers = 8 33 | 34 | self.save.every_n_steps = 10000 35 | self.save.best_k = 10 36 | 37 | # validation config 38 | self.validation.enabled = True 39 | self.validation.batch_size = 2 # 32 40 | self.validation.num_data_workers = 6 41 | self.validation.every_n_steps = 500 42 | self.validation.num_steps_per_epoch = 5 # 50 43 | 44 | self.on_ngc = False 45 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 46 | self.logging.log_tb = False # enable tensorboard logging 47 | self.logging.log_wandb = True # enable wandb logging 48 | self.logging.wandb_project_name = "tbsim" 49 | self.logging.log_every_n_steps = 10 50 | self.logging.flush_every_n_steps = 100 51 | 52 | 53 | class NuplanTrajdataAllEnvConfig(TrajdataEnvConfig): 54 | def __init__(self): 55 | super(NuplanTrajdataAllEnvConfig, self).__init__() 56 | 57 | self.data_generation_params.trajdata_centric = "agent" # or "scene" 58 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 59 | self.data_generation_params.trajdata_only_types = ["vehicle", "pedestrian"] 60 | # which types of agents to predict 61 | self.data_generation_params.trajdata_predict_types = ["vehicle", "pedestrian"] 62 | # list of scene description filters 63 | self.data_generation_params.trajdata_scene_desc_contains = None 64 | # whether or not to include the map in the data 65 | # TODO: handle mixed map-nomap datasets 66 | self.data_generation_params.trajdata_incl_map = True 67 | # maximum number of agents to consider during training. Note: it is currently only effective when centric is "scene". 68 | self.data_generation_params.other_agents_num = 20 69 | # max distance to be considered neighbors 70 | self.data_generation_params.trajdata_max_agents_distance = 30.0 71 | # standardize position and heading for the predicted agnet 72 | self.data_generation_params.trajdata_standardize_data = True 73 | 74 | # NOTE: rasterization info must still be provided even if incl_map=False 75 | # since still used for agent states 76 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 77 | self.rasterizer.num_sem_layers = 3 # 7 78 | # how to group layers together to viz RGB image 79 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 80 | # raster image size [pixels] 81 | self.rasterizer.raster_size = 224 82 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 83 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 84 | # where the agent is on the map, (0.0, 0.0) is the center 85 | self.rasterizer.ego_center = (-0.5, 0.0) -------------------------------------------------------------------------------- /tbsim/configs/trajdata_nuplan_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class NuplanTrajdataTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(NuplanTrajdataTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["nuplan_mini-mini_train"] 13 | self.trajdata_source_valid = ["nuplan_mini-mini_val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "nuplan_mini" : "../behavior-generation-dataset/nuplan/dataset/nuplan-v1.1", 18 | } 19 | 20 | # for debug 21 | self.trajdata_rebuild_cache = False 22 | 23 | self.rollout.enabled = True 24 | self.rollout.save_video = True 25 | self.rollout.every_n_steps = 10000 26 | self.rollout.warm_start_n_steps = 0 27 | 28 | # training config 29 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 30 | self.training.batch_size = 50 # 100 31 | self.training.num_steps = 100000 32 | self.training.num_data_workers = 8 33 | 34 | self.save.every_n_steps = 10000 35 | self.save.best_k = 10 36 | 37 | # validation config 38 | self.validation.enabled = True 39 | self.validation.batch_size = 32 40 | self.validation.num_data_workers = 6 41 | self.validation.every_n_steps = 500 42 | self.validation.num_steps_per_epoch = 5 # 50 43 | 44 | self.on_ngc = False 45 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 46 | self.logging.log_tb = False # enable tensorboard logging 47 | self.logging.log_wandb = True # enable wandb logging 48 | self.logging.wandb_project_name = "tbsim" 49 | self.logging.log_every_n_steps = 10 50 | self.logging.flush_every_n_steps = 100 51 | 52 | 53 | class NuplanTrajdataEnvConfig(TrajdataEnvConfig): 54 | def __init__(self): 55 | super(NuplanTrajdataEnvConfig, self).__init__() 56 | 57 | self.data_generation_params.trajdata_centric = "agent" # or "scene" 58 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 59 | self.data_generation_params.trajdata_only_types = ["vehicle"] 60 | # which types of agents to predict 61 | self.data_generation_params.trajdata_predict_types = ["vehicle"] 62 | # list of scene description filters 63 | self.data_generation_params.trajdata_scene_desc_contains = None 64 | # whether or not to include the map in the data 65 | # TODO: handle mixed map-nomap datasets 66 | self.data_generation_params.trajdata_incl_map = True 67 | # max distance to be considered neighbors 68 | self.data_generation_params.trajdata_max_agents_distance = 50 69 | # standardize position and heading for the predicted agnet 70 | self.data_generation_params.trajdata_standardize_data = True 71 | 72 | # NOTE: rasterization info must still be provided even if incl_map=False 73 | # since still used for agent states 74 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 75 | self.rasterizer.num_sem_layers = 3 # 7 76 | # how to group layers together to viz RGB image 77 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 78 | # raster image size [pixels] 79 | self.rasterizer.raster_size = 224 80 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 81 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 82 | # where the agent is on the map, (0.0, 0.0) is the center 83 | self.rasterizer.ego_center = (-0.5, 0.0) 84 | 85 | # max_agent_num (int, optional): The maximum number of agents to include in a batch for scene-centric batching. 86 | self.data_generation_params.other_agents_num = None 87 | 88 | # max_neighbor_num (int, optional): The maximum number of neighbors to include in a batch for agent-centric batching. 89 | self.data_generation_params.max_neighbor_num = 20 -------------------------------------------------------------------------------- /tbsim/configs/trajdata_nuplan_ped_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class NuplanTrajdataPedTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(NuplanTrajdataPedTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["nuplan_mini-mini_train"] 13 | self.trajdata_source_valid = ["nuplan_mini-mini_val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "nuplan_mini" : "../behavior-generation-dataset/nuplan/dataset/nuplan-v1.1", 18 | } 19 | 20 | # for debug 21 | self.trajdata_rebuild_cache = False 22 | 23 | self.rollout.enabled = True 24 | self.rollout.save_video = True 25 | self.rollout.every_n_steps = 10000 26 | self.rollout.warm_start_n_steps = 0 27 | 28 | # training config 29 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 30 | self.training.batch_size = 100 31 | self.training.num_steps = 100000 32 | self.training.num_data_workers = 8 33 | 34 | self.save.every_n_steps = 10000 35 | self.save.best_k = 10 36 | 37 | # validation config 38 | self.validation.enabled = True 39 | self.validation.batch_size = 32 40 | self.validation.num_data_workers = 6 41 | self.validation.every_n_steps = 500 42 | self.validation.num_steps_per_epoch = 5 # 50 43 | 44 | self.on_ngc = False 45 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 46 | self.logging.log_tb = False # enable tensorboard logging 47 | self.logging.log_wandb = True # enable wandb logging 48 | self.logging.wandb_project_name = "tbsim" 49 | self.logging.log_every_n_steps = 10 50 | self.logging.flush_every_n_steps = 100 51 | 52 | 53 | class NuplanTrajdataPedEnvConfig(TrajdataEnvConfig): 54 | def __init__(self): 55 | super(NuplanTrajdataPedEnvConfig, self).__init__() 56 | 57 | self.data_generation_params.trajdata_centric = "agent" # or "scene" 58 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 59 | self.data_generation_params.trajdata_only_types = ["pedestrian"] 60 | # which types of agents to predict 61 | self.data_generation_params.trajdata_predict_types = ["pedestrian"] 62 | # list of scene description filters 63 | self.data_generation_params.trajdata_scene_desc_contains = None 64 | # whether or not to include the map in the data 65 | # TODO: handle mixed map-nomap datasets 66 | self.data_generation_params.trajdata_incl_map = True 67 | # maximum number of agents to consider during training. Note: it is currently only effective when centric is "scene". 68 | self.data_generation_params.other_agents_num = 20 69 | # max distance to be considered neighbors 70 | self.data_generation_params.trajdata_max_agents_distance = 30.0 71 | # standardize position and heading for the predicted agnet 72 | self.data_generation_params.trajdata_standardize_data = True 73 | 74 | # NOTE: rasterization info must still be provided even if incl_map=False 75 | # since still used for agent states 76 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 77 | self.rasterizer.num_sem_layers = 3 # 7 78 | # how to group layers together to viz RGB image 79 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 80 | # raster image size [pixels] 81 | self.rasterizer.raster_size = 224 82 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 83 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 84 | # where the agent is on the map, (0.0, 0.0) is the center 85 | self.rasterizer.ego_center = (-0.5, 0.0) -------------------------------------------------------------------------------- /tbsim/configs/trajdata_nuplan_scene_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class NuplanTrajdataSceneTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(NuplanTrajdataSceneTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["nuplan_mini-mini_train"] 13 | self.trajdata_source_valid = ["nuplan_mini-mini_val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "nuplan_mini" : "../behavior-generation-dataset/nuplan/dataset/nuplan-v1.1", 18 | } 19 | 20 | # for debug 21 | self.trajdata_rebuild_cache = False 22 | 23 | self.rollout.enabled = True 24 | self.rollout.save_video = True 25 | self.rollout.every_n_steps = 10000 26 | self.rollout.warm_start_n_steps = 0 27 | 28 | # training config 29 | self.training.batch_size = 4 # 100 30 | self.training.num_steps = 100000 31 | self.training.num_data_workers = 8 32 | 33 | self.save.every_n_steps = 10000 34 | self.save.best_k = 10 35 | 36 | # validation config 37 | self.validation.enabled = True 38 | self.validation.batch_size = 1 # 32 39 | self.validation.num_data_workers = 6 40 | self.validation.every_n_steps = 500 41 | self.validation.num_steps_per_epoch = 5 # 50 42 | 43 | self.on_ngc = False 44 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 45 | self.logging.log_tb = False # enable tensorboard logging 46 | self.logging.log_wandb = True # enable wandb logging 47 | self.logging.wandb_project_name = "tbsim" 48 | self.logging.log_every_n_steps = 10 49 | self.logging.flush_every_n_steps = 100 50 | 51 | # vec map params (only used by algorithm which uses vec map) 52 | self.training_vec_map_params = { 53 | 'S_seg': 15, 54 | 'S_point': 80, 55 | 'map_max_dist': 80, 56 | 'max_heading_error': 0.25*np.pi, 57 | 'ahead_threshold': -40, 58 | 'dist_weight': 1.0, 59 | 'heading_weight': 0.1, 60 | } 61 | 62 | 63 | class NuplanTrajdataSceneEnvConfig(TrajdataEnvConfig): 64 | def __init__(self): 65 | super(NuplanTrajdataSceneEnvConfig, self).__init__() 66 | 67 | self.data_generation_params.trajdata_centric = "scene" # ["agent", "scene"] 68 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 69 | self.data_generation_params.trajdata_only_types = ["vehicle"] 70 | # which types of agents to predict 71 | self.data_generation_params.trajdata_predict_types = ["vehicle"] 72 | # list of scene description filters 73 | self.data_generation_params.trajdata_scene_desc_contains = None 74 | # whether or not to include the map in the data 75 | # TODO: handle mixed map-nomap datasets 76 | self.data_generation_params.trajdata_incl_map = True 77 | # For both training and testing: 78 | # if scene-centric, max distance to scene center to be included for batching 79 | # if agent-centric, max distance to each agent to be considered neighbors 80 | self.data_generation_params.trajdata_max_agents_distance = 50.0 81 | # standardize position and heading for the predicted agnet 82 | self.data_generation_params.trajdata_standardize_data = True 83 | 84 | # NOTE: rasterization info must still be provided even if incl_map=False 85 | # since still used for agent states 86 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 87 | self.rasterizer.num_sem_layers = 3 # 7 88 | # how to group layers together to viz RGB image 89 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 90 | # raster image size [pixels] 91 | self.rasterizer.raster_size = 224 92 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 93 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 94 | # where the agent is on the map, (0.0, 0.0) is the center 95 | self.rasterizer.ego_center = (-0.5, 0.0) 96 | 97 | # max_agent_num (int, optional): The maximum number of agents to include in a batch for scene-centric batching. 98 | self.data_generation_params.other_agents_num = 20 # None # 20 99 | 100 | # max_neighbor_num (int, optional): The maximum number of neighbors to include in a batch for agent-centric batching. -------------------------------------------------------------------------------- /tbsim/configs/trajdata_nusc_all_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class NuscTrajdataAllTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(NuscTrajdataAllTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["nusc_trainval-train", "nusc_trainval-train_val"] 13 | self.trajdata_source_valid = ["nusc_trainval-val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "nusc_trainval" : "../behavior-generation-dataset/nuscenes", 18 | "nusc_test" : "../behavior-generation-dataset/nuscenes", 19 | "nusc_mini" : "../behavior-generation-dataset/nuscenes", 20 | } 21 | 22 | # for debug 23 | self.trajdata_rebuild_cache = False 24 | 25 | self.rollout.enabled = True 26 | self.rollout.save_video = True 27 | self.rollout.every_n_steps = 10000 28 | self.rollout.warm_start_n_steps = 0 29 | 30 | # training config 31 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 32 | self.training.batch_size = 100 33 | self.training.num_steps = 100000 34 | self.training.num_data_workers = 8 35 | 36 | self.save.every_n_steps = 10000 37 | self.save.best_k = 10 38 | 39 | # validation config 40 | self.validation.enabled = True 41 | self.validation.batch_size = 32 42 | self.validation.num_data_workers = 6 43 | self.validation.every_n_steps = 500 44 | self.validation.num_steps_per_epoch = 5 # 50 45 | 46 | self.on_ngc = False 47 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 48 | self.logging.log_tb = False # enable tensorboard logging 49 | self.logging.log_wandb = True # enable wandb logging 50 | self.logging.wandb_project_name = "tbsim" 51 | self.logging.log_every_n_steps = 10 52 | self.logging.flush_every_n_steps = 100 53 | 54 | 55 | class NuscTrajdataAllEnvConfig(TrajdataEnvConfig): 56 | def __init__(self): 57 | super(NuscTrajdataAllEnvConfig, self).__init__() 58 | 59 | self.data_generation_params.trajdata_centric = "agent" # or "scene" 60 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 61 | self.data_generation_params.trajdata_only_types = ["vehicle", "pedestrian"] 62 | # which types of agents to predict 63 | self.data_generation_params.trajdata_predict_types = ["vehicle", "pedestrian"] 64 | # list of scene description filters 65 | self.data_generation_params.trajdata_scene_desc_contains = None 66 | # whether or not to include the map in the data 67 | # TODO: handle mixed map-nomap datasets 68 | self.data_generation_params.trajdata_incl_map = True 69 | # max distance to be considered neighbors 70 | self.data_generation_params.trajdata_max_agents_distance = 30.0 71 | # standardize position and heading for the predicted agnet 72 | self.data_generation_params.trajdata_standardize_data = True 73 | 74 | # NOTE: rasterization info must still be provided even if incl_map=False 75 | # since still used for agent states 76 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 77 | self.rasterizer.num_sem_layers = 3 # 7 78 | # how to group layers together to viz RGB image 79 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 80 | # raster image size [pixels] 81 | self.rasterizer.raster_size = 224 82 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 83 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 84 | # where the agent is on the map, (0.0, 0.0) is the center 85 | self.rasterizer.ego_center = (-0.5, 0.0) 86 | 87 | # maximum number of agents to consider during training 88 | self.data_generation_params.other_agents_num = 20 -------------------------------------------------------------------------------- /tbsim/configs/trajdata_nusc_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class NuscTrajdataTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(NuscTrajdataTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["nusc_trainval-train", "nusc_trainval-train_val"] 13 | self.trajdata_source_valid = ["nusc_trainval-val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "nusc_trainval" : "../behavior-generation-dataset/nuscenes", 18 | "nusc_test" : "../behavior-generation-dataset/nuscenes", 19 | "nusc_mini" : "../behavior-generation-dataset/nuscenes", 20 | } 21 | 22 | # for debug 23 | self.trajdata_rebuild_cache = False 24 | 25 | self.rollout.enabled = True 26 | self.rollout.save_video = True 27 | self.rollout.every_n_steps = 10000 28 | self.rollout.warm_start_n_steps = 0 29 | 30 | # training config 31 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 32 | self.training.batch_size = 100 # 4 # 100 33 | self.training.num_steps = 100000 34 | self.training.num_data_workers = 8 35 | 36 | self.save.every_n_steps = 10000 37 | self.save.best_k = 10 38 | 39 | # validation config 40 | self.validation.enabled = True 41 | self.validation.batch_size = 32 # 4 # 32 42 | self.validation.num_data_workers = 6 43 | self.validation.every_n_steps = 500 44 | self.validation.num_steps_per_epoch = 5 # 50 45 | 46 | self.on_ngc = False 47 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 48 | self.logging.log_tb = False # enable tensorboard logging 49 | self.logging.log_wandb = True # enable wandb logging 50 | self.logging.wandb_project_name = "tbsim" 51 | self.logging.log_every_n_steps = 10 52 | self.logging.flush_every_n_steps = 100 53 | 54 | 55 | class NuscTrajdataEnvConfig(TrajdataEnvConfig): 56 | def __init__(self): 57 | super(NuscTrajdataEnvConfig, self).__init__() 58 | 59 | self.data_generation_params.trajdata_centric = "agent" # "agent", "scene" 60 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 61 | self.data_generation_params.trajdata_only_types = ["vehicle"] 62 | # which types of agents to predict 63 | self.data_generation_params.trajdata_predict_types = ["vehicle"] 64 | # list of scene description filters 65 | self.data_generation_params.trajdata_scene_desc_contains = None 66 | # whether or not to include the map in the data 67 | # TODO: handle mixed map-nomap datasets 68 | self.data_generation_params.trajdata_incl_map = True 69 | # max distance to be considered neighbors 70 | self.data_generation_params.trajdata_max_agents_distance = np.inf 71 | # standardize position and heading for the predicted agnet 72 | self.data_generation_params.trajdata_standardize_data = True 73 | 74 | # NOTE: rasterization info must still be provided even if incl_map=False 75 | # since still used for agent states 76 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 77 | self.rasterizer.num_sem_layers = 3 # 7 78 | # how to group layers together to viz RGB image 79 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 80 | # raster image size [pixels] 81 | self.rasterizer.raster_size = 224 82 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 83 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 84 | # where the agent is on the map, (0.0, 0.0) is the center 85 | self.rasterizer.ego_center = (-0.5, 0.0) 86 | 87 | # max_agent_num (int, optional): The maximum number of agents to include in a batch for scene-centric batching. 88 | self.data_generation_params.other_agents_num = None 89 | 90 | # max_neighbor_num (int, optional): The maximum number of neighbors to include in a batch for agent-centric batching. -------------------------------------------------------------------------------- /tbsim/configs/trajdata_nusc_ped_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class NuscTrajdataPedTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(NuscTrajdataPedTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["nusc_trainval-train", "nusc_trainval-train_val"] 13 | self.trajdata_source_valid = ["nusc_trainval-val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "nusc_trainval" : "../behavior-generation-dataset/nuscenes", 18 | "nusc_test" : "../behavior-generation-dataset/nuscenes", 19 | "nusc_mini" : "../behavior-generation-dataset/nuscenes", 20 | } 21 | 22 | # for debug 23 | self.trajdata_rebuild_cache = False 24 | 25 | self.rollout.enabled = True 26 | self.rollout.save_video = True 27 | self.rollout.every_n_steps = 10000 28 | self.rollout.warm_start_n_steps = 0 29 | 30 | # training config 31 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 32 | self.training.batch_size = 100 33 | self.training.num_steps = 100000 34 | self.training.num_data_workers = 8 35 | 36 | self.save.every_n_steps = 10000 37 | self.save.best_k = 10 38 | 39 | # validation config 40 | self.validation.enabled = True 41 | self.validation.batch_size = 32 42 | self.validation.num_data_workers = 6 43 | self.validation.every_n_steps = 500 44 | self.validation.num_steps_per_epoch = 5 # 50 45 | 46 | self.on_ngc = False 47 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 48 | self.logging.log_tb = False # enable tensorboard logging 49 | self.logging.log_wandb = True # enable wandb logging 50 | self.logging.wandb_project_name = "tbsim" 51 | self.logging.log_every_n_steps = 10 52 | self.logging.flush_every_n_steps = 100 53 | 54 | 55 | class NuscTrajdataPedEnvConfig(TrajdataEnvConfig): 56 | def __init__(self): 57 | super(NuscTrajdataPedEnvConfig, self).__init__() 58 | 59 | # # 60 | # # with map, rasterized history 61 | # # 62 | # self.data_generation_params.trajdata_incl_map = True 63 | # self.data_generation_params.trajdata_max_agents_distance = np.inf 64 | # self.rasterizer.num_sem_layers = 7 65 | # self.rasterizer.drivable_layers = [] #[0, 1, 2] 66 | # self.rasterizer.include_hist = True # depends on the model being used 67 | 68 | # 69 | # with map, non-rasterized history 70 | # 71 | self.data_generation_params.trajdata_incl_map = True 72 | self.data_generation_params.trajdata_max_agents_distance = 15.0 73 | self.rasterizer.num_sem_layers = 3 74 | self.rasterizer.drivable_layers = [] #[0, 1, 2] every layer is "drivable" for a pedestrian 75 | self.rasterizer.include_hist = False # depends on the model being used 76 | 77 | # which types of neighbor agents 78 | # self.data_generation_params.trajdata_only_types = ["vehicle", "pedestrian", "bicycle", "motorcycle"] 79 | self.data_generation_params.trajdata_only_types = ["pedestrian"] 80 | # which types of agents to predict 81 | self.data_generation_params.trajdata_predict_types = ["pedestrian"] 82 | 83 | # NOTE: rasterization info must still be provided even if incl_map=False 84 | # since still used for agent states 85 | # how to group layers together to viz RGB image 86 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 87 | # raster image size [pixels] 88 | self.rasterizer.raster_size = 224 89 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 90 | self.rasterizer.pixel_size = 1.0 / 12.0 # 12 px/m 91 | # where the agent is on the map, (0.0, 0.0) is the center 92 | self.rasterizer.ego_center = (-0.5, 0.0) 93 | # if incl_map = True, but no map is available, will fill dummy map with this value 94 | self.rasterizer.no_map_fill_value = 0.5 # -1.0 -------------------------------------------------------------------------------- /tbsim/configs/trajdata_nusc_scene_config.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | from tbsim.configs.trajdata_config import TrajdataTrainConfig, TrajdataEnvConfig 5 | 6 | 7 | class NuscTrajdataSceneTrainConfig(TrajdataTrainConfig): 8 | def __init__(self): 9 | super(NuscTrajdataSceneTrainConfig, self).__init__() 10 | 11 | self.trajdata_cache_location = "~/.unified_data_cache" 12 | self.trajdata_source_train = ["nusc_trainval-train", "nusc_trainval-train_val"] 13 | self.trajdata_source_valid = ["nusc_trainval-val"] 14 | # dict mapping dataset IDs -> root path 15 | # all datasets that will be used must be included here 16 | self.trajdata_data_dirs = { 17 | "nusc_trainval" : "../behavior-generation-dataset/nuscenes", 18 | "nusc_test" : "../behavior-generation-dataset/nuscenes", 19 | "nusc_mini" : "../behavior-generation-dataset/nuscenes", 20 | } 21 | 22 | # for debug 23 | self.trajdata_rebuild_cache = False 24 | 25 | self.rollout.enabled = True 26 | self.rollout.save_video = True 27 | self.rollout.every_n_steps = 10000 28 | self.rollout.warm_start_n_steps = 0 29 | 30 | # training config 31 | # assuming 1 sec (10 steps) past, 2 sec (20 steps) future 32 | self.training.batch_size = 4 # 100 33 | self.training.num_steps = 100000 34 | self.training.num_data_workers = 8 35 | 36 | self.save.every_n_steps = 10000 37 | self.save.best_k = 10 38 | 39 | # validation config 40 | self.validation.enabled = True 41 | self.validation.batch_size = 1 # 2 # 4 # 32 42 | self.validation.num_data_workers = 6 43 | self.validation.every_n_steps = 500 44 | self.validation.num_steps_per_epoch = 5 # 50 45 | 46 | self.on_ngc = False 47 | self.logging.terminal_output_to_txt = True # whether to log stdout to txt file 48 | self.logging.log_tb = False # enable tensorboard logging 49 | self.logging.log_wandb = True # enable wandb logging 50 | self.logging.wandb_project_name = "tbsim" 51 | self.logging.log_every_n_steps = 10 52 | self.logging.flush_every_n_steps = 100 53 | 54 | # vec map params (only used by algorithm which uses vec map) 55 | self.training_vec_map_params = { 56 | 'S_seg': 15, 57 | 'S_point': 80, 58 | 'map_max_dist': 80, 59 | 'max_heading_error': 0.25*np.pi, 60 | 'ahead_threshold': -40, 61 | 'dist_weight': 1.0, 62 | 'heading_weight': 0.1, 63 | } 64 | 65 | 66 | class NuscTrajdataSceneEnvConfig(TrajdataEnvConfig): 67 | def __init__(self): 68 | super(NuscTrajdataSceneEnvConfig, self).__init__() 69 | 70 | self.data_generation_params.trajdata_centric = "scene" # ["agent", "scene"] 71 | # which types of agents to include from ['unknown', 'vehicle', 'pedestrian', 'bicycle', 'motorcycle'] 72 | self.data_generation_params.trajdata_only_types = ["vehicle"] 73 | # which types of agents to predict 74 | self.data_generation_params.trajdata_predict_types = ["vehicle"] 75 | # list of scene description filters 76 | self.data_generation_params.trajdata_scene_desc_contains = None 77 | # whether or not to include the map in the data 78 | # TODO: handle mixed map-nomap datasets 79 | self.data_generation_params.trajdata_incl_map = True 80 | # For both training and testing: 81 | # if scene-centric, max distance to scene center to be included for batching 82 | # if agent-centric, max distance to each agent to be considered neighbors 83 | self.data_generation_params.trajdata_max_agents_distance = 50 # np.inf # 30 84 | # standardize position and heading for the predicted agnet 85 | self.data_generation_params.trajdata_standardize_data = True 86 | 87 | # NOTE: rasterization info must still be provided even if incl_map=False 88 | # since still used for agent states 89 | # number of semantic layers that will be used (based on which trajdata dataset is being used) 90 | self.rasterizer.num_sem_layers = 3 # 7 91 | # how to group layers together to viz RGB image 92 | self.rasterizer.rgb_idx_groups = ([0], [1], [2]) 93 | # raster image size [pixels] 94 | self.rasterizer.raster_size = 224 95 | # raster's spatial resolution [meters per pixel]: the size in the real world one pixel corresponds to. 96 | self.rasterizer.pixel_size = 1.0 / 2.0 # 2 px/m 97 | # where the agent is on the map, (0.0, 0.0) is the center 98 | self.rasterizer.ego_center = (-0.5, 0.0) 99 | 100 | # max_agent_num (int, optional): The maximum number of agents to include in a batch for scene-centric batching. 101 | self.data_generation_params.other_agents_num = 20 # None # 20 102 | 103 | # max_neighbor_num (int, optional): The maximum number of neighbors to include in a batch for agent-centric batching. -------------------------------------------------------------------------------- /tbsim/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/CTG/f916c008c3ecf2360bfa050639606eaab7c207f5/tbsim/datasets/__init__.py -------------------------------------------------------------------------------- /tbsim/datasets/factory.py: -------------------------------------------------------------------------------- 1 | """DataModule / Dataset factory""" 2 | from tbsim.utils.config_utils import translate_l5kit_cfg, translate_trajdata_cfg, translate_pass_trajdata_cfg 3 | from tbsim.datasets.l5kit_datamodules import L5MixedDataModule, L5RasterizedDataModule 4 | from tbsim.datasets.trajdata_datamodules import UnifiedDataModule, PassUnifiedDataModule 5 | 6 | def datamodule_factory(cls_name: str, config): 7 | """ 8 | A factory for creating pl.DataModule. 9 | 10 | Valid module class names: "L5MixedDataModule", "L5RasterizedDataModule" 11 | Args: 12 | cls_name (str): name of the datamodule class 13 | config (Config): an Experiment config object 14 | **kwargs: any other kwargs needed by the datamodule 15 | 16 | Returns: 17 | A DataModule 18 | """ 19 | if cls_name.startswith("L5"): # TODO: make this less hacky 20 | l5_config = translate_l5kit_cfg(config) 21 | datamodule = eval(cls_name)(l5_config=l5_config, train_config=config.train) 22 | elif cls_name.startswith("Unified"): 23 | trajdata_config = translate_trajdata_cfg(config) 24 | datamodule = eval(cls_name)(data_config=trajdata_config, train_config=config.train) 25 | elif cls_name.startswith("PassUnified"): 26 | trajdata_config = translate_pass_trajdata_cfg(config) 27 | datamodule = eval(cls_name)(data_config=trajdata_config, train_config=config.train) 28 | else: 29 | raise NotImplementedError("{} is not a supported datamodule type".format(cls_name)) 30 | return datamodule -------------------------------------------------------------------------------- /tbsim/datasets/l5kit_datamodules.py: -------------------------------------------------------------------------------- 1 | """Functions and classes for dataset I/O""" 2 | import abc 3 | from collections import OrderedDict 4 | 5 | from typing import Optional 6 | import os 7 | import pytorch_lightning as pl 8 | from torch.utils.data import DataLoader 9 | 10 | from l5kit.rasterization import build_rasterizer 11 | from l5kit.rasterization.rasterizer import Rasterizer 12 | from l5kit.data import LocalDataManager, ChunkedDataset 13 | from l5kit.dataset import EgoDataset, AgentDataset 14 | 15 | from tbsim.configs.base import TrainConfig 16 | from tbsim.l5kit.vectorizer import build_vectorizer 17 | from tbsim.l5kit.l5_ego_dataset import ( 18 | EgoDatasetMixed 19 | ) 20 | 21 | from tbsim.l5kit.l5_agent_dataset import AgentDatasetMixed, AgentDataset 22 | 23 | 24 | class LazyRasterizer(Rasterizer): 25 | """ 26 | Only creates the actual rasterizer when a member function is called. 27 | 28 | A Rasterizer class is non-pickleable, which means that pickle complains about it when we try to do 29 | multi-process training (e.g., multiGPU training). This class is a way to circumvent the issue by only 30 | creating the rasterizer object when it's being used in the spawned processes. 31 | """ 32 | def __init__(self, l5_config, data_manager): 33 | super(LazyRasterizer, self).__init__() 34 | self._l5_config = l5_config 35 | self._dm = data_manager 36 | self._rasterizer = None 37 | 38 | @property 39 | def rasterizer(self): 40 | if self._rasterizer is None: 41 | self._rasterizer = build_rasterizer(self._l5_config, self._dm) 42 | return self._rasterizer 43 | 44 | def rasterize(self, *args, **kwargs): 45 | return self.rasterizer.rasterize(*args, **kwargs) 46 | 47 | def to_rgb(self,*args, **kwargs): 48 | return self.rasterizer.to_rgb(*args, **kwargs) 49 | 50 | def num_channels(self) -> int: 51 | return self.rasterizer.num_channels() 52 | 53 | 54 | class L5BaseDatasetModule(abc.ABC): 55 | pass 56 | 57 | 58 | class L5RasterizedDataModule(pl.LightningDataModule, L5BaseDatasetModule): 59 | def __init__( 60 | self, 61 | l5_config: dict, 62 | train_config: TrainConfig, 63 | ): 64 | super().__init__() 65 | self.train_dataset = None 66 | self.valid_dataset = None 67 | self.env_dataset = None 68 | self.experience_dataset = None # replay buffer 69 | self.rasterizer = None 70 | self._train_config = train_config 71 | self._l5_config = l5_config 72 | self._mode = train_config.dataset_mode 73 | 74 | assert self._mode in ["ego", "agents"] 75 | 76 | @property 77 | def modality_shapes(self): 78 | dm = LocalDataManager(None) 79 | rasterizer = build_rasterizer(self._l5_config, dm) 80 | h, w = self._l5_config["raster_params"]["raster_size"] 81 | return OrderedDict(image=(rasterizer.num_channels(), h, w)) 82 | 83 | def setup(self, stage: Optional[str] = None): 84 | os.environ["L5KIT_DATA_FOLDER"] = os.path.abspath(self._train_config.dataset_path) 85 | dm = LocalDataManager(None) 86 | self.rasterizer = LazyRasterizer(self._l5_config, dm) 87 | 88 | train_zarr = ChunkedDataset(dm.require(self._train_config.dataset_train_key)).open() 89 | valid_zarr = ChunkedDataset(dm.require(self._train_config.dataset_valid_key)).open() 90 | 91 | self.env_dataset = EgoDataset(self._l5_config, valid_zarr, self.rasterizer) 92 | 93 | if self._mode == "ego": 94 | self.train_dataset = EgoDataset(self._l5_config, train_zarr, self.rasterizer) 95 | self.valid_dataset = EgoDataset(self._l5_config, valid_zarr, self.rasterizer) 96 | else: 97 | read_cached_mask = not self._train_config.on_ngc 98 | self.train_dataset = AgentDataset(self._l5_config, train_zarr, self.rasterizer, read_cached_mask=read_cached_mask) 99 | self.valid_dataset = AgentDataset(self._l5_config, valid_zarr, self.rasterizer, read_cached_mask=read_cached_mask) 100 | 101 | def train_dataloader(self): 102 | return DataLoader( 103 | dataset=self.train_dataset, 104 | shuffle=True, 105 | batch_size=self._train_config.training.batch_size, 106 | num_workers=self._train_config.training.num_data_workers, 107 | drop_last=True, 108 | persistent_workers=True, 109 | ) 110 | 111 | def val_dataloader(self): 112 | return DataLoader( 113 | dataset=self.valid_dataset, 114 | shuffle=True, 115 | batch_size=self._train_config.validation.batch_size, 116 | num_workers=self._train_config.validation.num_data_workers, 117 | drop_last=True, 118 | persistent_workers=True, 119 | ) 120 | 121 | def test_dataloader(self): 122 | pass 123 | 124 | def predict_dataloader(self): 125 | pass 126 | 127 | 128 | class L5MixedDataModule(L5RasterizedDataModule): 129 | def __init__( 130 | self, 131 | l5_config, 132 | train_config: TrainConfig, 133 | ): 134 | super(L5MixedDataModule, self).__init__( 135 | l5_config=l5_config, train_config=train_config) 136 | self.vectorizer = None 137 | 138 | def setup(self, stage: Optional[str] = None): 139 | os.environ["L5KIT_DATA_FOLDER"] = os.path.abspath(self._train_config.dataset_path) 140 | dm = LocalDataManager(None) 141 | self.rasterizer = build_rasterizer(self._l5_config, dm) 142 | self.vectorizer = build_vectorizer(self._l5_config, dm) 143 | 144 | train_zarr = ChunkedDataset(dm.require(self._train_config.dataset_train_key)).open() 145 | valid_zarr = ChunkedDataset(dm.require(self._train_config.dataset_valid_key)).open() 146 | if self._mode == "ego": 147 | self.train_dataset = EgoDatasetMixed(self._l5_config, train_zarr, self.vectorizer, self.rasterizer) 148 | self.valid_dataset = EgoDatasetMixed(self._l5_config, valid_zarr, self.vectorizer, self.rasterizer) 149 | else: 150 | read_cached_mask = not self._train_config.on_ngc 151 | self.train_dataset = AgentDatasetMixed(self._l5_config, train_zarr, self.vectorizer, self.rasterizer, read_cached_mask=read_cached_mask) 152 | self.valid_dataset = AgentDatasetMixed(self._l5_config, valid_zarr, self.vectorizer, self.rasterizer, read_cached_mask=read_cached_mask) -------------------------------------------------------------------------------- /tbsim/dynamics/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from tbsim.dynamics.single_integrator import SingleIntegrator 4 | from tbsim.dynamics.unicycle import Unicycle 5 | from tbsim.dynamics.bicycle import Bicycle 6 | from tbsim.dynamics.double_integrator import DoubleIntegrator 7 | from tbsim.dynamics.base import Dynamics, DynType, forward_dynamics 8 | 9 | 10 | def get_dynamics_model(dyn_type: Union[str, DynType]): 11 | if dyn_type in ["Unicycle", DynType.UNICYCLE]: 12 | return Unicycle 13 | elif dyn_type == ["SingleIntegrator", DynType.SI]: 14 | return SingleIntegrator 15 | elif dyn_type == ["DoubleIntegrator", DynType.DI]: 16 | return DoubleIntegrator 17 | else: 18 | raise NotImplementedError("Dynamics model {} is not implemented".format(dyn_type)) 19 | -------------------------------------------------------------------------------- /tbsim/dynamics/base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math, copy, time 4 | import abc 5 | from copy import deepcopy 6 | 7 | 8 | class DynType: 9 | """ 10 | Holds environment types - one per environment class. 11 | These act as identifiers for different environments. 12 | """ 13 | 14 | UNICYCLE = 1 15 | SI = 2 16 | DI = 3 17 | BICYCLE = 4 18 | 19 | 20 | class Dynamics(abc.ABC): 21 | @abc.abstractmethod 22 | def __init__(self, name, **kwargs): 23 | self._name = name 24 | self.xdim = 4 25 | self.udim = 2 26 | 27 | @abc.abstractmethod 28 | def __call__(self, x, u): 29 | return 30 | 31 | @abc.abstractmethod 32 | def step(self, x, u, dt, bound=True): 33 | return 34 | 35 | @abc.abstractmethod 36 | def name(self): 37 | return self._name 38 | 39 | @abc.abstractmethod 40 | def type(self): 41 | return 42 | 43 | @abc.abstractmethod 44 | def ubound(self, x): 45 | return 46 | 47 | @staticmethod 48 | def state2pos(x): 49 | return 50 | 51 | @staticmethod 52 | def state2yaw(x): 53 | return 54 | 55 | 56 | def forward_dynamics( 57 | dyn_model: Dynamics, 58 | initial_states: torch.Tensor, 59 | actions: torch.Tensor, 60 | step_time: float, 61 | ): 62 | """ 63 | Integrate the state forward with initial state x0, action u 64 | Args: 65 | dyn_model (dynamics.Dynamics): dynamics model 66 | initial_states (Torch.tensor): state tensor of size [B, (A), 4] 67 | actions (Torch.tensor): action tensor of size [B, (A), T, 2] 68 | step_time (float): delta time between steps 69 | Returns: 70 | state tensor of size [B, (A), T, 4] 71 | """ 72 | num_steps = actions.shape[-2] 73 | x = [initial_states] + [None] * num_steps 74 | for t in range(num_steps): 75 | x[t + 1] = dyn_model.step(x[t], actions[..., t, :], step_time) 76 | 77 | x = torch.stack(x[1:], dim=-2) 78 | pos = dyn_model.state2pos(x) 79 | yaw = dyn_model.state2yaw(x) 80 | return x, pos, yaw 81 | -------------------------------------------------------------------------------- /tbsim/dynamics/bicycle.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | from tbsim.dynamics.base import Dynamics, DynType 5 | 6 | 7 | def bicycle_model(state, acc, ddh, vehicle_length, dt, max_hdot=math.pi * 2.0, max_s=50.0): 8 | """ 9 | Simple differentiable bicycle model that does not allow reverse 10 | Args: 11 | state (torch.Tensor): a batch of current kinematic state [B, ..., 5] (x, y, yaw, speed, hdot) 12 | acc (torch.Tensor): a batch of acceleration profile [B, ...] (acc) 13 | ddh (torch.Tensor): a batch of heading acceleration profile [B, ...] (heading) 14 | vehicle_length (torch.Tensor): a batch of vehicle length [B, ...] (length) 15 | dt (float): time between steps 16 | max_hdot (float): maximum change of heading (rad/s) 17 | max_s (float): maximum speed (m/s) 18 | 19 | Returns: 20 | New kinematic state (torch.Tensor) 21 | """ 22 | # state: (x, y, h, speed, hdot) 23 | assert state.shape[-1] == 5 24 | newhdot = (state[..., 4] + ddh * dt).clamp(-max_hdot, max_hdot) 25 | newh = state[..., 2] + dt * state[..., 3].abs() / vehicle_length * newhdot 26 | news = (state[..., 3] + acc * dt).clamp(0.0, max_s) # no reverse 27 | newy = state[..., 1] + news * newh.sin() * dt 28 | newx = state[..., 0] + news * newh.cos() * dt 29 | 30 | newstate = torch.empty_like(state) 31 | newstate[..., 0] = newx 32 | newstate[..., 1] = newy 33 | newstate[..., 2] = newh 34 | newstate[..., 3] = news 35 | newstate[..., 4] = newhdot 36 | 37 | return newstate 38 | 39 | 40 | class Bicycle(Dynamics): 41 | 42 | def __init__( 43 | self, 44 | acc_bound=(-10, 8), 45 | ddh_bound=(-math.pi * 2.0, math.pi * 2.0), 46 | max_speed=50.0, 47 | max_hdot=math.pi * 2.0 48 | ): 49 | """ 50 | A simple bicycle dynamics model 51 | Args: 52 | acc_bound (tuple): acceleration bound (m/s^2) 53 | ddh_bound (tuple): angular acceleration bound (rad/s^2) 54 | max_speed (float): maximum speed, must be positive 55 | max_hdot (float): maximum turning speed, must be positive 56 | """ 57 | super(Bicycle, self).__init__(name="bicycle") 58 | self.xdim = 6 59 | self.udim = 2 60 | assert max_speed >= 0 61 | assert max_hdot >= 0 62 | self.acc_bound = acc_bound 63 | self.ddh_bound = ddh_bound 64 | self.max_speed = max_speed 65 | self.max_hdot = max_hdot 66 | 67 | def get_normalized_controls(self, u): 68 | u = torch.sigmoid(u) # normalize to [0, 1] 69 | acc = self.acc_bound[0] + (self.acc_bound[1] - self.acc_bound[0]) * u[..., 0] 70 | ddh = self.ddh_bound[0] + (self.ddh_bound[1] - self.ddh_bound[0]) * u[..., 1] 71 | return acc, ddh 72 | 73 | def get_clipped_controls(self, u): 74 | acc = torch.clip(u[..., 0], self.acc_bound[0], self.acc_bound[1]) 75 | ddh = torch.clip(u[..., 1], self.ddh_bound[0], self.ddh_bound[1]) 76 | return acc, ddh 77 | 78 | def step(self, x, u, dt, normalize=True): 79 | """ 80 | Take a step with the dynamics model 81 | Args: 82 | x (torch.Tensor): current state [B, ..., 6] (x, y, h, speed, dh, veh_length) 83 | u (torch.Tensor): (un-normalized) actions [B, ..., 2] (acc, ddh) 84 | dt (float): time between steps 85 | normalize (bool): whether to normalize the actions 86 | 87 | Returns: 88 | next_x (torch.Tensor): next state after taking the action 89 | """ 90 | assert x.shape[-1] == self.xdim 91 | assert u.shape[:-1] == x.shape[:-1] 92 | assert u.shape[-1] == self.udim 93 | if normalize: 94 | acc, ddh = self.get_normalized_controls(u) 95 | else: 96 | acc, ddh = self.get_clipped_controls(u) 97 | next_x = x.clone() # keep the extent the same 98 | next_x[..., :5] = bicycle_model( 99 | state=x[..., :5], 100 | acc=acc, 101 | ddh=ddh, 102 | vehicle_length=x[..., 5], 103 | dt=dt, 104 | max_hdot=self.max_hdot, 105 | max_s=self.max_speed 106 | ) 107 | return next_x 108 | @staticmethod 109 | def calculate_vel(pos, yaw, dt, mask): 110 | 111 | vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * torch.cos( 112 | yaw[..., 1:, :] 113 | ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * torch.sin( 114 | yaw[..., 1:, :] 115 | ) 116 | # right finite difference velocity 117 | vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) 118 | # left finite difference velocity 119 | vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) 120 | mask_r = torch.roll(mask, 1, dims=-1) 121 | mask_r[..., 0] = False 122 | mask_r = mask_r & mask 123 | 124 | mask_l = torch.roll(mask, -1, dims=-1) 125 | mask_l[..., -1] = False 126 | mask_l = mask_l & mask 127 | vel = ( 128 | (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 129 | + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l 130 | + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r 131 | ) 132 | 133 | return vel 134 | 135 | def type(self): 136 | return DynType.BICYCLE 137 | 138 | def state2pos(self, x): 139 | return x[..., :2] 140 | 141 | def state2yaw(self, x): 142 | return x[..., 2:3] 143 | 144 | def __call__(self, x, u): 145 | pass 146 | 147 | def ubound(self, x): 148 | pass 149 | 150 | def name(self): 151 | return self._name 152 | -------------------------------------------------------------------------------- /tbsim/dynamics/double_integrator.py: -------------------------------------------------------------------------------- 1 | from tbsim.dynamics.base import DynType, Dynamics 2 | import torch 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | 7 | 8 | class DoubleIntegrator(Dynamics): 9 | def __init__(self, name, abound, vbound=None): 10 | self._name = name 11 | self._type = DynType.DI 12 | self.xdim = abound.shape[0] * 2 13 | self.udim = abound.shape[0] 14 | self.cyclic_state = list() 15 | self.vbound = np.array(vbound) 16 | self.abound = np.array(abound) 17 | 18 | def __call__(self, x, u): 19 | assert x.shape[:-1] == u.shape[:, -1] 20 | if isinstance(x, np.ndarray): 21 | return np.hstack((x[..., 2:], u)) 22 | elif isinstance(x, torch.Tensor): 23 | return torch.cat((x[..., 2:], u), dim=-1) 24 | else: 25 | raise NotImplementedError 26 | 27 | def step(self, x, u, dt, bound=True): 28 | 29 | if isinstance(x, np.ndarray): 30 | if bound: 31 | lb, ub = self.ubound(x) 32 | u = np.clip(u, lb, ub) 33 | xn = np.hstack( 34 | ((x[..., 2:4] + 0.5 * u * dt) * dt + x[..., 0:2], x[..., 2:4] + u * dt) 35 | ) 36 | elif isinstance(x, torch.Tensor): 37 | if bound: 38 | lb, ub = self.ubound(x) 39 | u = torch.clip(u, min=lb, max=ub) 40 | xn = torch.clone(x) 41 | xn[..., 0:2] += (x[..., 2:4] + 0.5 * u * dt) * dt 42 | xn[..., 2:4] += u * dt 43 | else: 44 | raise NotImplementedError 45 | return xn 46 | 47 | def name(self): 48 | return self._name 49 | 50 | def type(self): 51 | return self._type 52 | 53 | def ubound(self, x): 54 | if self.vbound is None: 55 | if isinstance(x, np.ndarray): 56 | lb = np.ones_like(x[..., 2:]) * self.abound[:, 0] 57 | ub = np.ones_like(x[..., 2:]) * self.abound[:, 1] 58 | 59 | elif isinstance(x, torch.Tensor): 60 | lb = torch.ones_like(x[..., 2:]) * torch.from_numpy( 61 | self.abound[:, 0] 62 | ).to(x.device) 63 | ub = torch.ones_like(x[..., 2:]) * torch.from_numpy( 64 | self.abound[:, 1] 65 | ).to(x.device) 66 | 67 | else: 68 | raise NotImplementedError 69 | else: 70 | if isinstance(x, np.ndarray): 71 | lb = (x[..., 2:] > self.vbound[:, 0]) * self.abound[:, 0] 72 | ub = (x[..., 2:] < self.vbound[:, 1]) * self.abound[:, 1] 73 | 74 | elif isinstance(x, torch.Tensor): 75 | lb = ( 76 | x[..., 2:] > torch.from_numpy(self.vbound[:, 0]).to(x.device) 77 | ) * torch.from_numpy(self.abound[:, 0]).to(x.device) 78 | ub = ( 79 | x[..., 2:] < torch.from_numpy(self.vbound[:, 1]).to(x.device) 80 | ) * torch.from_numpy(self.abound[:, 1]).to(x.device) 81 | else: 82 | raise NotImplementedError 83 | return lb, ub 84 | 85 | @staticmethod 86 | def state2pos(x): 87 | return x[..., 0:2] 88 | 89 | @staticmethod 90 | def state2yaw(x): 91 | # return torch.atan2(x[..., 3:], x[..., 2:3]) 92 | return torch.zeros_like(x[..., 0:1]) 93 | -------------------------------------------------------------------------------- /tbsim/dynamics/single_integrator.py: -------------------------------------------------------------------------------- 1 | from tbsim.dynamics.base import DynType, Dynamics 2 | import torch 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | 7 | class SingleIntegrator(Dynamics): 8 | def __init__(self, name, vbound): 9 | self._name = name 10 | self._type = DynType.SI 11 | self.xdim = vbound.shape[0] 12 | self.udim = vbound.shape[0] 13 | self.cyclic_state = list() 14 | self.vbound = np.array(vbound) 15 | 16 | def __call__(self, x, u): 17 | assert x.shape[:-1] == u.shape[:, -1] 18 | 19 | return u 20 | 21 | def step(self, x, u, dt, bound=True): 22 | assert x.shape[:-1] == u.shape[:, -1] 23 | if bound: 24 | lb, ub = self.ubound(x) 25 | if isinstance(x, np.ndarray): 26 | u = np.clip(u, lb, ub) 27 | elif isinstance(x, torch.Tensor): 28 | u = torch.clip(u, min=lb, max=ub) 29 | 30 | return x + u * dt 31 | 32 | def name(self): 33 | return self._name 34 | 35 | def type(self): 36 | return self._type 37 | 38 | def ubound(self, x): 39 | if isinstance(x, np.ndarray): 40 | lb = np.ones_like(x) * self.vbound[:, 0] 41 | ub = np.ones_like(x) * self.vbound[:, 1] 42 | return lb, ub 43 | elif isinstance(x, torch.Tensor): 44 | lb = torch.ones_like(x) * torch.from_numpy(self.vbound[:, 0]) 45 | ub = torch.ones_like(x) * torch.from_numpy(self.vbound[:, 1]) 46 | return lb, ub 47 | else: 48 | raise NotImplementedError 49 | 50 | @staticmethod 51 | def state2pos(x): 52 | return x[..., 0:2] 53 | -------------------------------------------------------------------------------- /tbsim/dynamics/unicycle.py: -------------------------------------------------------------------------------- 1 | from tbsim.dynamics.base import DynType, Dynamics 2 | import torch 3 | import numpy as np 4 | from copy import deepcopy 5 | 6 | 7 | class Unicycle(Dynamics): 8 | def __init__( 9 | self, name, max_steer=0.5, max_yawvel=8, acce_bound=[-6, 4], vbound=[-10, 30] 10 | ): 11 | self._name = name 12 | self._type = DynType.UNICYCLE 13 | self.xdim = 4 14 | self.udim = 2 15 | self.cyclic_state = [3] 16 | self.acce_bound = acce_bound 17 | self.vbound = vbound 18 | self.max_steer = max_steer 19 | self.max_yawvel = max_yawvel 20 | 21 | def __call__(self, x, u): 22 | assert x.shape[:-1] == u.shape[:, -1] 23 | if isinstance(x, np.ndarray): 24 | assert isinstance(u, np.ndarray) 25 | theta = x[..., 3:4] 26 | dxdt = np.hstack( 27 | (np.cos(theta) * x[..., 2:3], np.sin(theta) * x[..., 2:3], u) 28 | ) 29 | elif isinstance(x, torch.Tensor): 30 | assert isinstance(u, torch.Tensor) 31 | theta = x[..., 3:4] 32 | dxdt = torch.cat( 33 | (torch.cos(theta) * x[..., 2:3], 34 | torch.sin(theta) * x[..., 2:3], u), 35 | dim=-1, 36 | ) 37 | else: 38 | raise NotImplementedError 39 | return dxdt 40 | 41 | def step(self, x, u, dt, bound=True): 42 | assert x.shape[:-1] == u.shape[:-1] 43 | # print('x.shape:', x.shape, 'u.shape:', u.shape) 44 | if isinstance(x, np.ndarray): 45 | assert isinstance(u, np.ndarray) 46 | if bound: 47 | lb, ub = self.ubound(x) 48 | u = np.clip(u, lb, ub) 49 | 50 | theta = x[..., 3:4] 51 | dxdt = np.hstack( 52 | ( 53 | np.cos(theta) * (x[..., 2:3] + u[..., 0:1] * dt * 0.5), 54 | np.sin(theta) * (x[..., 2:3] + u[..., 0:1] * dt * 0.5), 55 | u, 56 | ) 57 | ) 58 | elif isinstance(x, torch.Tensor): 59 | assert isinstance(u, torch.Tensor) 60 | # print('original before clip u[0]', u[0]) 61 | if bound: 62 | lb, ub = self.ubound(x) 63 | # s = (u - lb) / torch.clip(ub - lb, min=1e-3) 64 | # u = lb + (ub - lb) * torch.sigmoid(s) 65 | u = torch.clip(u, lb, ub) 66 | # print('original after clip u[0]', u[0]) 67 | theta = x[..., 3:4] 68 | dxdt = torch.cat( 69 | ( 70 | torch.cos(theta) * (x[..., 2:3] + u[..., 0:1] * dt * 0.5), 71 | torch.sin(theta) * (x[..., 2:3] + u[..., 0:1] * dt * 0.5), 72 | u, 73 | ), 74 | dim=-1, 75 | ) 76 | else: 77 | raise NotImplementedError 78 | # print("x.size()", x.size()) 79 | # print('x[1]', 'dxdt[1]*dt') 80 | # print(x[1], dxdt[1]*dt) 81 | return x + dxdt * dt 82 | 83 | def name(self): 84 | return self._name 85 | 86 | def type(self): 87 | return self._type 88 | 89 | def ubound(self, x): 90 | if isinstance(x, np.ndarray): 91 | v = x[..., 2:3] 92 | yawbound = np.minimum( 93 | self.max_steer * v, 94 | self.max_yawvel / np.clip(np.abs(v), a_min=0.1, a_max=None), 95 | ) 96 | acce_lb = np.clip( 97 | np.clip(self.vbound[0] - v, None, self.acce_bound[1]), 98 | self.acce_bound[0], 99 | None, 100 | ) 101 | acce_ub = np.clip( 102 | np.clip(self.vbound[1] - v, self.acce_bound[0], None), 103 | None, 104 | self.acce_bound[1], 105 | ) 106 | lb = np.hstack((acce_lb, -yawbound)) 107 | ub = np.hstack((acce_ub, yawbound)) 108 | return lb, ub 109 | elif isinstance(x, torch.Tensor): 110 | v = x[..., 2:3] 111 | yawbound = torch.minimum( 112 | self.max_steer * torch.abs(v), 113 | self.max_yawvel / torch.clip(torch.abs(v), min=0.1), 114 | ) 115 | yawbound = torch.clip(yawbound, min=0.1) 116 | acce_lb = torch.clip( 117 | torch.clip(self.vbound[0] - v, max=self.acce_bound[1]), 118 | min=self.acce_bound[0], 119 | ) 120 | acce_ub = torch.clip( 121 | torch.clip(self.vbound[1] - v, min=self.acce_bound[0]), 122 | max=self.acce_bound[1], 123 | ) 124 | lb = torch.cat((acce_lb, -yawbound), dim=-1) 125 | ub = torch.cat((acce_ub, yawbound), dim=-1) 126 | return lb, ub 127 | 128 | else: 129 | raise NotImplementedError 130 | 131 | @staticmethod 132 | def state2pos(x): 133 | return x[..., 0:2] 134 | 135 | @staticmethod 136 | def state2yaw(x): 137 | return x[..., 3:] 138 | 139 | @staticmethod 140 | def calculate_vel(pos, yaw, dt, mask): 141 | if isinstance(pos, torch.Tensor): 142 | vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * torch.cos( 143 | yaw[..., 1:, :] 144 | ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * torch.sin( 145 | yaw[..., 1:, :] 146 | ) 147 | # right finite difference velocity 148 | vel_r = torch.cat((vel[..., 0:1, :], vel), dim=-2) 149 | # left finite difference velocity 150 | vel_l = torch.cat((vel, vel[..., -1:, :]), dim=-2) 151 | mask_r = torch.roll(mask, 1, dims=-1) 152 | mask_r[..., 0] = False 153 | mask_r = mask_r & mask 154 | 155 | mask_l = torch.roll(mask, -1, dims=-1) 156 | mask_l[..., -1] = False 157 | mask_l = mask_l & mask 158 | vel = ( 159 | (mask_l & mask_r).unsqueeze(-1) * (vel_r + vel_l) / 2 160 | + (mask_l & (~mask_r)).unsqueeze(-1) * vel_l 161 | + (mask_r & (~mask_l)).unsqueeze(-1) * vel_r 162 | ) 163 | elif isinstance(pos, np.ndarray): 164 | vel = (pos[..., 1:, 0:1] - pos[..., :-1, 0:1]) / dt * np.cos( 165 | yaw[..., 1:, :] 166 | ) + (pos[..., 1:, 1:2] - pos[..., :-1, 1:2]) / dt * np.sin(yaw[..., 1:, :]) 167 | # right finite difference velocity 168 | vel_r = np.concatenate((vel[..., 0:1, :], vel), axis=-2) 169 | # left finite difference velocity 170 | vel_l = np.concatenate((vel, vel[..., -1:, :]), axis=-2) 171 | mask_r = np.roll(mask, 1, axis=-1) 172 | mask_r[..., 0] = False 173 | mask_r = mask_r & mask 174 | mask_l = np.roll(mask, -1, axis=-1) 175 | mask_l[..., -1] = False 176 | mask_l = mask_l & mask 177 | vel = ( 178 | np.expand_dims(mask_l & mask_r,-1) * (vel_r + vel_l) / 2 179 | + np.expand_dims(mask_l & (~mask_r),-1) * vel_l 180 | + np.expand_dims(mask_r & (~mask_l),-1) * vel_r 181 | ) 182 | else: 183 | raise NotImplementedError 184 | return vel 185 | @staticmethod 186 | def inverse_dyn(x,xp,dt): 187 | return (xp[...,2:]-x[...,2:])/dt -------------------------------------------------------------------------------- /tbsim/envs/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class SimulationException(Exception): 5 | pass 6 | 7 | 8 | class BaseEnv(abc.ABC): 9 | """TODO: Make a Simulator MetaClass""" 10 | 11 | @abc.abstractmethod 12 | def reset(self, scene_indices=None, start_frame_index=None): 13 | return 14 | 15 | @abc.abstractmethod 16 | def reset_multi_episodes_metrics(self): 17 | pass 18 | 19 | @abc.abstractmethod 20 | def step(self, action, num_steps_to_take, render): 21 | return 22 | 23 | @abc.abstractmethod 24 | def update_random_seed(self, seed): 25 | return 26 | 27 | @abc.abstractmethod 28 | def get_metrics(self): 29 | return 30 | 31 | @abc.abstractmethod 32 | def get_multi_episode_metrics(self): 33 | return 34 | 35 | @abc.abstractmethod 36 | def render(self, actions_to_take): 37 | return 38 | 39 | @abc.abstractmethod 40 | def get_info(self): 41 | return 42 | 43 | @abc.abstractmethod 44 | def get_observation(self): 45 | return 46 | 47 | @abc.abstractmethod 48 | def get_reward(self): 49 | return 50 | 51 | @abc.abstractmethod 52 | def is_done(self): 53 | return 54 | 55 | @abc.abstractmethod 56 | def get_info(self): 57 | return 58 | 59 | 60 | class BatchedEnv(abc.ABC): 61 | @abc.abstractmethod 62 | def num_instances(self): 63 | return 64 | -------------------------------------------------------------------------------- /tbsim/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/CTG/f916c008c3ecf2360bfa050639606eaab7c207f5/tbsim/evaluation/__init__.py -------------------------------------------------------------------------------- /tbsim/evaluation/metric_composers.py: -------------------------------------------------------------------------------- 1 | from tbsim.algos.algos import ( 2 | DiscreteVAETrafficModel, 3 | ) 4 | 5 | from tbsim.algos.metric_algos import ( 6 | OccupancyMetric, 7 | ) 8 | 9 | import tbsim.envs.env_metrics as EnvMetrics 10 | 11 | from tbsim.utils.batch_utils import batch_utils 12 | from tbsim.utils.config_utils import get_experiment_config_from_file 13 | from tbsim.configs.base import ExperimentConfig 14 | 15 | from tbsim.utils.experiment_utils import get_checkpoint 16 | 17 | try: 18 | from Pplan.Sampling.spline_planner import SplinePlanner 19 | from Pplan.Sampling.trajectory_tree import TrajTree 20 | except ImportError: 21 | print("Cannot import Pplan") 22 | 23 | 24 | class MetricsComposer(object): 25 | """Wrapper for building learned metrics from trained checkpoints.""" 26 | def __init__(self, eval_config, device, ckpt_root_dir="checkpoints/"): 27 | self.device = device 28 | self.ckpt_root_dir = ckpt_root_dir 29 | self.eval_config = eval_config 30 | self._exp_config = None 31 | 32 | def get_modality_shapes(self, exp_cfg: ExperimentConfig): 33 | return batch_utils().get_modality_shapes(exp_cfg) 34 | 35 | def get_metrics(self): 36 | raise NotImplementedError 37 | 38 | 39 | class CVAEMetrics(MetricsComposer): 40 | def get_metrics(self, eval_config, perturbations=None, rolling=False, env="l5kit", **kwargs): 41 | 42 | ckpt_path, config_path = get_checkpoint( 43 | ngc_job_id=eval_config.ckpt.cvae_metric.ngc_job_id, 44 | ckpt_key=eval_config.ckpt.cvae_metric.ckpt_key, 45 | ckpt_root_dir=self.ckpt_root_dir 46 | ) 47 | 48 | controller_cfg = get_experiment_config_from_file(config_path) 49 | modality_shapes = batch_utils().get_modality_shapes(controller_cfg) 50 | CVAE_model = DiscreteVAETrafficModel.load_from_checkpoint( 51 | ckpt_path, 52 | algo_config=controller_cfg.algo, 53 | modality_shapes=modality_shapes 54 | ).to(self.device).eval() 55 | if not rolling: 56 | return EnvMetrics.LearnedCVAENLL(metric_algo=CVAE_model, perturbations=perturbations) 57 | else: 58 | if "rolling_horizon" in kwargs: 59 | rolling_horizon = kwargs["rolling_horizon"] 60 | else: 61 | rolling_horizon = None 62 | return EnvMetrics.LearnedCVAENLLRolling(metric_algo=CVAE_model, rolling_horizon=rolling_horizon, perturbations=perturbations) 63 | 64 | 65 | class OccupancyMetrics(MetricsComposer): 66 | def get_metrics(self, eval_config, perturbations = None, rolling=False, env="l5kit", **kwargs): 67 | ckpt_path, config_path = get_checkpoint( 68 | ngc_job_id=eval_config.ckpt.occupancy_metric.ngc_job_id, 69 | ckpt_key=eval_config.ckpt.occupancy_metric.ckpt_key, 70 | ckpt_root_dir=self.ckpt_root_dir 71 | ) 72 | 73 | cfg = get_experiment_config_from_file(config_path) 74 | 75 | modality_shapes = batch_utils().get_modality_shapes(cfg) 76 | occupancy_model = OccupancyMetric.load_from_checkpoint( 77 | ckpt_path, 78 | algo_config=cfg.algo, 79 | modality_shapes=modality_shapes 80 | ).to(self.device).eval() 81 | 82 | if not rolling: 83 | return EnvMetrics.Occupancy_likelihood(metric_algo=occupancy_model, perturbations=perturbations) 84 | else: 85 | if "rolling_horizon" in kwargs: 86 | rolling_horizon = kwargs["rolling_horizon"] 87 | else: 88 | rolling_horizon = None 89 | return EnvMetrics.Occupancy_rolling(metric_algo=occupancy_model, rolling_horizon=rolling_horizon, perturbations=perturbations) 90 | 91 | -------------------------------------------------------------------------------- /tbsim/l5kit/vis_rasterizer.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from enum import IntEnum 3 | from typing import Dict, List, Optional 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | from l5kit.data.filter import filter_tl_faces_by_status 9 | from l5kit.data.map_api import InterpolationMethod, MapAPI, TLFacesColors 10 | from l5kit.geometry import rotation33_as_yaw, transform_point, transform_points 11 | from l5kit.rasterization.rasterizer import Rasterizer 12 | from l5kit.rasterization.render_context import RenderContext 13 | 14 | 15 | # sub-pixel drawing precision constants 16 | CV2_SUB_VALUES = {"shift": 9, "lineType": cv2.LINE_AA} 17 | CV2_SHIFT_VALUE = 2 ** CV2_SUB_VALUES["shift"] 18 | INTERPOLATION_POINTS = 20 19 | 20 | 21 | class RasterEls(IntEnum): # map elements 22 | LANE_NOTL = 0 23 | ROAD = 1 24 | CROSSWALK = 2 25 | 26 | 27 | COLORS = { 28 | TLFacesColors.GREEN.name: (0, 255, 0), 29 | TLFacesColors.RED.name: (255, 0, 0), 30 | TLFacesColors.YELLOW.name: (255, 255, 0), 31 | # RasterEls.LANE_NOTL.name: (255, 217, 82), 32 | RasterEls.LANE_NOTL.name: (164, 184, 196), 33 | # RasterEls.ROAD.name: (17, 17, 31), 34 | RasterEls.ROAD.name: (200, 211, 213), 35 | RasterEls.CROSSWALK.name: (96, 117, 138), 36 | # RasterEls.CROSSWALK.name: (255, 117, 69), 37 | # "Background": (37, 40, 61), 38 | "Background": (255, 255, 255), 39 | } 40 | 41 | 42 | def indices_in_bounds(center: np.ndarray, bounds: np.ndarray, half_extent: float) -> np.ndarray: 43 | """ 44 | Get indices of elements for which the bounding box described by bounds intersects the one defined around 45 | center (square with side 2*half_side) 46 | 47 | Args: 48 | center (float): XY of the center 49 | bounds (np.ndarray): array of shape Nx2x2 [[x_min,y_min],[x_max, y_max]] 50 | half_extent (float): half the side of the bounding box centered around center 51 | 52 | Returns: 53 | np.ndarray: indices of elements inside radius from center 54 | """ 55 | x_center, y_center = center 56 | 57 | x_min_in = x_center > bounds[:, 0, 0] - half_extent 58 | y_min_in = y_center > bounds[:, 0, 1] - half_extent 59 | x_max_in = x_center < bounds[:, 1, 0] + half_extent 60 | y_max_in = y_center < bounds[:, 1, 1] + half_extent 61 | return np.nonzero(x_min_in & y_min_in & x_max_in & y_max_in)[0] 62 | 63 | 64 | def cv2_subpixel(coords: np.ndarray) -> np.ndarray: 65 | """ 66 | Cast coordinates to numpy.int but keep fractional part by previously multiplying by 2**CV2_SHIFT 67 | cv2 calls will use shift to restore original values with higher precision 68 | 69 | Args: 70 | coords (np.ndarray): XY coords as float 71 | 72 | Returns: 73 | np.ndarray: XY coords as int for cv2 shift draw 74 | """ 75 | coords = coords * CV2_SHIFT_VALUE 76 | coords = coords.astype(np.int32) 77 | return coords 78 | 79 | 80 | class VisualizationRasterizer(object): 81 | """ 82 | Rasteriser for visualization purposes 83 | """ 84 | 85 | def __init__( 86 | self, render_context: RenderContext, semantic_map_path: str, world_to_ecef: np.ndarray, 87 | ): 88 | self.render_context = render_context 89 | self.raster_size = render_context.raster_size_px 90 | self.pixel_size = render_context.pixel_size_m 91 | self.ego_center = render_context.center_in_raster_ratio 92 | 93 | self.world_to_ecef = world_to_ecef 94 | 95 | self.mapAPI = MapAPI(semantic_map_path, world_to_ecef) 96 | 97 | def rasterize( 98 | self, 99 | ego_translation_m, 100 | ego_yaw_rad, 101 | ) -> np.ndarray: 102 | raster_from_world = self.render_context.raster_from_world(ego_translation_m, ego_yaw_rad) 103 | world_from_raster = np.linalg.inv(raster_from_world) 104 | 105 | # get XY of center pixel in world coordinates 106 | center_in_raster_px = np.asarray(self.raster_size) * (0.5, 0.5) 107 | center_in_world_m = transform_point(center_in_raster_px, world_from_raster) 108 | 109 | sem_im = self.render_semantic_map(center_in_world_m, raster_from_world) 110 | return sem_im.astype(np.float32) / 255 111 | 112 | def render_semantic_map( 113 | self, center_in_world: np.ndarray, raster_from_world: np.ndarray 114 | ) -> np.ndarray: 115 | """Renders the semantic map at given x,y coordinates. 116 | 117 | Args: 118 | center_in_world (np.ndarray): XY of the image center in world ref system 119 | raster_from_world (np.ndarray): 120 | Returns: 121 | np.ndarray: RGB raster 122 | 123 | """ 124 | 125 | img = np.ones(shape=(self.raster_size[1], self.raster_size[0], 3)) 126 | img *= [[COLORS["Background"]]] 127 | img = img.astype(np.uint8) 128 | 129 | # filter using half a radius from the center 130 | raster_radius = float(np.linalg.norm(self.raster_size * self.pixel_size)) / 2 131 | 132 | # get all lanes as interpolation so that we can transform them all together 133 | 134 | lane_indices = indices_in_bounds(center_in_world, self.mapAPI.bounds_info["lanes"]["bounds"], raster_radius) 135 | lanes_mask: Dict[str, np.ndarray] = defaultdict(lambda: np.zeros(len(lane_indices) * 2, dtype=bool)) 136 | lanes_area = np.zeros((len(lane_indices) * 2, INTERPOLATION_POINTS, 2)) 137 | 138 | for idx, lane_idx in enumerate(lane_indices): 139 | lane_idx = self.mapAPI.bounds_info["lanes"]["ids"][lane_idx] 140 | 141 | # interpolate over polyline to always have the same number of points 142 | lane_coords = self.mapAPI.get_lane_as_interpolation( 143 | lane_idx, INTERPOLATION_POINTS, InterpolationMethod.INTER_ENSURE_LEN 144 | ) 145 | lanes_area[idx * 2] = lane_coords["xyz_left"][:, :2] 146 | lanes_area[idx * 2 + 1] = lane_coords["xyz_right"][::-1, :2] 147 | 148 | lane_type = RasterEls.LANE_NOTL.name 149 | lane_tl_ids = set(self.mapAPI.get_lane_traffic_control_ids(lane_idx)) 150 | # for tl_id in lane_tl_ids.intersection(active_tl_ids): 151 | # for tl_id in lane_tl_ids: 152 | # lane_type = self.mapAPI.get_color_for_face(tl_id) 153 | 154 | lanes_mask[lane_type][idx * 2: idx * 2 + 2] = True 155 | 156 | if len(lanes_area): 157 | lanes_area = cv2_subpixel(transform_points(lanes_area.reshape((-1, 2)), raster_from_world)) 158 | 159 | for lane_area in lanes_area.reshape((-1, INTERPOLATION_POINTS * 2, 2)): 160 | # need to for-loop otherwise some of them are empty 161 | cv2.fillPoly(img, [lane_area], COLORS[RasterEls.ROAD.name], **CV2_SUB_VALUES) 162 | 163 | lanes_area = lanes_area.reshape((-1, INTERPOLATION_POINTS, 2)) 164 | for name, mask in lanes_mask.items(): # draw each type of lane with its own color 165 | cv2.polylines(img, lanes_area[mask], False, COLORS[name], **CV2_SUB_VALUES) 166 | 167 | # plot crosswalks 168 | crosswalks = [] 169 | for idx in indices_in_bounds(center_in_world, self.mapAPI.bounds_info["crosswalks"]["bounds"], raster_radius): 170 | crosswalk = self.mapAPI.get_crosswalk_coords(self.mapAPI.bounds_info["crosswalks"]["ids"][idx]) 171 | xy_cross = cv2_subpixel(transform_points(crosswalk["xyz"][:, :2], raster_from_world)) 172 | crosswalks.append(xy_cross) 173 | 174 | cv2.polylines(img, crosswalks, True, COLORS[RasterEls.CROSSWALK.name], **CV2_SUB_VALUES) 175 | 176 | return img 177 | 178 | def to_rgb(self, in_im: np.ndarray, **kwargs: dict) -> np.ndarray: 179 | return (in_im * 255).astype(np.uint8) 180 | 181 | def num_channels(self) -> int: 182 | return 3 183 | -------------------------------------------------------------------------------- /tbsim/models/GAN_regularizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tbsim.dynamics.base import DynType 3 | import tbsim.utils.l5_utils as L5Utils 4 | from tbsim.models.cnn_roi_encoder import obtain_map_enc 5 | 6 | 7 | def pred2obs( 8 | dyn_list, 9 | step_time, 10 | src_pos, 11 | src_yaw, 12 | src_mask, 13 | data_batch, 14 | pos_pred, 15 | yaw_pred, 16 | pred_mask, 17 | raw_type, 18 | src_lanes, 19 | CNNmodel, 20 | algo_config, 21 | f_steps=1, 22 | M=1, 23 | ): 24 | """generate observation for the predicted scene f_step steps into the future 25 | 26 | Args: 27 | src_pos (torch.tensor[torch.float]): xy position in src 28 | src_yaw (torch.tensor[torch.float]): yaw in src 29 | src_mask (torch.tensor[torch.bool]): mask for src 30 | data_batch (dict): input data dictionary 31 | pos_pred (torch.tensor[torch.float]): predicted xy trajectory 32 | yaw_pred (torch.tensor[torch.float]): predicted yaw 33 | pred_mask (torch.tensor[torch.bool]): mask for prediction 34 | raw_type (torch.tensor[torch.int]): type of agents 35 | src_lanes (torch.tensor[torch.float]): lane info 36 | f_steps (int, optional): [description]. Defaults to 1. 37 | 38 | Returns: 39 | torch.tensor[torch.float]: new src for the transformer 40 | torch.tensor[torch.bool]: new src mask 41 | torch.tensor[torch.float]: new map encoding 42 | """ 43 | if pos_pred.ndim == 5: 44 | src_pos = src_pos.unsqueeze(1).repeat(1, M, 1, 1, 1) 45 | src_yaw = src_yaw.unsqueeze(1).repeat(1, M, 1, 1, 1) 46 | src_mask = src_mask.unsqueeze(1).repeat(1, M, 1, 1) 47 | pred_mask = pred_mask.unsqueeze(1).repeat(1, M, 1, 1) 48 | raw_type = raw_type.unsqueeze(1).repeat(1, M, 1) 49 | pos_new = torch.cat( 50 | (src_pos[..., f_steps:, :], pos_pred[..., :f_steps, :]), dim=-2 51 | ) 52 | yaw_new = torch.cat( 53 | (src_yaw[..., f_steps:, :], yaw_pred[..., :f_steps, :]), dim=-2 54 | ) 55 | src_mask_new = torch.cat( 56 | (src_mask[..., f_steps:], pred_mask[..., :f_steps]), dim=-1 57 | ) 58 | vel_new = dyn_list[DynType.UNICYCLE].calculate_vel( 59 | pos_new, yaw_new, step_time, src_mask_new 60 | ) 61 | src_new, _, _ = L5Utils.raw2feature( 62 | pos_new, 63 | vel_new, 64 | yaw_new, 65 | raw_type, 66 | src_mask_new, 67 | torch.zeros_like(src_lanes) if src_lanes is not None else None, 68 | ) 69 | 70 | if M == 1: 71 | map_emb_new = obtain_map_enc( 72 | data_batch["image"], 73 | CNNmodel, 74 | pos_new, 75 | yaw_new, 76 | data_batch["raster_from_agent"], 77 | src_mask_new, 78 | torch.tensor(algo_config.CNN.patch_size).to(src_pos.device), 79 | algo_config.CNN.output_size, 80 | mode="last", 81 | ) 82 | 83 | else: 84 | map_emb_new = list() 85 | for i in range(M): 86 | 87 | map_emb_new_i = obtain_map_enc( 88 | data_batch["image"], 89 | CNNmodel, 90 | pos_new[:, i], 91 | yaw_new[:, i], 92 | data_batch["raster_from_agent"], 93 | src_mask_new, 94 | torch.tensor(algo_config.CNN.patch_size).to(src_pos.device), 95 | algo_config.CNN.output_size, 96 | mode="last", 97 | ) 98 | map_emb_new.append(map_emb_new_i) 99 | map_emb_new = torch.stack(map_emb_new, dim=1) 100 | return src_new, src_mask_new, map_emb_new 101 | 102 | 103 | def pred2obs_static( 104 | dyn_list, 105 | step_time, 106 | data_batch, 107 | pos_pred, 108 | yaw_pred, 109 | pred_mask, 110 | raw_type, 111 | src_lanes, 112 | CNNmodel, 113 | algo_config, 114 | M=1, 115 | ): 116 | """generate observation for every step of the predictions 117 | 118 | Args: 119 | data_batch (dict): input data dictionary 120 | pos_pred (torch.tensor[torch.float]): predicted xy trajectory 121 | yaw_pred (torch.tensor[torch.float]): predicted yaw 122 | pred_mask (torch.tensor[torch.bool]): mask for prediction 123 | raw_type (torch.tensor[torch.int]): type of agents 124 | src_lanes (torch.tensor[torch.float]): lane info 125 | Returns: 126 | torch.tensor[torch.float]: new src for the transformer 127 | torch.tensor[torch.bool]: new src mask 128 | torch.tensor[torch.float]: new map encoding 129 | """ 130 | if pos_pred.ndim == 5: 131 | pred_mask = pred_mask.unsqueeze(1).repeat(1, M, 1, 1) 132 | raw_type = raw_type.unsqueeze(1).repeat(1, M, 1) 133 | 134 | pred_vel = dyn_list[DynType.UNICYCLE].calculate_vel( 135 | pos_pred, yaw_pred, step_time, pred_mask 136 | ) 137 | src_new, _, _ = L5Utils.raw2feature( 138 | pos_pred, 139 | pred_vel, 140 | yaw_pred, 141 | raw_type, 142 | pred_mask, 143 | torch.zeros_like(src_lanes) if src_lanes is not None else None, 144 | add_noise=True, 145 | ) 146 | 147 | if M == 1: 148 | map_emb_new = obtain_map_enc( 149 | data_batch["image"], 150 | CNNmodel, 151 | pos_pred, 152 | yaw_pred, 153 | data_batch["raster_from_agent"], 154 | pred_mask, 155 | torch.tensor(algo_config.CNN.patch_size).to(pos_pred.device), 156 | algo_config.CNN.output_size, 157 | mode="all", 158 | ) 159 | else: 160 | map_emb_new = list() 161 | for i in range(M): 162 | map_emb_new_i = obtain_map_enc( 163 | data_batch["image"], 164 | CNNmodel, 165 | pos_pred[:, i], 166 | yaw_pred[:, i], 167 | data_batch["raster_from_agent"], 168 | pred_mask, 169 | torch.tensor(algo_config.CNN.patch_size).to(pos_pred.device), 170 | algo_config.CNN.output_size, 171 | mode="all", 172 | ) 173 | map_emb_new.append(map_emb_new_i) 174 | map_emb_new = torch.stack(map_emb_new, dim=1) 175 | 176 | return src_new, pred_mask, map_emb_new 177 | -------------------------------------------------------------------------------- /tbsim/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/CTG/f916c008c3ecf2360bfa050639606eaab7c207f5/tbsim/models/__init__.py -------------------------------------------------------------------------------- /tbsim/models/context_encoders.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def init(module, weight_init, bias_init, gain=1): 7 | ''' 8 | This function provides weight and bias initializations for linear layers. 9 | ''' 10 | weight_init(module.weight.data, gain=gain) 11 | bias_init(module.bias.data) 12 | return module 13 | 14 | 15 | class MapEncoderCNN(nn.Module): 16 | ''' 17 | Regular CNN encoder for road image. 18 | ''' 19 | def __init__(self, d_k=64, dropout=0.1, c=10): 20 | super(MapEncoderCNN, self).__init__() 21 | self.dropout = dropout 22 | self.c = c 23 | init_ = lambda m: init(m, nn.init.xavier_normal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)) 24 | # MAP ENCODER 25 | fm_size = 7 26 | self.map_encoder = nn.Sequential( 27 | init_(nn.Conv2d(3, 32, kernel_size=4, stride=1)), nn.ReLU(), 28 | init_(nn.Conv2d(32, 32, kernel_size=4, stride=2)), nn.ReLU(), 29 | init_(nn.Conv2d(32, 32, kernel_size=3, stride=2)), nn.ReLU(), 30 | init_(nn.Conv2d(32, 32, kernel_size=3, stride=2)), nn.ReLU(), 31 | init_(nn.Conv2d(32, fm_size*self.c, kernel_size=2, stride=2)), nn.ReLU(), 32 | nn.Dropout2d(p=self.dropout) 33 | ) 34 | self.map_feats = nn.Sequential( 35 | init_(nn.Linear(7*7*fm_size, d_k)), nn.ReLU(), 36 | init_(nn.Linear(d_k, d_k)), nn.ReLU(), 37 | ) 38 | 39 | def forward(self, roads): 40 | ''' 41 | :param roads: road image with size (B, 128, 128, 3) 42 | :return: road features, with one for every mode (B, c, d_k) 43 | ''' 44 | B = roads.size(0) # batch size 45 | return self.map_feats(self.map_encoder(roads).view(B, self.c, -1)) 46 | 47 | 48 | class MapEncoderPts(nn.Module): 49 | ''' 50 | This class operates on the road lanes provided as a tensor with shape 51 | (B, num_road_segs, num_pts_per_road_seg, map_attr) 52 | ''' 53 | def __init__(self, d_k, map_attr=2, dropout=0.1): 54 | super(MapEncoderPts, self).__init__() 55 | self.dropout = dropout 56 | self.d_k = d_k 57 | self.map_attr = map_attr 58 | init_ = lambda m: init(m, nn.init.xavier_normal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)) 59 | 60 | self.road_pts_lin = nn.Sequential(init_(nn.Linear(map_attr, self.d_k))) 61 | self.road_pts_attn_layer = nn.MultiheadAttention(self.d_k, num_heads=8, dropout=self.dropout) 62 | self.norm1 = nn.LayerNorm(self.d_k, eps=1e-5) 63 | self.norm2 = nn.LayerNorm(self.d_k, eps=1e-5) 64 | self.map_feats = nn.Sequential( 65 | init_(nn.Linear(self.d_k, self.d_k)), nn.ReLU(), nn.Dropout(self.dropout), 66 | init_(nn.Linear(self.d_k, self.d_k)), 67 | ) 68 | 69 | def get_road_pts_mask(self, roads): 70 | road_segment_mask = torch.sum(roads[:, :, :, -1], dim=2) == 0 71 | road_pts_mask = (1.0 - roads[:, :, :, -1]).type(torch.BoolTensor).to(roads.device).view(-1, roads.shape[2]) 72 | road_pts_mask[:, 0][road_pts_mask.sum(-1) == roads.shape[2]] = False # Ensures no NaNs due to empty rows. 73 | return road_segment_mask, road_pts_mask 74 | 75 | def forward(self, roads, agents_emb): 76 | ''' 77 | :param roads: (B, S, P, k_attr+1) where B is batch size, S is num road segments, P is 78 | num pts per road segment. 79 | :param agents_emb: (T_obs, B, d_k) where T_obs is the observation horizon. THis tensor is obtained from 80 | AutoBot's encoder, and basically represents the observed socio-temporal context of agents. 81 | :return: embedded road segments with shape (S) 82 | ''' 83 | B = roads.shape[0] 84 | S = roads.shape[1] 85 | P = roads.shape[2] 86 | road_segment_mask, road_pts_mask = self.get_road_pts_mask(roads) 87 | road_pts_feats = self.road_pts_lin(roads[:, :, :, :self.map_attr]).view(B*S, P, -1).permute(1, 0, 2) 88 | 89 | # Combining information from each road segment using attention with agent contextual embeddings as queries. 90 | agents_emb = agents_emb[-1].unsqueeze(2).repeat(1, 1, S, 1).view(-1, self.d_k).unsqueeze(0) 91 | road_seg_emb = self.road_pts_attn_layer(query=agents_emb, key=road_pts_feats, value=road_pts_feats, 92 | key_padding_mask=road_pts_mask)[0] 93 | road_seg_emb = self.norm1(road_seg_emb) 94 | road_seg_emb2 = road_seg_emb + self.map_feats(road_seg_emb) 95 | road_seg_emb2 = self.norm2(road_seg_emb2) 96 | road_seg_emb = road_seg_emb2.view(B, S, -1) 97 | 98 | return road_seg_emb.permute(1, 0, 2), road_segment_mask 99 | 100 | 101 | class MapEncoderPtsMA(nn.Module): 102 | ''' 103 | This class operates on the multi-agent road lanes provided as a tensor with shape 104 | (B, num_agents, num_road_segs, num_pts_per_road_seg, k_attr+1) 105 | ''' 106 | def __init__(self, d_k, map_attr=3, dropout=0.1): 107 | super(MapEncoderPtsMA, self).__init__() 108 | self.dropout = dropout 109 | self.d_k = d_k 110 | init_ = lambda m: init(m, nn.init.xavier_normal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)) 111 | 112 | self.map_attr = map_attr 113 | 114 | # Seed parameters for the map 115 | self.map_seeds = nn.Parameter(torch.Tensor(1, 1, self.d_k), requires_grad=True) 116 | nn.init.xavier_uniform_(self.map_seeds) 117 | 118 | self.road_pts_lin = nn.Sequential(init_(nn.Linear(self.map_attr, self.d_k))) 119 | self.road_pts_attn_layer = nn.MultiheadAttention(self.d_k, num_heads=8, dropout=self.dropout) 120 | self.norm1 = nn.LayerNorm(self.d_k, eps=1e-5) 121 | self.norm2 = nn.LayerNorm(self.d_k, eps=1e-5) 122 | self.map_feats = nn.Sequential( 123 | init_(nn.Linear(self.d_k, self.d_k*3)), nn.ReLU(), nn.Dropout(self.dropout), 124 | init_(nn.Linear(self.d_k*3, self.d_k)), 125 | ) 126 | 127 | def get_road_pts_mask(self, roads): 128 | road_segment_mask = torch.sum(roads[:, :, :, :, -1], dim=3) == 0 129 | road_pts_mask = (1.0 - roads[:, :, :, :, -1]).type(torch.BoolTensor).to(roads.device).view(-1, roads.shape[3]) 130 | 131 | # The next lines ensure that we do not obtain NaNs during training for missing agents or for empty roads. 132 | road_pts_mask[:, 0][road_pts_mask.sum(-1) == roads.shape[3]] = False # for empty agents 133 | road_segment_mask[:, :, 0][road_segment_mask.sum(-1) == road_segment_mask.shape[2]] = False # for empty roads 134 | return road_segment_mask, road_pts_mask 135 | 136 | def forward(self, roads): 137 | ''' 138 | :param roads: (B, M, S, P, k_attr+1) where B is batch size, M is num_agents, S is num road segments, P is 139 | num pts per road segment. 140 | :return: embedded road segments with shape (S) 141 | ''' 142 | B = roads.shape[0] 143 | M = roads.shape[1] 144 | S = roads.shape[2] 145 | P = roads.shape[3] 146 | road_segment_mask, road_pts_mask = self.get_road_pts_mask(roads) 147 | road_pts_feats = self.road_pts_lin(roads[:, :, :, :, :self.map_attr]).view(B*M*S, P, -1).permute(1, 0, 2) 148 | 149 | # Combining information from each road segment using attention with agent contextual embeddings as queries. 150 | map_seeds = self.map_seeds.repeat(1, B * M * S, 1) 151 | # agents_emb = agents_emb[-1].detach().unsqueeze(2).repeat(1, 1, S, 1).view(-1, self.d_k).unsqueeze(0) 152 | road_seg_emb = self.road_pts_attn_layer(query=map_seeds, key=road_pts_feats, value=road_pts_feats, 153 | key_padding_mask=road_pts_mask)[0] 154 | road_seg_emb = self.norm1(road_seg_emb) 155 | road_seg_emb2 = road_seg_emb + self.map_feats(road_seg_emb) 156 | road_seg_emb2 = self.norm2(road_seg_emb2) 157 | road_seg_emb = road_seg_emb2.view(B, M, S, -1) 158 | 159 | return road_seg_emb.permute(2, 0, 1, 3), road_segment_mask 160 | 161 | -------------------------------------------------------------------------------- /tbsim/models/learned_metrics.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import tbsim.models.base_models as base_models 7 | import tbsim.utils.tensor_utils as TensorUtils 8 | 9 | 10 | class PermuteEBM(nn.Module): 11 | """Raster-based model for planning. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | model_arch: str, 17 | input_image_shape, 18 | map_feature_dim: int, 19 | traj_feature_dim: int, 20 | embedding_dim: int, 21 | embed_layer_dims: tuple 22 | ) -> None: 23 | 24 | super().__init__() 25 | self.map_encoder = base_models.RasterizedMapEncoder( 26 | model_arch=model_arch, 27 | input_image_shape=input_image_shape, 28 | feature_dim=map_feature_dim, 29 | use_spatial_softmax=False, 30 | output_activation=nn.ReLU 31 | ) 32 | self.traj_encoder = base_models.RNNTrajectoryEncoder( 33 | trajectory_dim=3, 34 | rnn_hidden_size=100, 35 | feature_dim=traj_feature_dim 36 | ) 37 | self.embed_net = base_models.MLP( 38 | input_dim=traj_feature_dim + map_feature_dim, 39 | output_dim=embedding_dim, 40 | output_activation=nn.ReLU, 41 | layer_dims=embed_layer_dims 42 | ) 43 | self.score_net = nn.Linear(embedding_dim, 1) 44 | 45 | def forward(self, data_batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: 46 | image_batch = data_batch["image"] 47 | trajs = torch.cat((data_batch["target_positions"], data_batch["target_yaws"]), dim=2) 48 | bs = image_batch.shape[0] 49 | 50 | map_feat = self.map_encoder(image_batch) # [B, D_m] 51 | traj_feat = self.traj_encoder(trajs) # [B, D_t] 52 | 53 | # construct contrastive samples 54 | map_feat_rep = TensorUtils.unsqueeze_expand_at(map_feat, size=bs, dim=1) # [B, B, D_m] 55 | traj_feat_rep = TensorUtils.unsqueeze_expand_at(traj_feat, size=bs, dim=0) # [B, B, D_t] 56 | cat_rep = torch.cat((map_feat_rep, traj_feat_rep), dim=-1) # [B, B, D_m + D_t] 57 | ebm_rep = TensorUtils.time_distributed(cat_rep, self.embed_net) # [B, B, D] 58 | 59 | # calculate embeddings and scores for InfoNCE loss 60 | scores = TensorUtils.time_distributed(ebm_rep, self.score_net).squeeze(-1) # [B, B] 61 | out_dict = dict(features=ebm_rep, scores=scores) 62 | 63 | return out_dict 64 | 65 | def get_scores(self, data_batch): 66 | image_batch = data_batch["image"] 67 | trajs = torch.cat((data_batch["target_positions"], data_batch["target_yaws"]), dim=2) 68 | 69 | map_feat = self.map_encoder(image_batch) # [B, D_m] 70 | traj_feat = self.traj_encoder(trajs) # [B, D_t] 71 | cat_rep = torch.cat((map_feat, traj_feat), dim=-1) # [B, D_m + D_t] 72 | ebm_rep = self.embed_net(cat_rep) 73 | scores = self.score_net(ebm_rep) 74 | out_dict = dict(features=ebm_rep, scores=scores) 75 | 76 | return out_dict 77 | 78 | def compute_losses(self, pred_batch, data_batch): 79 | scores = pred_batch["scores"] 80 | bs = scores.shape[0] 81 | labels = torch.arange(bs).to(scores.device) 82 | loss = nn.CrossEntropyLoss()(scores, labels) 83 | losses = dict(infoNCE_loss=loss) 84 | 85 | return losses -------------------------------------------------------------------------------- /tbsim/models/roi_align.py: -------------------------------------------------------------------------------- 1 | from logging import raiseExceptions 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from tbsim.utils.geometry_utils import batch_nd_transform_points 6 | 7 | 8 | def bilinear_interpolate(img, x, y, floattype=torch.float, flip_y=False): 9 | """Return bilinear interpolation of 4 nearest pts w.r.t to x,y from img 10 | Args: 11 | img (torch.Tensor): Tensor of size cxwxh. Usually one channel of feature layer 12 | x (torch.Tensor): Float dtype, x axis location for sampling 13 | y (torch.Tensor): Float dtype, y axis location for sampling 14 | 15 | Returns: 16 | torch.Tensor: interpolated value 17 | """ 18 | if flip_y: 19 | y = img.shape[-2]-1-y 20 | if img.device.type == "cuda": 21 | x0 = torch.floor(x).type(torch.cuda.LongTensor) 22 | y0 = torch.floor(y).type(torch.cuda.LongTensor) 23 | elif img.device.type == "cpu": 24 | x0 = torch.floor(x).type(torch.LongTensor) 25 | y0 = torch.floor(y).type(torch.LongTensor) 26 | else: 27 | raise ValueError("device not recognized") 28 | x1 = x0 + 1 29 | y1 = y0 + 1 30 | 31 | x0 = torch.clamp(x0, 0, img.shape[-1] - 1) 32 | x1 = torch.clamp(x1, 0, img.shape[-1] - 1) 33 | y0 = torch.clamp(y0, 0, img.shape[-2] - 1) 34 | y1 = torch.clamp(y1, 0, img.shape[-2] - 1) 35 | 36 | Ia = img[..., y0, x0] 37 | Ib = img[..., y1, x0] 38 | Ic = img[..., y0, x1] 39 | Id = img[..., y1, x1] 40 | 41 | step = (x1.type(floattype) - x0.type(floattype)) * ( 42 | y1.type(floattype) - y0.type(floattype) 43 | ) 44 | step = torch.clamp(step, 1e-3, 2) 45 | norm_const = 1 / step 46 | 47 | wa = (x1.type(floattype) - x) * (y1.type(floattype) - y) * norm_const 48 | wb = (x1.type(floattype) - x) * (y - y0.type(floattype)) * norm_const 49 | wc = (x - x0.type(floattype)) * (y1.type(floattype) - y) * norm_const 50 | wd = (x - x0.type(floattype)) * (y - y0.type(floattype)) * norm_const 51 | return ( 52 | Ia * wa.unsqueeze(0) 53 | + Ib * wb.unsqueeze(0) 54 | + Ic * wc.unsqueeze(0) 55 | + Id * wd.unsqueeze(0) 56 | ) 57 | 58 | 59 | def ROI_align(features, ROI, outdim): 60 | """Given feature layers and proposals return bilinear interpolated 61 | points in feature layer 62 | 63 | Args: 64 | features (torch.Tensor): Tensor of shape channels x width x height 65 | proposal (list of torch.Tensor): x0,y0,W1,W2,H1,H2,psi 66 | """ 67 | 68 | bs, num_channels, h, w = features.shape 69 | 70 | xg = ( 71 | torch.cat( 72 | ( 73 | torch.arange(0, outdim).view(-1, 1) - (outdim - 1) / 2, 74 | torch.zeros([outdim, 1]), 75 | ), 76 | dim=-1, 77 | ) 78 | / outdim 79 | ) 80 | yg = ( 81 | torch.cat( 82 | ( 83 | torch.zeros([outdim, 1]), 84 | torch.arange(0, outdim).view(-1, 1) - (outdim - 1) / 2, 85 | ), 86 | dim=-1, 87 | ) 88 | / outdim 89 | ) 90 | gg = xg.view(1, -1, 2) + yg.view(-1, 1, 2) 91 | gg = gg.to(features.device) 92 | res = list() 93 | for i in range(bs): 94 | if ROI[i] is not None: 95 | W1 = ROI[i][..., 2:3] 96 | W2 = ROI[i][..., 3:4] 97 | H1 = ROI[i][..., 4:5] 98 | H2 = ROI[i][..., 5:6] 99 | psi = ROI[i][..., 6:] 100 | WH = torch.cat((W1 + W2, H1 + H2), dim=-1) 101 | offset = torch.cat(((W1 - W2) / 2, (H1 - H2) / 2), dim=-1) 102 | s = torch.sin(psi).unsqueeze(-1) 103 | c = torch.cos(psi).unsqueeze(-1) 104 | rotM = torch.cat( 105 | (torch.cat((c, -s), dim=-1), torch.cat((s, c), dim=-1)), dim=-2 106 | ) 107 | ggi = gg * WH[..., None, None, :] - offset[..., None, None, :] 108 | ggi = ggi @ rotM[..., None, :, :] + ROI[i][..., None, None, 0:2] 109 | 110 | x_sample = ggi[..., 0].flatten() 111 | y_sample = ggi[..., 1].flatten() 112 | res.append( 113 | bilinear_interpolate(features[i], x_sample, y_sample).view( 114 | ggi.shape[0], num_channels, *ggi.shape[1:-1] 115 | ) 116 | ) 117 | else: 118 | res.append(None) 119 | 120 | return res 121 | 122 | 123 | def generate_ROIs( 124 | pos, 125 | yaw, 126 | raster_from_agent, 127 | mask, 128 | patch_size, 129 | mode="last", 130 | ): 131 | """ 132 | This version generates ROI for all agents only at most recent time step unless specified otherwise 133 | """ 134 | if mode == "all": 135 | bs = pos.shape[0] 136 | yaw = yaw.type(torch.float) 137 | Mat = raster_from_agent.view(-1, 1, 1, 3, 3).type(torch.float) 138 | raster_xy = batch_nd_transform_points(pos, Mat) 139 | raster_mult = torch.linalg.norm( 140 | raster_from_agent[0, 0, 0:2], dim=[-1]).item() 141 | patch_size = patch_size.type(torch.float) 142 | patch_size *= raster_mult 143 | ROI = [None] * bs 144 | index = [None] * bs 145 | for i in range(bs): 146 | ii, jj = torch.where(mask[i]) 147 | index[i] = (ii, jj) 148 | if patch_size.ndim == 1: 149 | patches_size = patch_size.repeat(ii.shape[0], 1) 150 | else: 151 | sizes = patch_size[i, ii] 152 | patches_size = torch.cat( 153 | ( 154 | sizes[:, 0:1] * 0.5, 155 | sizes[:, 0:1] * 0.5, 156 | sizes[:, 1:2] * 0.5, 157 | sizes[:, 1:2] * 0.5, 158 | ), 159 | dim=-1, 160 | ) 161 | ROI[i] = torch.cat( 162 | ( 163 | raster_xy[i, ii, jj], 164 | patches_size, 165 | yaw[i, ii, jj], 166 | ), 167 | dim=-1, 168 | ).to(pos.device) 169 | return ROI, index 170 | elif mode == "last": 171 | num = torch.arange(0, mask.shape[2]).view(1, 1, -1).to(mask.device) 172 | nummask = num * mask 173 | last_idx, _ = torch.max(nummask, dim=2) 174 | bs = pos.shape[0] 175 | Mat = raster_from_agent.view(-1, 1, 1, 3, 3).type(torch.float) 176 | raster_xy = batch_nd_transform_points(pos, Mat) 177 | raster_mult = torch.linalg.norm( 178 | raster_from_agent[0, 0, 0:2], dim=[-1]).item() 179 | patch_size = patch_size.type(torch.float) 180 | patch_size *= raster_mult 181 | agent_mask = mask.any(dim=2) 182 | ROI = [None] * bs 183 | index = [None] * bs 184 | for i in range(bs): 185 | ii = torch.where(agent_mask[i])[0] 186 | index[i] = ii 187 | if patch_size.ndim == 1: 188 | patches_size = patch_size.repeat(ii.shape[0], 1) 189 | else: 190 | sizes = patch_size[i, ii] 191 | patches_size = torch.cat( 192 | ( 193 | sizes[:, 0:1] * 0.5, 194 | sizes[:, 0:1] * 0.5, 195 | sizes[:, 1:2] * 0.5, 196 | sizes[:, 1:2] * 0.5, 197 | ), 198 | dim=-1, 199 | ) 200 | ROI[i] = torch.cat( 201 | ( 202 | raster_xy[i, ii, last_idx[i, ii]], 203 | patches_size, 204 | yaw[i, ii, last_idx[i, ii]], 205 | ), 206 | dim=-1, 207 | ) 208 | return ROI, index 209 | else: 210 | raise ValueError("mode must be 'all' or 'last'") 211 | 212 | 213 | def Indexing_ROI_result(CNN_out, index, emb_size): 214 | """put the lists of ROI align result into embedding tensor with the help of index""" 215 | bs = len(CNN_out) 216 | map_emb = torch.zeros(emb_size).to(CNN_out[0].device) 217 | if map_emb.ndim == 3: 218 | for i in range(bs): 219 | map_emb[i, index[i]] = CNN_out[i] 220 | elif map_emb.ndim == 4: 221 | for i in range(bs): 222 | ii, jj = index[i] 223 | map_emb[i, ii, jj] = CNN_out[i] 224 | else: 225 | raise ValueError("wrong dimension for the map embedding!") 226 | 227 | return map_emb 228 | -------------------------------------------------------------------------------- /tbsim/models/temporal.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import einops 4 | from einops.layers.torch import Rearrange 5 | import pdb 6 | 7 | from .diffuser_helpers import ( 8 | SinusoidalPosEmb, 9 | Downsample1d, 10 | Upsample1d, 11 | Conv1dBlock, 12 | ) 13 | 14 | import tbsim.utils.tensor_utils as TensorUtils 15 | 16 | class ResidualTemporalMapBlockConcat(nn.Module): 17 | 18 | def __init__(self, inp_channels, out_channels, time_embed_dim, horizon, kernel_size=5): 19 | super().__init__() 20 | 21 | self.time_mlp = nn.Sequential( 22 | nn.Mish(), 23 | nn.Linear(time_embed_dim, out_channels), 24 | Rearrange('batch t -> batch t 1'), 25 | ) 26 | 27 | self.blocks = nn.ModuleList([ 28 | Conv1dBlock(inp_channels, out_channels, kernel_size), 29 | Conv1dBlock(out_channels, out_channels, kernel_size), 30 | ]) 31 | 32 | self.residual_conv = nn.Conv1d(inp_channels, out_channels, 1) \ 33 | if inp_channels != out_channels else nn.Identity() 34 | 35 | def forward(self, x, t): 36 | ''' 37 | x : [ batch_size x inp_channels x horizon ] 38 | t : [ batch_size x embed_dim ] 39 | returns: 40 | out : [ batch_size x out_channels x horizon ] 41 | ''' 42 | 43 | out = self.blocks[0](x) + self.time_mlp(t) 44 | out = self.blocks[1](out) 45 | return out + self.residual_conv(x) 46 | 47 | 48 | 49 | class TemporalMapUnet(nn.Module): 50 | 51 | def __init__( 52 | self, 53 | horizon, 54 | transition_dim, 55 | cond_dim, 56 | output_dim, 57 | dim=32, 58 | dim_mults=(1, 2, 4, 8), 59 | diffuser_building_block='concat' 60 | ): 61 | super().__init__() 62 | 63 | if diffuser_building_block == 'concat': 64 | ResidualTemporalMapBlock = ResidualTemporalMapBlockConcat 65 | else: 66 | raise NotImplementedError 67 | 68 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] 69 | in_out = list(zip(dims[:-1], dims[1:])) 70 | print(f'[ models/temporal ] Channel dimensions: {in_out}') 71 | 72 | time_dim = dim 73 | 74 | self.time_mlp = nn.Sequential( 75 | SinusoidalPosEmb(time_dim), 76 | nn.Linear(time_dim, time_dim * 4), 77 | nn.Mish(), 78 | nn.Linear(time_dim * 4, time_dim), 79 | ) 80 | 81 | cond_dim = cond_dim + time_dim 82 | 83 | self.downs = nn.ModuleList([]) 84 | self.ups = nn.ModuleList([]) 85 | num_resolutions = len(in_out) 86 | 87 | for ind, (dim_in, dim_out) in enumerate(in_out): 88 | is_last = ind >= (num_resolutions - 1) 89 | 90 | self.downs.append(nn.ModuleList([ 91 | ResidualTemporalMapBlock(dim_in, dim_out, time_embed_dim=cond_dim, horizon=horizon), 92 | ResidualTemporalMapBlock(dim_out, dim_out, time_embed_dim=cond_dim, horizon=horizon), 93 | Downsample1d(dim_out) if not is_last else nn.Identity() 94 | ])) 95 | 96 | if not is_last: 97 | horizon = horizon // 2 98 | 99 | mid_dim = dims[-1] 100 | self.mid_block1 = ResidualTemporalMapBlock(mid_dim, mid_dim, time_embed_dim=cond_dim, horizon=horizon) 101 | self.mid_block2 = ResidualTemporalMapBlock(mid_dim, mid_dim, time_embed_dim=cond_dim, horizon=horizon) 102 | 103 | final_up_dim = None 104 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 105 | is_last = ind >= (num_resolutions - 1) 106 | 107 | self.ups.append(nn.ModuleList([ 108 | ResidualTemporalMapBlock(dim_out * 2, dim_in, time_embed_dim=cond_dim, horizon=horizon), 109 | ResidualTemporalMapBlock(dim_in, dim_in, time_embed_dim=cond_dim, horizon=horizon), 110 | Upsample1d(dim_in) if not is_last else nn.Identity() 111 | ])) 112 | final_up_dim = dim_in 113 | 114 | if not is_last: 115 | horizon = horizon * 2 116 | 117 | self.final_conv = nn.Sequential( 118 | Conv1dBlock(final_up_dim, final_up_dim, kernel_size=5), 119 | nn.Conv1d(final_up_dim, output_dim, 1), 120 | ) 121 | 122 | def forward(self, x, aux_info, time): 123 | ''' 124 | x : [ B*N, T, D ] or [ B*N, M, T, D ] 125 | aux_info['cond_feat'] : [B*N, C] or [ B*N, M, C ] 126 | ''' 127 | len_x_shape = len(x.shape) 128 | if len_x_shape == 4: 129 | BN, M, T, _ = x.shape 130 | # [ B*N, M, T, D ] -> [ B*N*M, T, D ] 131 | x = x.reshape(BN * M, T, -1) 132 | # [ B*N, M, C ] -> [ B*N*M, C ] 133 | cond_feat = aux_info["cond_feat"].reshape(BN * M, -1) 134 | # [ B*N ] -> [ B*N*M ] 135 | time = TensorUtils.repeat_by_expand_at(time, repeats=M, dim=0) 136 | else: 137 | cond_feat = aux_info['cond_feat'] 138 | x = einops.rearrange(x, 'b h t -> b t h') 139 | # print('rearrange x.size()', x.size()) 140 | # print('time', time) 141 | # print('time.size()', time.size()) 142 | # (B*N) -> (B*N, K_d) 143 | t = self.time_mlp(time) 144 | # print('t.size()', t.size()) 145 | t = torch.cat([t, cond_feat], dim=-1) 146 | # raise 147 | h = [] 148 | for resnet, resnet2, downsample in self.downs: 149 | # print('down1 x.size()', x.size()) 150 | x = resnet(x, t) 151 | # print('down2 x.size()', x.size()) 152 | x = resnet2(x, t) 153 | # print('down3 x.size()', x.size()) 154 | h.append(x) 155 | x = downsample(x) 156 | x = self.mid_block1(x, t) 157 | # print('mid1 x.size()', x.size()) 158 | x = self.mid_block2(x, t) 159 | # print('mid2 x.size()', x.size()) 160 | for resnet, resnet2, upsample in self.ups: 161 | # print('x.shape', x.shape) 162 | # print('h[-1].shape', h[-1].shape) 163 | x = torch.cat((x, h.pop()), dim=1) 164 | # print('up1 x.size()', x.size()) 165 | x = resnet(x, t) 166 | # print('up2 x.size()', x.size()) 167 | x = resnet2(x, t) 168 | # print('up3 x.size()', x.size()) 169 | x = upsample(x) 170 | 171 | 172 | x = self.final_conv(x) 173 | # print('final conv x.size()', x.size()) 174 | 175 | x = einops.rearrange(x, 'b t h -> b h t') 176 | # print('output x.size()', x.size()) 177 | if len_x_shape == 4: 178 | x = x.reshape(BN, M, T, -1) 179 | return x -------------------------------------------------------------------------------- /tbsim/policies/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/CTG/f916c008c3ecf2360bfa050639606eaab7c207f5/tbsim/policies/__init__.py -------------------------------------------------------------------------------- /tbsim/policies/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | 3 | 4 | class Policy(abc.ABC): 5 | def __init__(self, device, *args, **kwargs): 6 | self.device = device 7 | 8 | @abc.abstractmethod 9 | def get_action(self, obs_dict, **kwargs): 10 | """Predict an action based on the input observation """ 11 | pass 12 | 13 | @abc.abstractmethod 14 | def eval(self): 15 | """Set the policy to evaluation mode""" 16 | pass -------------------------------------------------------------------------------- /tbsim/policies/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from copy import deepcopy 4 | 5 | import tbsim.utils.tensor_utils as TensorUtils 6 | from tbsim.utils.geometry_utils import transform_points_tensor 7 | from l5kit.geometry import transform_points 8 | 9 | 10 | class Trajectory(object): 11 | """Container for sequences of 2D positions and yaws""" 12 | def __init__(self, positions, yaws): 13 | assert positions.shape[:-1] == yaws.shape[:-1] 14 | assert positions.shape[-1] == 2 15 | assert yaws.shape[-1] == 1 16 | self._positions = positions 17 | self._yaws = yaws 18 | 19 | @property 20 | def trajectories(self): 21 | if isinstance(self.positions, np.ndarray): 22 | return np.concatenate([self._positions, self._yaws], axis=-1) 23 | else: 24 | return torch.cat([self._positions, self._yaws], dim=-1) 25 | 26 | @property 27 | def positions(self): 28 | return TensorUtils.clone(self._positions) 29 | 30 | @positions.setter 31 | def positions(self, x): 32 | self._positions = TensorUtils.clone(x) 33 | 34 | @property 35 | def yaws(self): 36 | return TensorUtils.clone(self._yaws) 37 | 38 | @yaws.setter 39 | def yaws(self, x): 40 | self._yaws = TensorUtils.clone(x) 41 | 42 | def to_dict(self): 43 | return dict( 44 | positions=self.positions, 45 | yaws=self.yaws 46 | ) 47 | 48 | @classmethod 49 | def from_dict(cls, d): 50 | return cls(**d) 51 | 52 | def transform(self, trans_mats, rot_rads): 53 | if isinstance(self.positions, np.ndarray): 54 | pos = transform_points(self.positions, trans_mats) 55 | else: 56 | pos = transform_points_tensor(self.positions, trans_mats) 57 | 58 | yaw = self.yaws + rot_rads 59 | return self.__class__(pos, yaw) 60 | 61 | def to_numpy(self): 62 | return self.__class__(**TensorUtils.to_numpy(self.to_dict())) 63 | 64 | 65 | class Action(Trajectory): 66 | pass 67 | 68 | 69 | class Plan(Trajectory): 70 | """Container for sequences of 2D positions, yaws, controls, availabilities.""" 71 | def __init__(self, positions, yaws, availabilities, controls=None): 72 | assert positions.shape[:-1] == yaws.shape[:-1] 73 | assert positions.shape[-1] == 2 74 | assert yaws.shape[-1] == 1 75 | assert availabilities.shape == positions.shape[:-1] 76 | self._positions = positions 77 | self._yaws = yaws 78 | self._availabilities = availabilities 79 | self._controls = controls 80 | 81 | @property 82 | def availabilities(self): 83 | return TensorUtils.clone(self._availabilities) 84 | 85 | @property 86 | def controls(self): 87 | return TensorUtils.clone(self._controls) 88 | 89 | def to_dict(self): 90 | p = dict( 91 | positions=self.positions, 92 | yaws=self.yaws, 93 | availabilities=self.availabilities, 94 | ) 95 | if self._controls is not None: 96 | p["controls"] = self.controls 97 | return p 98 | 99 | def transform(self, trans_mats, rot_rads): 100 | if isinstance(self.positions, np.ndarray): 101 | pos = transform_points(self.positions, trans_mats) 102 | else: 103 | pos = transform_points_tensor(self.positions, trans_mats) 104 | 105 | yaw = self.yaws + rot_rads 106 | return self.__class__(pos, yaw, self.availabilities, controls=self.controls) 107 | 108 | 109 | class RolloutAction(object): 110 | """Actions used to control agent rollouts""" 111 | def __init__(self, ego=None, ego_info=None, agents=None, agents_info=None): 112 | assert ego is None or isinstance(ego, Action) 113 | assert agents is None or isinstance(agents, Action) 114 | assert ego_info is None or isinstance(ego_info, dict) 115 | assert agents_info is None or isinstance(agents_info, dict) 116 | 117 | self.ego = ego 118 | self.ego_info = ego_info 119 | self.agents = agents 120 | self.agents_info = agents_info 121 | 122 | @property 123 | def has_ego(self): 124 | return self.ego is not None 125 | 126 | @property 127 | def has_agents(self): 128 | return self.agents is not None 129 | 130 | def transform(self, ego_trans_mats, ego_rot_rads, agents_trans_mats=None, agents_rot_rads=None): 131 | trans_action = RolloutAction() 132 | if self.has_ego: 133 | trans_action.ego = self.ego.transform( 134 | trans_mats=ego_trans_mats, rot_rads=ego_rot_rads) 135 | if self.ego_info is not None: 136 | trans_action.ego_info = deepcopy(self.ego_info) 137 | if "plan" in trans_action.ego_info: 138 | plan = Plan.from_dict(trans_action.ego_info["plan"]) 139 | trans_action.ego_info["plan"] = plan.transform( 140 | trans_mats=ego_trans_mats, rot_rads=ego_rot_rads 141 | ).to_dict() 142 | if self.has_agents: 143 | assert agents_trans_mats is not None and agents_rot_rads is not None 144 | trans_action.agents = self.agents.transform( 145 | trans_mats=agents_trans_mats, rot_rads=agents_rot_rads) 146 | if self.agents_info is not None: 147 | trans_action.agents_info = deepcopy(self.agents_info) 148 | if "plan" in trans_action.agents_info: 149 | plan = Plan.from_dict(trans_action.agents_info["plan"]) 150 | trans_action.agents_info["plan"] = plan.transform( 151 | trans_mats=agents_trans_mats, rot_rads=agents_rot_rads 152 | ).to_dict() 153 | return trans_action 154 | 155 | def to_dict(self): 156 | d = dict() 157 | if self.has_ego: 158 | d["ego"] = self.ego.to_dict() 159 | d["ego_info"] = deepcopy(self.ego_info) 160 | if self.has_agents: 161 | d["agents"] = self.agents.to_dict() 162 | d["agents_info"] = deepcopy(self.agents_info) 163 | return d 164 | 165 | def to_numpy(self): 166 | return self.__class__( 167 | ego=self.ego.to_numpy() if self.has_ego else None, 168 | ego_info=TensorUtils.to_numpy( 169 | self.ego_info) if self.has_ego else None, 170 | agents=self.agents.to_numpy() if self.has_agents else None, 171 | agents_info=TensorUtils.to_numpy( 172 | self.agents_info) if self.has_agents else None, 173 | ) 174 | 175 | @classmethod 176 | def from_dict(cls, d): 177 | d = deepcopy(d) 178 | if "ego" in d: 179 | d["ego"] = Action.from_dict(d["ego"]) 180 | if "agents" in d: 181 | d["agents"] = Action.from_dict(d["agents"]) 182 | return cls(**d) 183 | 184 | -------------------------------------------------------------------------------- /tbsim/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/CTG/f916c008c3ecf2360bfa050639606eaab7c207f5/tbsim/utils/__init__.py -------------------------------------------------------------------------------- /tbsim/utils/config_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tbsim.configs.registry import get_registered_experiment_config 3 | from tbsim.configs.base import ExperimentConfig 4 | from tbsim.configs.config import Dict 5 | 6 | 7 | def translate_l5kit_cfg(cfg: ExperimentConfig): 8 | """ 9 | Translate a tbsim config to a l5kit config 10 | 11 | Args: 12 | cfg (ExperimentConfig): an ExperimentConfig instance 13 | 14 | Returns: 15 | cfg for l5kit 16 | """ 17 | rcfg = dict() 18 | 19 | rcfg["raster_params"] = cfg.env.rasterizer.to_dict() 20 | rcfg["raster_params"]["dataset_meta_key"] = cfg.train.dataset_meta_key 21 | rcfg["model_params"] = cfg.algo 22 | if "data_generation_params" in cfg.env.keys(): 23 | rcfg["data_generation_params"] = cfg.env["data_generation_params"] 24 | return rcfg 25 | 26 | 27 | def get_experiment_config_from_file(file_path, locked=False): 28 | ext_cfg = json.load(open(file_path, "r")) 29 | cfg = get_registered_experiment_config(ext_cfg["registered_name"]) 30 | cfg.update(**ext_cfg) 31 | cfg.lock(locked) 32 | return cfg 33 | 34 | 35 | def translate_trajdata_cfg(cfg: ExperimentConfig): 36 | rcfg = Dict() 37 | # assert cfg.algo.step_time == 0.5 # TODO: support interpolation 38 | if "scene_centric" in cfg.algo and cfg.algo.scene_centric: 39 | rcfg.centric="scene" 40 | else: 41 | rcfg.centric="agent" 42 | if "standardize_data" in cfg.env.data_generation_params: 43 | rcfg.standardize_data = cfg.env.data_generation_params.standardize_data 44 | else: 45 | rcfg.standardize_data = True 46 | rcfg.step_time = cfg.algo.step_time 47 | rcfg.trajdata_source_root = cfg.train.trajdata_source_root 48 | rcfg.trajdata_source_train = cfg.train.trajdata_source_train 49 | rcfg.trajdata_source_train_val = cfg.train.trajdata_source_train_val 50 | rcfg.trajdata_source_valid = cfg.train.trajdata_source_valid 51 | rcfg.dataset_path = cfg.train.dataset_path 52 | rcfg.history_num_frames = cfg.algo.history_num_frames 53 | rcfg.future_num_frames = cfg.algo.future_num_frames 54 | rcfg.other_agents_num = cfg.env.data_generation_params.other_agents_num 55 | rcfg.max_agents_distance = cfg.env.data_generation_params.max_agents_distance 56 | rcfg.max_agents_distance_simulation = cfg.env.simulation.distance_th_close 57 | rcfg.pixel_size = cfg.env.rasterizer.pixel_size 58 | rcfg.raster_size = int(cfg.env.rasterizer.raster_size) 59 | rcfg.raster_center = cfg.env.rasterizer.ego_center 60 | rcfg.yaw_correction_speed = cfg.env.data_generation_params.yaw_correction_speed 61 | if "vectorize_lane" in cfg.env.data_generation_params: 62 | rcfg.vectorize_lane = cfg.env.data_generation_params.vectorize_lane 63 | else: 64 | rcfg.vectorize_lane = "None" 65 | 66 | rcfg.lock() 67 | return rcfg 68 | 69 | 70 | def translate_pass_trajdata_cfg(cfg: ExperimentConfig): 71 | """ 72 | Translate a unified passthrough config to trajdata. 73 | """ 74 | rcfg = Dict() 75 | rcfg.step_time = cfg.algo.step_time 76 | rcfg.trajdata_cache_location = cfg.train.trajdata_cache_location 77 | rcfg.trajdata_source_train = cfg.train.trajdata_source_train 78 | rcfg.trajdata_source_valid = cfg.train.trajdata_source_valid 79 | rcfg.trajdata_data_dirs = cfg.train.trajdata_data_dirs 80 | rcfg.trajdata_rebuild_cache = cfg.train.trajdata_rebuild_cache 81 | 82 | rcfg.history_num_frames = cfg.algo.history_num_frames 83 | rcfg.future_num_frames = cfg.algo.future_num_frames 84 | 85 | rcfg.trajdata_centric = cfg.env.data_generation_params.trajdata_centric 86 | rcfg.trajdata_only_types = cfg.env.data_generation_params.trajdata_only_types 87 | rcfg.trajdata_predict_types = cfg.env.data_generation_params.trajdata_predict_types 88 | rcfg.trajdata_incl_map = cfg.env.data_generation_params.trajdata_incl_map 89 | rcfg.other_agents_num = cfg.env.data_generation_params.other_agents_num 90 | rcfg.max_agents_distance = cfg.env.data_generation_params.trajdata_max_agents_distance 91 | rcfg.trajdata_standardize_data = cfg.env.data_generation_params.trajdata_standardize_data 92 | rcfg.trajdata_scene_desc_contains = cfg.env.data_generation_params.trajdata_scene_desc_contains 93 | 94 | rcfg.pixel_size = cfg.env.rasterizer.pixel_size 95 | rcfg.raster_size = int(cfg.env.rasterizer.raster_size) 96 | rcfg.raster_center = cfg.env.rasterizer.ego_center 97 | rcfg.num_sem_layers = cfg.env.rasterizer.num_sem_layers 98 | rcfg.drivable_layers = cfg.env.rasterizer.drivable_layers 99 | rcfg.no_map_fill_value = cfg.env.rasterizer.no_map_fill_value 100 | rcfg.raster_include_hist = cfg.env.rasterizer.include_hist 101 | 102 | rcfg.lock() 103 | return rcfg 104 | -------------------------------------------------------------------------------- /tbsim/utils/diffuser_utils/arrays.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import torch 4 | import pdb 5 | 6 | DTYPE = torch.float 7 | DEVICE = 'cuda:0' 8 | 9 | #-----------------------------------------------------------------------------# 10 | #------------------------------ numpy <--> torch -----------------------------# 11 | #-----------------------------------------------------------------------------# 12 | 13 | def to_np(x): 14 | if torch.is_tensor(x): 15 | x = x.detach().cpu().numpy() 16 | return x 17 | 18 | def to_torch(x, dtype=None, device=None): 19 | dtype = dtype or DTYPE 20 | device = device or DEVICE 21 | if type(x) is dict: 22 | return {k: to_torch(v, dtype, device) for k, v in x.items()} 23 | elif torch.is_tensor(x): 24 | return x.to(device).type(dtype) 25 | # import pdb; pdb.set_trace() 26 | return torch.tensor(x, dtype=dtype, device=device) 27 | 28 | def to_device(x, device=DEVICE): 29 | if torch.is_tensor(x): 30 | return x.to(device) 31 | elif type(x) is dict: 32 | return {k: to_device(v, device) for k, v in x.items()} 33 | else: 34 | print(f'Unrecognized type in `to_device`: {type(x)}') 35 | pdb.set_trace() 36 | # return [x.to(device) for x in xs] 37 | 38 | # def atleast_2d(x, axis=0): 39 | # ''' 40 | # works for both np arrays and torch tensors 41 | # ''' 42 | # while len(x.shape) < 2: 43 | # shape = (1, *x.shape) if axis == 0 else (*x.shape, 1) 44 | # x = x.reshape(*shape) 45 | # return x 46 | 47 | # def to_2d(x): 48 | # dim = x.shape[-1] 49 | # return x.reshape(-1, dim) 50 | 51 | def batchify(batch): 52 | ''' 53 | convert a single dataset item to a batch suitable for passing to a model by 54 | 1) converting np arrays to torch tensors and 55 | 2) and ensuring that everything has a batch dimension 56 | ''' 57 | fn = lambda x: to_torch(x[None]) 58 | 59 | batched_vals = [] 60 | for field in batch._fields: 61 | val = getattr(batch, field) 62 | val = apply_dict(fn, val) if type(val) is dict else fn(val) 63 | batched_vals.append(val) 64 | return type(batch)(*batched_vals) 65 | 66 | def apply_dict(fn, d, *args, **kwargs): 67 | return { 68 | k: fn(v, *args, **kwargs) 69 | for k, v in d.items() 70 | } 71 | 72 | def normalize(x): 73 | """ 74 | scales `x` to [0, 1] 75 | """ 76 | x = x - x.min() 77 | x = x / x.max() 78 | return x 79 | 80 | def to_img(x): 81 | normalized = normalize(x) 82 | array = to_np(normalized) 83 | array = np.transpose(array, (1,2,0)) 84 | return (array * 255).astype(np.uint8) 85 | 86 | def set_device(device): 87 | DEVICE = device 88 | if 'cuda' in device: 89 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 90 | 91 | def batch_to_device(batch, device='cuda:0'): 92 | vals = [ 93 | to_device(getattr(batch, field), device) 94 | for field in batch._fields 95 | ] 96 | return type(batch)(*vals) 97 | 98 | def _to_str(num): 99 | if num >= 1e6: 100 | return f'{(num/1e6):.2f} M' 101 | else: 102 | return f'{(num/1e3):.2f} k' 103 | 104 | #-----------------------------------------------------------------------------# 105 | #----------------------------- parameter counting ----------------------------# 106 | #-----------------------------------------------------------------------------# 107 | 108 | def param_to_module(param): 109 | module_name = param[::-1].split('.', maxsplit=1)[-1][::-1] 110 | return module_name 111 | 112 | def report_parameters(model, topk=10): 113 | counts = {k: p.numel() for k, p in model.named_parameters()} 114 | n_parameters = sum(counts.values()) 115 | print(f'[ utils/arrays ] Total parameters: {_to_str(n_parameters)}') 116 | 117 | modules = dict(model.named_modules()) 118 | sorted_keys = sorted(counts, key=lambda x: -counts[x]) 119 | max_length = max([len(k) for k in sorted_keys]) 120 | for i in range(topk): 121 | key = sorted_keys[i] 122 | count = counts[key] 123 | module = param_to_module(key) 124 | print(' '*8, f'{key:10}: {_to_str(count)} | {modules[module]}') 125 | 126 | remaining_parameters = sum([counts[k] for k in sorted_keys[topk:]]) 127 | print(' '*8, f'... and {len(counts)-topk} others accounting for {_to_str(remaining_parameters)} parameters') 128 | return n_parameters 129 | -------------------------------------------------------------------------------- /tbsim/utils/diffuser_utils/progress.py: -------------------------------------------------------------------------------- 1 | import time 2 | import math 3 | import pdb 4 | 5 | class Progress: 6 | 7 | def __init__(self, total, name = 'Progress', ncol=3, max_length=20, indent=0, line_width=100, speed_update_freq=100): 8 | self.total = total 9 | self.name = name 10 | self.ncol = ncol 11 | self.max_length = max_length 12 | self.indent = indent 13 | self.line_width = line_width 14 | self._speed_update_freq = speed_update_freq 15 | 16 | self._step = 0 17 | self._prev_line = '\033[F' 18 | self._clear_line = ' ' * self.line_width 19 | 20 | self._pbar_size = self.ncol * self.max_length 21 | self._complete_pbar = '#' * self._pbar_size 22 | self._incomplete_pbar = ' ' * self._pbar_size 23 | 24 | self.lines = [''] 25 | self.fraction = '{} / {}'.format(0, self.total) 26 | 27 | self.resume() 28 | 29 | 30 | def update(self, description, n=1): 31 | self._step += n 32 | if self._step % self._speed_update_freq == 0: 33 | self._time0 = time.time() 34 | self._step0 = self._step 35 | self.set_description(description) 36 | 37 | def resume(self): 38 | self._skip_lines = 1 39 | print('\n', end='') 40 | self._time0 = time.time() 41 | self._step0 = self._step 42 | 43 | def pause(self): 44 | self._clear() 45 | self._skip_lines = 1 46 | 47 | def set_description(self, params=[]): 48 | 49 | if type(params) == dict: 50 | params = sorted([ 51 | (key, val) 52 | for key, val in params.items() 53 | ]) 54 | 55 | ############ 56 | # Position # 57 | ############ 58 | self._clear() 59 | 60 | ########### 61 | # Percent # 62 | ########### 63 | percent, fraction = self._format_percent(self._step, self.total) 64 | self.fraction = fraction 65 | 66 | ######### 67 | # Speed # 68 | ######### 69 | speed = self._format_speed(self._step) 70 | 71 | ########## 72 | # Params # 73 | ########## 74 | num_params = len(params) 75 | nrow = math.ceil(num_params / self.ncol) 76 | params_split = self._chunk(params, self.ncol) 77 | params_string, lines = self._format(params_split) 78 | self.lines = lines 79 | 80 | 81 | description = '{} | {}{}'.format(percent, speed, params_string) 82 | print(description) 83 | self._skip_lines = nrow + 1 84 | 85 | def append_description(self, descr): 86 | self.lines.append(descr) 87 | 88 | def _clear(self): 89 | position = self._prev_line * self._skip_lines 90 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) 91 | print(position, end='') 92 | print(empty) 93 | print(position, end='') 94 | 95 | def _format_percent(self, n, total): 96 | if total: 97 | percent = n / float(total) 98 | 99 | complete_entries = int(percent * self._pbar_size) 100 | incomplete_entries = self._pbar_size - complete_entries 101 | 102 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] 103 | fraction = '{} / {}'.format(n, total) 104 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent*100)) 105 | else: 106 | fraction = '{}'.format(n) 107 | string = '{} iterations'.format(n) 108 | return string, fraction 109 | 110 | def _format_speed(self, n): 111 | num_steps = n - self._step0 112 | t = time.time() - self._time0 113 | speed = num_steps / t 114 | string = '{:.1f} Hz'.format(speed) 115 | if num_steps > 0: 116 | self._speed = string 117 | return string 118 | 119 | def _chunk(self, l, n): 120 | return [l[i:i+n] for i in range(0, len(l), n)] 121 | 122 | def _format(self, chunks): 123 | lines = [self._format_chunk(chunk) for chunk in chunks] 124 | lines.insert(0,'') 125 | padding = '\n' + ' '*self.indent 126 | string = padding.join(lines) 127 | return string, lines 128 | 129 | def _format_chunk(self, chunk): 130 | line = ' | '.join([self._format_param(param) for param in chunk]) 131 | return line 132 | 133 | def _format_param(self, param): 134 | k, v = param 135 | return '{} : {}'.format(k, v)[:self.max_length] 136 | 137 | def stamp(self): 138 | if self.lines != ['']: 139 | params = ' | '.join(self.lines) 140 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) 141 | self._clear() 142 | print(string, end='\n') 143 | self._skip_lines = 1 144 | else: 145 | self._clear() 146 | self._skip_lines = 0 147 | 148 | def close(self): 149 | self.pause() 150 | 151 | class Silent: 152 | 153 | def __init__(self, *args, **kwargs): 154 | pass 155 | 156 | def __getattr__(self, attr): 157 | return lambda *args: None 158 | 159 | 160 | if __name__ == '__main__': 161 | silent = Silent() 162 | silent.update() 163 | silent.stamp() 164 | 165 | num_steps = 1000 166 | progress = Progress(num_steps) 167 | for i in range(num_steps): 168 | progress.update() 169 | params = [ 170 | ['A', '{:06d}'.format(i)], 171 | ['B', '{:06d}'.format(i)], 172 | ['C', '{:06d}'.format(i)], 173 | ['D', '{:06d}'.format(i)], 174 | ['E', '{:06d}'.format(i)], 175 | ['F', '{:06d}'.format(i)], 176 | ['G', '{:06d}'.format(i)], 177 | ['H', '{:06d}'.format(i)], 178 | ] 179 | progress.set_description(params) 180 | time.sleep(0.01) 181 | progress.close() 182 | -------------------------------------------------------------------------------- /tbsim/utils/ema.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Optional, Union, Dict, Any 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | from pytorch_lightning.utilities import rank_zero_only 7 | 8 | 9 | class EMA(pl.Callback): 10 | """Implements EMA (exponential moving average) to any kind of model. 11 | EMA weights will be used during validation and stored separately from original model weights. 12 | 13 | How to use EMA: 14 | - Sometimes, last EMA checkpoint isn't the best as EMA weights metrics can show long oscillations in time. See 15 | https://github.com/rwightman/pytorch-image-models/issues/102 16 | - Batch Norm layers and likely any other type of norm layers doesn't need to be updated at the end. See 17 | discussions in: https://github.com/rwightman/pytorch-image-models/issues/106#issuecomment-609461088 and 18 | https://github.com/rwightman/pytorch-image-models/issues/224 19 | - For object detection, SWA usually works better. See https://github.com/timgaripov/swa/issues/16 20 | 21 | Implementation detail: 22 | - See EMA in Pytorch Lightning: https://github.com/PyTorchLightning/pytorch-lightning/issues/10914 23 | - When multi gpu, we broadcast ema weights and the original weights in order to only hold 1 copy in memory. 24 | This is specially relevant when storing EMA weights on CPU + pinned memory as pinned memory is a limited 25 | resource. In addition, we want to avoid duplicated operations in ranks != 0 to reduce jitter and improve 26 | performance. 27 | """ 28 | def __init__(self, decay: float = 0.995, ema_device: Optional[Union[torch.device, str]] = None, pin_memory=True, every_n_train_steps: int = 10): 29 | super().__init__() 30 | self.decay = decay 31 | self.ema_device: str = f"{ema_device}" if ema_device else None # perform ema on different device from the model 32 | self.ema_pin_memory = pin_memory if torch.cuda.is_available() else False # Only works if CUDA is available 33 | self.ema_state_dict: Dict[str, torch.Tensor] = {} 34 | self.original_state_dict = {} 35 | self._ema_state_dict_ready = False 36 | self._every_n_train_steps = every_n_train_steps 37 | 38 | @staticmethod 39 | def get_state_dict(pl_module: pl.LightningModule): 40 | """Returns state dictionary from pl_module. Override if you want filter some parameters and/or buffers out. 41 | For example, in pl_module has metrics, you don't want to return their parameters. 42 | 43 | code: 44 | # Only consider modules that can be seen by optimizers. Lightning modules can have others nn.Module attached 45 | # like losses, metrics, etc. 46 | patterns_to_ignore = ("metrics1", "metrics2") 47 | return dict(filter(lambda i: i[0].startswith(patterns), pl_module.state_dict().items())) 48 | """ 49 | return pl_module.state_dict() 50 | 51 | def on_train_start(self, trainer: "pl.Trainer", pl_module: pl.LightningModule) -> None: 52 | # Only keep track of EMA weights in rank zero. 53 | if not self._ema_state_dict_ready and pl_module.global_rank == 0: 54 | self.ema_state_dict = deepcopy(self.get_state_dict(pl_module)) 55 | if self.ema_device: 56 | self.ema_state_dict = {k: tensor.to(device=self.ema_device) for k, tensor in self.ema_state_dict.items()} 57 | 58 | if self.ema_device == "cpu" and self.ema_pin_memory: 59 | self.ema_state_dict = {k: tensor.pin_memory() for k, tensor in self.ema_state_dict.items()} 60 | 61 | self._ema_state_dict_ready = True 62 | 63 | @rank_zero_only 64 | def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: pl.LightningModule, *args, **kwargs) -> None: 65 | # Update EMA weights 66 | if self._every_n_train_steps >= 1 and (trainer.global_step % self._every_n_train_steps == 0): 67 | # print('update ema') 68 | with torch.no_grad(): 69 | for key, value in self.get_state_dict(pl_module).items(): 70 | ema_value = self.ema_state_dict[key] 71 | ema_value.copy_(self.decay * ema_value + (1. - self.decay) * value, non_blocking=True) 72 | 73 | def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: 74 | if not self._ema_state_dict_ready: 75 | return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint. 76 | 77 | self.original_state_dict = deepcopy(self.get_state_dict(pl_module)) 78 | ema_state_dict = pl_module.trainer.training_type_plugin.broadcast(self.ema_state_dict, 0) 79 | self.ema_state_dict = ema_state_dict 80 | 81 | def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: 82 | if not self._ema_state_dict_ready: 83 | return # Skip Lightning sanity validation check if no ema weights has been loaded from a checkpoint. 84 | 85 | # Replace EMA weights with training weights 86 | pl_module.load_state_dict(self.original_state_dict, strict=False) 87 | 88 | def on_save_checkpoint( 89 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any] 90 | ) -> dict: 91 | return {"ema_state_dict": self.ema_state_dict, "_ema_state_dict_ready": self._ema_state_dict_ready} 92 | 93 | def on_load_checkpoint( 94 | self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] 95 | ) -> None: 96 | self._ema_state_dict_ready = callback_state["_ema_state_dict_ready"] 97 | self.ema_state_dict = callback_state["ema_state_dict"] -------------------------------------------------------------------------------- /tbsim/utils/ftocp.py: -------------------------------------------------------------------------------- 1 | from casadi import * 2 | from numpy import * 3 | import itertools 4 | import numpy as np 5 | 6 | class FTOCP(object): 7 | """ Finite Time Optimal Control Problem (FTOCP) 8 | Methods: 9 | - solve: solves the FTOCP given the initial condition x0 and terminal contraints 10 | - buildNonlinearProgram: builds the nonlinear program solved by the above solve methos 11 | - model: given x_t and u_t computes x_{t+1} = Ax_t + Bu_t 12 | """ 13 | def __init__(self, N, M, dt,W,L,max_steer=0.5, max_yawvel=8, acce_bound=[-6,4],vbound=[-5.,40.]): 14 | # Define variables 15 | self.N = N 16 | self.dt = dt 17 | self.n = 4 18 | self.d = 2 19 | self.M = M 20 | self.xRef = None 21 | self.W = W 22 | self.L = L 23 | self.max_steer = max_steer 24 | self.max_yawvel = max_yawvel 25 | self.acce_bound = acce_bound 26 | self.obs = list() 27 | 28 | self.x_lb = [-np.inf,-np.inf,vbound[0],-2*np.pi] 29 | self.x_ub = [np.inf,np.inf,vbound[1],2*np.pi] 30 | 31 | self.u_lb = [self.acce_bound[0],-self.max_steer*vbound[1]] 32 | self.u_ub = [self.acce_bound[1],self.max_steer*vbound[1]] 33 | 34 | 35 | self.feasible = 0 36 | self.xPredOld =[] 37 | self.yPredOld =[] 38 | 39 | self.solOld =[] 40 | self.xGuessTot = None 41 | 42 | 43 | def buildandsolve(self,x0_val,ypreds, agent_extent, xdes,w): 44 | # Define variables 45 | n = self.n 46 | d = self.d 47 | N = self.N 48 | M = self.M 49 | 50 | Nnodes = ypreds.shape[0] 51 | 52 | 53 | X = SX.sym('X', n*(M*N+1)) 54 | x0 = X[0:n] 55 | xbr = [None]*M 56 | 57 | U = SX.sym('U', d*(M*(N-1)+1)) 58 | u0 = U[0:d] 59 | ubr = [None]*M 60 | for i in range(M): 61 | xbr[i] = X[(i*N+1)*n:(i*N+N+1)*n].reshape((n,N)).T 62 | ubr[i] = U[(i*(N-1)+1)*d:((i+1)*(N-1)+1)*d].reshape((d,N-1)).T 63 | slack0 = SX.sym('s', N*M) 64 | slack = slack0.reshape((N,M)).T 65 | 66 | # Define dynamic constraints 67 | constraint = list() 68 | for i in range(M): 69 | constraint = vertcat(constraint,xbr[i][0,0]-x0[0]-self.dt*x0[2]*casadi.cos(x0[3])) 70 | constraint = vertcat(constraint,xbr[i][0,1]-x0[1]-self.dt*x0[2]*casadi.sin(x0[3])) 71 | constraint = vertcat(constraint,xbr[i][0,2]-x0[2]-self.dt*u0[0]) 72 | constraint = vertcat(constraint,xbr[i][0,3]-x0[3]-self.dt*u0[1]) 73 | 74 | for j in range(N-1): 75 | constraint = vertcat(constraint,xbr[i][j+1,0]-xbr[i][j,0]-self.dt*xbr[i][j,2]*casadi.cos(xbr[i][j,3])) 76 | constraint = vertcat(constraint,xbr[i][j+1,1]-xbr[i][j,1]-self.dt*xbr[i][j,2]*casadi.sin(xbr[i][j,3])) 77 | constraint = vertcat(constraint,xbr[i][j+1,2]-xbr[i][j,2]-self.dt*ubr[i][j,0]) 78 | constraint = vertcat(constraint,xbr[i][j+1,3]-xbr[i][j,3]-self.dt*ubr[i][j,1]) 79 | dyn_constr_count = constraint.shape[0] 80 | 81 | for i in range(M): 82 | 83 | constraint = vertcat(constraint,ubr[i][0,1]-softmax(x0[2],1.)*self.max_steer) 84 | constraint = vertcat(constraint,-ubr[i][0,1]-softmax(x0[2],1.)*self.max_steer) 85 | constraint = vertcat(constraint,ubr[i][0,1]*x0[2]-self.max_yawvel) 86 | constraint = vertcat(constraint,-ubr[i][0,1]*x0[2]-self.max_yawvel) 87 | for j in range(N-1): 88 | constraint = vertcat(constraint,ubr[i][j,1]-softmax(xbr[i][j+1,2],1.)*self.max_steer) 89 | constraint = vertcat(constraint,-ubr[i][j,1]-softmax(xbr[i][j+1,2],1.)*self.max_steer) 90 | constraint = vertcat(constraint,ubr[i][j,1]*xbr[i][j+1,2]-self.max_yawvel) 91 | constraint = vertcat(constraint,-ubr[i][j,1]*xbr[i][j+1,2]-self.max_yawvel) 92 | 93 | ubound_constr_count = constraint.shape[0]-dyn_constr_count 94 | # Obstacle constraints 95 | 96 | if Nnodes>0: 97 | for i in range(M): 98 | for j in range(Nnodes): 99 | for k in range(N): 100 | constraint = vertcat(constraint, ( ( xbr[i][k,0] - ypreds[j][i][k,0] )**2/(self.L/1.414+agent_extent[j,0]/1.414)**2 + 101 | ( xbr[i][k,1] - ypreds[j][i][k,1] )**2/(self.W/1.414+agent_extent[j,1]/1.414)**2 + slack[i,k] ) ) 102 | 103 | collision_count = constraint.shape[0]-dyn_constr_count-ubound_constr_count 104 | 105 | # Defining Cost 106 | cost = 0 107 | cost_x = 1. 108 | cost_y = 5. 109 | cost_v = 1. 110 | cost_acc = 0.5 111 | cost_ste = 2.0 112 | cost_slack = 1e6 113 | cost_R = DM([cost_acc,cost_ste]) 114 | cost_Q = DM([cost_x,cost_y,cost_v]) 115 | cost = sum1(u0**2*cost_R) 116 | for i in range(M): 117 | for k in range(N-1): 118 | 119 | cost+=(sum1((xbr[i][k,:3].T-xdes[k][:3])**2*cost_Q)+sum1(ubr[i][k,:].T**2*cost_R)+slack[i,k]*cost_slack)*w[i] 120 | 121 | cost+= (sum1((xbr[i][N-1,:3].T-xdes[N-1][:3])**2*cost_Q)+slack[i,N-1]*cost_slack)*w[i] 122 | 123 | 124 | 125 | # Set IPOPT options 126 | # opts = {"verbose":False,"ipopt.print_level":0,"print_time":0}#, "ipopt.acceptable_constr_viol_tol":0.001}#,"ipopt.acceptable_tol":1e-4}#, "expand":True} 127 | # opts = {"verbose":False,"ipopt.print_level":0,"print_time":0,"ipopt.mu_strategy":"adaptive"}#, "ipopt.acceptable_constr_viol_tol":0.001}#,"ipopt.acceptable_tol":1e-4}#, "expand":True} 128 | opts = {"verbose":False,"ipopt.print_level":0,"print_time":0,"ipopt.mu_strategy":"adaptive","ipopt.mu_init":1e-5,"ipopt.mu_min":1e-15,"ipopt.barrier_tol_factor":1}#, "ipopt.acceptable_constr_viol_tol":0.001}#,"ipopt.acceptable_tol":1e-4}#, "expand":True} 129 | nlp = {'x':vertcat(X,U, slack0), 'f':cost, 'g':constraint} 130 | self.solver = nlpsol('solver', 'ipopt', nlp, opts) 131 | 132 | # Set lower bound of inequality constraint to zero to force: 1) n*N state dynamics and 2) inequality constraints (set to zero as we have slack) 133 | self.lbg_dyanmics = [0]*dyn_constr_count + [-10000]*ubound_constr_count + [1]*collision_count 134 | self.ubg_dyanmics = [0]*dyn_constr_count + [0]*ubound_constr_count + [10000]*collision_count 135 | 136 | self.lbx = x0_val.tolist() + self.x_lb*(N*M) + self.u_lb*(M*(N-1)+1) + [0]*(N*M) 137 | self.ubx = x0_val.tolist() + self.x_ub*(N*M) + self.u_ub*(M*(N-1)+1) + [np.inf]*(N*M) 138 | if self.xGuessTot is not None and self.xGuessTot.shape[0]==nlp['x'].shape[0]: 139 | sol = self.solver(lbx=self.lbx, ubx=self.ubx, lbg=self.lbg_dyanmics, ubg=self.ubg_dyanmics, x0 = self.xGuessTot) 140 | 141 | else: 142 | sol = self.solver(lbx=self.lbx, ubx=self.ubx, lbg=self.lbg_dyanmics, ubg=self.ubg_dyanmics) 143 | # sol = self.solver(lbx=self.lbx, ubx=self.ubx, lbg=self.lbg_dyanmics, ubg=self.ubg_dyanmics) 144 | # Check solution flag 145 | if (self.solver.stats()['success']): 146 | self.feasible = 1 147 | else: 148 | sol = self.solver(lbx=self.lbx, ubx=self.ubx, lbg=self.lbg_dyanmics, ubg=self.ubg_dyanmics) 149 | 150 | 151 | 152 | # Store optimal solution 153 | x = np.array(sol["x"]) 154 | self.xSol = x[0:n*(M*N+1)].reshape((M*N+1,n)) 155 | self.uSol = x[n*(M*N+1):n*(M*N+1)+d*(M*(N-1)+1)].reshape((M*(N-1)+1,d)) 156 | self.slack = x[n*(M*N+1)+d*(M*(N-1)+1):] 157 | 158 | self.xGuessTot = x 159 | # Check solution flag 160 | if (self.solver.stats()['success']): 161 | self.feasible = 1 162 | else: 163 | self.feasible = 0 164 | 165 | 166 | def softmax(x,y,gamma=10): 167 | return (exp(x*gamma)*x+exp(y*gamma)*y)/(exp(x*gamma)+exp(y*gamma)) -------------------------------------------------------------------------------- /tbsim/utils/lane_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tbsim.utils.geometry_utils as GeoUtils 3 | import tbsim.utils.tensor_utils as TensorUtils 4 | 5 | 6 | def get_edge(lane,dir,W=2.0,num_pts = None): 7 | 8 | if dir == "L": 9 | 10 | if lane.left_edge is not None: 11 | lane.left_edge = lane.left_edge.interpolate(num_pts) 12 | xy = lane.left_edge.xy 13 | if lane.left_edge.has_heading: 14 | h = lane.left_edge.h 15 | else: 16 | dxy = xy[1:]-xy[:-1] 17 | h = GeoUtils.round_2pi(np.arctan2(dxy[:,1],dxy[:,0])) 18 | h = np.hstack((h,h[-1])) 19 | else: 20 | lane.center = lane.center.interpolate(num_pts) 21 | angle = lane.center.h+np.pi/2 22 | offset = np.stack([W*np.cos(angle),W*np.sin(angle)],-1) 23 | xy = lane.center.xy+offset 24 | h = lane.center.h 25 | elif dir == "R": 26 | 27 | if lane.right_edge is not None: 28 | lane.right_edge = lane.right_edge.interpolate(num_pts) 29 | xy = lane.right_edge.xy 30 | if lane.right_edge.has_heading: 31 | h = lane.right_edge.h 32 | else: 33 | dxy = xy[1:]-xy[:-1] 34 | h = GeoUtils.round_2pi(np.arctan2(dxy[:,1],dxy[:,0])) 35 | h = np.hstack((h,h[-1])) 36 | else: 37 | lane.center = lane.center.interpolate(num_pts) 38 | angle = lane.center.h-np.pi/2 39 | offset = np.stack([W*np.cos(angle),W*np.sin(angle)],-1) 40 | xy = lane.center.xy+offset 41 | h = lane.center.h 42 | elif dir =="C": 43 | lane.center = lane.center.interpolate(num_pts) 44 | xy = lane.center.xy 45 | if lane.center.has_heading: 46 | h = lane.center.h 47 | else: 48 | dxy = xy[1:]-xy[:-1] 49 | h = GeoUtils.round_2pi(np.arctan2(dxy[:,1],dxy[:,0])) 50 | h = np.hstack((h,h[-1])) 51 | return xy,h 52 | 53 | def get_bdry_xyh(lane1,lane2=None,dir="L",W=3.6,num_pts = 25): 54 | if lane2 is None: 55 | xy,h = get_edge(lane1,dir,W,num_pts*2) 56 | else: 57 | xy1,h1 = get_edge(lane1,dir,W,num_pts) 58 | xy2,h2 = get_edge(lane2,dir,W,num_pts) 59 | xy = np.concatenate((xy1,xy2),0) 60 | h = np.concatenate((h1,h2),0) 61 | return xy,h 62 | -------------------------------------------------------------------------------- /tbsim/utils/log_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains utility classes and functions for logging to stdout, stderr, 3 | and to tensorboard. 4 | """ 5 | import sys 6 | 7 | 8 | class PrintLogger(object): 9 | """ 10 | This class redirects print statements to both console and a file. 11 | """ 12 | def __init__(self, log_file): 13 | self.terminal = sys.stdout 14 | print('STDOUT will be forked to %s' % log_file) 15 | self.log_file = open(log_file, "a") 16 | 17 | def write(self, message): 18 | self.terminal.write(message) 19 | self.log_file.write(message) 20 | self.log_file.flush() 21 | 22 | def flush(self): 23 | # this flush method is needed for python 3 compatibility. 24 | # this handles the flush command by doing nothing. 25 | # you might want to specify some extra behavior here. 26 | pass 27 | -------------------------------------------------------------------------------- /tbsim/utils/timer.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import numpy as np 4 | from contextlib import contextmanager 5 | 6 | 7 | class Timer(object): 8 | """A simple timer.""" 9 | def __init__(self): 10 | self.total_time = 0. 11 | self.calls = 0 12 | self.start_time = 0. 13 | self.diff = 0. 14 | self.average_time = 0. 15 | self.times = [] 16 | 17 | def recent_average_time(self, latest_n): 18 | return np.mean(np.array(self.times)[-latest_n:]) 19 | 20 | def tic(self): 21 | # using time.time instead of time.clock because time time.clock 22 | # does not normalize for multithreading 23 | self.start_time = time.time() 24 | 25 | def toc(self, average=True): 26 | self.diff = time.time() - self.start_time 27 | self.times.append(self.diff) 28 | self.total_time += self.diff 29 | self.calls += 1 30 | self.average_time = self.total_time / self.calls 31 | if average: 32 | return self.average_time 33 | else: 34 | return self.diff 35 | 36 | @contextmanager 37 | def timed(self): 38 | self.tic() 39 | yield 40 | self.toc() 41 | 42 | 43 | class Timers(object): 44 | def __init__(self): 45 | self._timers = {} 46 | 47 | def tic(self, key): 48 | if key not in self._timers: 49 | self._timers[key] = Timer() 50 | self._timers[key].tic() 51 | 52 | def toc(self, key): 53 | self._timers[key].toc() 54 | 55 | @contextmanager 56 | def timed(self, key): 57 | self.tic(key) 58 | yield 59 | self.toc(key) 60 | 61 | def __str__(self): 62 | msg = [] 63 | for k, v in self._timers.items(): 64 | msg.append('%s: %f' % (k, v.average_time)) 65 | return ', '.join(msg) -------------------------------------------------------------------------------- /tbsim/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains some PyTorch utilities. 3 | """ 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.optim as optim 8 | import functools 9 | from tqdm.auto import tqdm 10 | 11 | 12 | def soft_update(source, target, tau): 13 | """ 14 | Soft update from the parameters of a @source torch module to a @target torch module 15 | with strength @tau. The update follows target = target * (1 - tau) + source * tau. 16 | 17 | Args: 18 | source (torch.nn.Module): source network to push target network parameters towards 19 | target (torch.nn.Module): target network to update 20 | """ 21 | for target_param, param in zip(target.parameters(), source.parameters()): 22 | target_param.copy_(target_param * (1.0 - tau) + param * tau) 23 | 24 | 25 | def hard_update(source, target): 26 | """ 27 | Hard update @target parameters to match @source. 28 | 29 | Args: 30 | source (torch.nn.Module): source network to provide parameters 31 | target (torch.nn.Module): target network to update parameters for 32 | """ 33 | for target_param, param in zip(target.parameters(), source.parameters()): 34 | target_param.copy_(param) 35 | 36 | 37 | def get_torch_device(try_to_use_cuda): 38 | """ 39 | Return torch device. If using cuda (GPU), will also set cudnn.benchmark to True 40 | to optimize CNNs. 41 | 42 | Args: 43 | try_to_use_cuda (bool): if True and cuda is available, will use GPU 44 | 45 | Returns: 46 | device (torch.Device): device to use for models 47 | """ 48 | if try_to_use_cuda and torch.cuda.is_available(): 49 | torch.backends.cudnn.benchmark = True 50 | device = torch.device("cuda:0") 51 | else: 52 | device = torch.device("cpu") 53 | return device 54 | 55 | 56 | def reparameterize(mu, logvar): 57 | """ 58 | Reparameterize for the backpropagation of z instead of q. 59 | This makes it so that we can backpropagate through the sampling of z from 60 | our encoder when feeding the sampled variable to the decoder. 61 | 62 | (See "The reparameterization trick" section of https://arxiv.org/abs/1312.6114) 63 | 64 | Args: 65 | mu (torch.Tensor): batch of means from the encoder distribution 66 | logvar (torch.Tensor): batch of log variances from the encoder distribution 67 | 68 | Returns: 69 | z (torch.Tensor): batch of sampled latents from the encoder distribution that 70 | support backpropagation 71 | """ 72 | # logvar = \log(\sigma^2) = 2 * \log(\sigma) 73 | # \sigma = \exp(0.5 * logvar) 74 | 75 | # clamped for numerical stability 76 | logstd = (0.5 * logvar).clamp(-4, 15) 77 | std = torch.exp(logstd) 78 | 79 | # Sample \epsilon from normal distribution 80 | # use std to create a new tensor, so we don't have to care 81 | # about running on GPU or not 82 | eps = std.new(std.size()).normal_() 83 | 84 | # Then multiply with the standard deviation and add the mean 85 | z = eps.mul(std).add_(mu) 86 | 87 | return z 88 | 89 | 90 | def optimizer_from_optim_params(net_optim_params, net): 91 | """ 92 | Helper function to return a torch Optimizer from the optim_params 93 | section of the config for a particular network. 94 | 95 | Args: 96 | optim_params (Config): optim_params part of algo_config corresponding 97 | to @net. This determines the optimizer that is created. 98 | 99 | net (torch.nn.Module): module whose parameters this optimizer will be 100 | responsible 101 | 102 | Returns: 103 | optimizer (torch.optim.Optimizer): optimizer 104 | """ 105 | return optim.Adam( 106 | params=net.parameters(), 107 | lr=net_optim_params["learning_rate"]["initial"], 108 | weight_decay=net_optim_params["regularization"]["L2"], 109 | ) 110 | 111 | 112 | def lr_scheduler_from_optim_params(net_optim_params, net, optimizer): 113 | """ 114 | Helper function to return a LRScheduler from the optim_params 115 | section of the config for a particular network. Returns None 116 | if a scheduler is not needed. 117 | 118 | Args: 119 | optim_params (Config): optim_params part of algo_config corresponding 120 | to @net. This determines whether a learning rate scheduler is created. 121 | 122 | net (torch.nn.Module): module whose parameters this optimizer will be 123 | responsible 124 | 125 | optimizer (torch.optim.Optimizer): optimizer for this net 126 | 127 | Returns: 128 | lr_scheduler (torch.optim.lr_scheduler or None): learning rate scheduler 129 | """ 130 | lr_scheduler = None 131 | if len(net_optim_params["learning_rate"]["epoch_schedule"]) > 0: 132 | # decay LR according to the epoch schedule 133 | lr_scheduler = optim.lr_scheduler.MultiStepLR( 134 | optimizer=optimizer, 135 | milestones=net_optim_params["learning_rate"]["epoch_schedule"], 136 | gamma=net_optim_params["learning_rate"]["decay_factor"], 137 | ) 138 | return lr_scheduler 139 | 140 | 141 | def backprop_for_loss(net, optim, loss, max_grad_norm=None, retain_graph=False): 142 | """ 143 | Backpropagate loss and update parameters for network with 144 | name @name. 145 | 146 | Args: 147 | net (torch.nn.Module): network to update 148 | 149 | optim (torch.optim.Optimizer): optimizer to use 150 | 151 | loss (torch.Tensor): loss to use for backpropagation 152 | 153 | max_grad_norm (float): if provided, used to clip gradients 154 | 155 | retain_graph (bool): if True, graph is not freed after backward call 156 | 157 | Returns: 158 | grad_norms (float): average gradient norms from backpropagation 159 | """ 160 | 161 | # backprop 162 | optim.zero_grad() 163 | loss.backward(retain_graph=retain_graph) 164 | 165 | # gradient clipping 166 | if max_grad_norm is not None: 167 | torch.nn.utils.clip_grad_norm_(net.parameters(), max_grad_norm) 168 | 169 | # compute grad norms 170 | grad_norms = 0.0 171 | for p in net.parameters(): 172 | # only clip gradients for parameters for which requires_grad is True 173 | if p.grad is not None: 174 | grad_norms += p.grad.data.norm(2).pow(2).item() 175 | 176 | # step 177 | optim.step() 178 | 179 | return grad_norms 180 | 181 | 182 | class dummy_context_mgr: 183 | """ 184 | A dummy context manager - useful for having conditional scopes (such 185 | as @maybe_no_grad). Nothing happens in this scope. 186 | """ 187 | 188 | def __enter__(self): 189 | return None 190 | 191 | def __exit__(self, exc_type, exc_value, traceback): 192 | return False 193 | 194 | 195 | def maybe_no_grad(no_grad): 196 | """ 197 | Args: 198 | no_grad (bool): if True, the returned context will be torch.no_grad(), otherwise 199 | it will be a dummy context 200 | """ 201 | return torch.no_grad() if no_grad else dummy_context_mgr() 202 | 203 | 204 | def rgetattr(obj, attr, *args): 205 | "recursively get attributes" 206 | 207 | def _getattr(obj, attr): 208 | return getattr(obj, attr, *args) 209 | 210 | return functools.reduce(_getattr, [obj] + attr.split(".")) 211 | 212 | 213 | def rsetattr(obj, attr, val): 214 | "recursively set attributes" 215 | pre, _, post = attr.rpartition(".") 216 | return setattr(rgetattr(obj, pre) if pre else obj, post, val) 217 | 218 | 219 | class ProgressBar(pl.Callback): 220 | def __init__( 221 | self, global_progress: bool = True, leave_global_progress: bool = True 222 | ): 223 | super().__init__() 224 | 225 | self.global_progress = global_progress 226 | self.global_desc = "Epoch: {epoch}/{max_epoch}" 227 | self.leave_global_progress = leave_global_progress 228 | self.global_pb = None 229 | 230 | def on_fit_start(self, trainer, pl_module): 231 | desc = self.global_desc.format( 232 | epoch=trainer.current_epoch + 1, max_epoch=trainer.max_epochs 233 | ) 234 | 235 | self.global_pb = tqdm( 236 | desc=desc, 237 | total=trainer.max_epochs, 238 | initial=trainer.current_epoch, 239 | leave=self.leave_global_progress, 240 | disable=not self.global_progress, 241 | ) 242 | 243 | def on_fit_end(self, trainer, pl_module): 244 | self.global_pb.close() 245 | self.global_pb = None 246 | 247 | def on_epoch_end(self, trainer, pl_module): 248 | 249 | # Set description 250 | desc = self.global_desc.format( 251 | epoch=trainer.current_epoch + 1, max_epoch=trainer.max_epochs 252 | ) 253 | self.global_pb.set_description(desc) 254 | 255 | # Set logs and metrics 256 | # logs = pl_module.logs 257 | # for k, v in logs.items(): 258 | # if isinstance(v, torch.Tensor): 259 | # logs[k] = v.squeeze().item() 260 | # self.global_pb.set_postfix(logs) 261 | 262 | # Update progress 263 | self.global_pb.update(1) 264 | -------------------------------------------------------------------------------- /tbsim/utils/train_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains several utility functions used to define the main training loop. It 3 | mainly consists of functions to assist with logging, rollouts, and the @run_epoch function, 4 | which is the core training logic for models in this repository. 5 | """ 6 | import os 7 | import socket 8 | import shutil 9 | 10 | 11 | def infinite_iter(data_loader): 12 | """ 13 | Get an infinite generator 14 | Args: 15 | data_loader (DataLoader): data loader to iterate through 16 | 17 | """ 18 | c_iter = iter(data_loader) 19 | while True: 20 | try: 21 | data = next(c_iter) 22 | except StopIteration: 23 | c_iter = iter(data_loader) 24 | data = next(c_iter) 25 | yield data 26 | 27 | 28 | def get_exp_dir(exp_name, output_dir, save_checkpoints=True, auto_remove_exp_dir=False): 29 | """ 30 | Create experiment directory from config. If an identical experiment directory 31 | exists and @auto_remove_exp_dir is False (default), the function will prompt 32 | the user on whether to remove and replace it, or keep the existing one and 33 | add a new subdirectory with the new timestamp for the current run. 34 | 35 | Args: 36 | exp_name (str): name of the experiment 37 | output_dir (str): output directory of the experiment 38 | save_checkpoints (bool): if save checkpoints 39 | auto_remove_exp_dir (bool): if True, automatically remove the existing experiment 40 | folder if it exists at the same path. 41 | 42 | Returns: 43 | log_dir (str): path to created log directory (sub-folder in experiment directory) 44 | output_dir (str): path to created models directory (sub-folder in experiment directory) 45 | to store model checkpoints 46 | video_dir (str): path to video directory (sub-folder in experiment directory) 47 | to store rollout videos 48 | """ 49 | 50 | # create directory for where to dump model parameters, tensorboard logs, and videos 51 | base_output_dir = output_dir 52 | if not os.path.isabs(base_output_dir): 53 | base_output_dir = os.path.abspath(base_output_dir) 54 | base_output_dir = os.path.join(base_output_dir, exp_name) 55 | if os.path.exists(base_output_dir): 56 | if not auto_remove_exp_dir: 57 | ans = input( 58 | "WARNING: model directory ({}) already exists! \noverwrite? (y/n)\n".format( 59 | base_output_dir 60 | ) 61 | ) 62 | else: 63 | ans = "y" 64 | if ans == "y": 65 | print("REMOVING") 66 | shutil.rmtree(base_output_dir) 67 | os.makedirs(base_output_dir, exist_ok=True) 68 | 69 | # version the run 70 | existing_runs = [ 71 | a 72 | for a in os.listdir(base_output_dir) 73 | if os.path.isdir(os.path.join(base_output_dir, a)) 74 | ] 75 | run_counts = [-1] 76 | for ep in existing_runs: 77 | m = ep.split("run") 78 | if len(m) == 2 and m[0] == "": 79 | if m[1].isnumeric(): 80 | run_counts.append(int(m[1])) 81 | version_str = "run{}".format(max(run_counts) + 1) 82 | 83 | # only make model directory if model saving is enabled 84 | ckpt_dir = None 85 | if save_checkpoints: 86 | ckpt_dir = os.path.join(base_output_dir, version_str, "checkpoints") 87 | os.makedirs(ckpt_dir) 88 | 89 | # tensorboard directory 90 | log_dir = os.path.join(base_output_dir, version_str, "logs") 91 | os.makedirs(log_dir) 92 | 93 | # video directory 94 | video_dir = os.path.join(base_output_dir, version_str, "videos") 95 | os.makedirs(video_dir) 96 | return base_output_dir, log_dir, ckpt_dir, video_dir, version_str 97 | --------------------------------------------------------------------------------