├── .github └── FUNDING.yml ├── LICENSE ├── README.md ├── assets ├── overview.png └── results.png ├── carformer ├── .gitignore ├── Makefile ├── carformer │ ├── __init__.py │ ├── config │ │ ├── __init__.py │ │ ├── backbone │ │ │ ├── llama-base.yaml │ │ │ ├── llama-micro.yaml │ │ │ └── llama.yaml │ │ ├── carformer_config.py │ │ ├── config.yaml │ │ ├── dataset │ │ │ └── b2d.yaml │ │ ├── deepspeed │ │ │ ├── basic.yaml │ │ │ ├── default.yaml │ │ │ ├── disable.yaml │ │ │ ├── zero2.yaml │ │ │ └── zero3.yaml │ │ ├── experiments │ │ │ ├── init_rgb_from_ckpt.yaml │ │ │ ├── onlyfrc.yaml │ │ │ ├── overfit.yaml │ │ │ └── use_light_encoder.yaml │ │ ├── hyperparams │ │ │ ├── default.yaml │ │ │ ├── optimizer │ │ │ │ ├── adam.yaml │ │ │ │ ├── adamw.yaml │ │ │ │ └── sgd.yaml │ │ │ └── scheduler │ │ │ │ ├── cosine_annealing.yaml │ │ │ │ └── linear_warmup.yaml │ │ ├── logging │ │ │ └── wandb.yaml │ │ ├── training │ │ │ ├── action │ │ │ │ ├── path.yaml │ │ │ │ └── waypoints.yaml │ │ │ ├── bev │ │ │ │ ├── rgb.yaml │ │ │ │ └── rgb_backbone │ │ │ │ │ ├── internvl2-4b.yaml │ │ │ │ │ ├── internvl2-76b.yaml │ │ │ │ │ ├── internvl2pt5-8b.yaml │ │ │ │ │ └── llava1pt6.yaml │ │ │ ├── default.yaml │ │ │ ├── goal │ │ │ │ ├── default_goal.yaml │ │ │ │ ├── dual_target_point.yaml │ │ │ │ └── target_point.yaml │ │ │ ├── quantized.yaml │ │ │ ├── reward │ │ │ │ └── reward.yaml │ │ │ └── state │ │ │ │ ├── lights.yaml │ │ │ │ └── speed.yaml │ │ └── user │ │ │ └── example.yaml │ ├── data │ │ ├── __init__.py │ │ ├── data.py │ │ ├── data_parser.py │ │ ├── data_utils.py │ │ ├── utils.py │ │ └── wrapper.py │ ├── perception │ │ ├── __init__.py │ │ └── rgb.py │ ├── ponderer.py │ ├── ponderer_lit.py │ ├── utils │ │ ├── __init__.py │ │ ├── distributed.py │ │ ├── distributedsampler.py │ │ └── utils.py │ └── visualization │ │ └── visutils.py ├── requirements.txt ├── setup.py └── train.py ├── docs └── TRAIN_EVAL.md ├── misc ├── leaderboard_evaluator_local.py └── run_eval_leaderboard.py ├── requirements.txt └── team_code ├── config ├── config.yaml ├── eval │ └── b2d.yaml ├── experiments │ └── eta.yaml └── user │ └── example.yaml ├── eta_agent.py └── planner.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 3 | patreon: # Replace with a single Patreon username 4 | open_collective: # Replace with a single Open Collective username 5 | ko_fi: # Replace with a single Ko-fi username 6 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 7 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 8 | liberapay: # Replace with a single Liberapay username 9 | issuehunt: # Replace with a single IssueHunt username 10 | otechie: # Replace with a single Otechie username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Shadi Hamdan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🤔 ETA 2 | 3 | ## Highlight 4 | We propose "**E**fficiency through **T**hinking **A**head" (ETA), an asynchronous dual-system that pre-processes information from past frames using a large model in tandem with processing the current information with a small model to enable real-time decisions with strong performance. 5 | CarFormer overview 6 | 7 | ## News 8 | - **`[2025/06/10]`** [ETA](https://arxiv.org/abs/2506.07725) paper and code release! 9 | 10 | ## Results 11 | CarFormer overview 12 | 13 | 14 | ## Table of Contents 15 | 1. [Highlights](#highlight) 16 | 2. [News](#news) 17 | 3. [Results](#results) 18 | 2. [Getting Started](#gettingstarted) 19 | - [Training](docs/TRAIN_EVAL.md#trainingsetup) 20 | - [Evaluation](docs/TRAIN_EVAL.md#evalsetup) 21 | 4. [TODO List](#todolist) 22 | 6. [License and Citation](#licenseandcitation) 23 | 7. [Other Resources](#otherresources) 24 | 25 | ## Getting Started 26 | 27 | To get started with ETA: 28 | - Training 29 | - [Download Bench2Drive Data](docs/TRAIN_EVAL.md#b2ddata) 30 | - [Setup other prerequisites](docs/TRAIN_EVAL.md#trainingsetup) 31 | - [Start training](docs/TRAIN_EVAL.md#training) 32 | - Evaluation 33 | - [Setup Bench2Drive](docs/TRAIN_EVAL.md#evalsetup) 34 | - [Setup files and configs](docs/TRAIN_EVAL.md#evalfilesetup) 35 | - [Download checkpoints](docs/TRAIN_EVAL.md#evalcheckpoints) 36 | - [Run evaluation](docs/TRAIN_EVAL.md#runeval) 37 | 38 | ## TODO List 39 | - [x] ETA Training code 40 | - [x] ETA Evaluation 41 | - [x] Inference Code 42 | - [ ] Checkpoints 43 | 44 | ## Acknowledgements 45 | 46 | This codebase builds on open sourced code from [CARLA Garage](git@github.com:autonomousvision/carla_garage.git) and [Bench2DriveZoo](https://github.com/Thinklab-SJTU/Bench2DriveZoo/) among others. We thank the authors for their contributions. This project is funded by the European Union (ERC, ENSURE, 101116486) with additional compute support from Leonardo Booster (EuroHPC Joint Undertaking, EHPC-AI-2024A01-060). Views and opinions expressed are however those of the author(s) only and do not necessarily reflect those of the European Union or the European Research Council. Neither the European Union nor the granting authority can be held responsible for them. This study is also supported by National Natural Science Foundation of China (62206172) and Shanghai Committee of Science and Technology (23YF1462000). 47 | 48 | ## License and Citation 49 | This project is released under the [MIT License](LICENSE). If you find our project useful for your research, please consider citing our paper with the following BibTeX: 50 | 51 | 52 | ```bibtex 53 | @article{hamdan2025eta, 54 | title={ETA: Efficiency through Thinking Ahead, A Dual Approach to Self-Driving with Large Models}, 55 | author={Hamdan, Shadi and Sima, Chonghao and Yang, Zetong and Li, Hongyang and G{\"u}ney, Fatma}, 56 | journal={arXiv preprint arXiv:2506.07725}, 57 | year={2025} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/assets/overview.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/assets/results.png -------------------------------------------------------------------------------- /carformer/.gitignore: -------------------------------------------------------------------------------- 1 | # Carformer visualizations 2 | visualizations/**/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | -------------------------------------------------------------------------------- /carformer/Makefile: -------------------------------------------------------------------------------- 1 | # By default echo the help message 2 | .DEFAULT_GOAL := help 3 | 4 | help: ## Display this help message 5 | @echo "Usage: make [target]" 6 | @echo "" 7 | @echo "Targets:" 8 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z0-9_-]+:.*?## / {printf " %-20s %s\n", $$1, $$2}' $(MAKEFILE_LIST) 9 | 10 | ETA_base_model_s42: 11 | OMP_NUM_THREADS=6 python train.py deepspeed=zero2 hyperparams.batch_size=5 hyperparams.num_epochs=40 user=example wandb_tag=$@ dataset.dataset_path_rel=B2D-base gpus=4 nodes=8 training.parallel_dataset_init=False backbone=llama-base training.weighted_sampling=True training/goal=dual_target_point training.normalize_goal=True training.bucket_weights.type=preferturns training.rgb_crop.crop_size=896 start_saving_epoch=30 seed=42 12 | 13 | ETA_async_model_s42: 14 | OMP_NUM_THREADS=6 python train.py deepspeed=zero2 hyperparams.batch_size=5 hyperparams.num_epochs=40 user=example wandb_tag=$@ dataset.dataset_path_rel=B2D-base gpus=4 nodes=8 training.parallel_dataset_init=False training.weighted_sampling=True training/goal=dual_target_point training.normalize_goal=True training.bucket_weights.type=preferturns training.use_future_vehicle_forecast=True training.use_predicted_latent_with_gap=True training.action_gap=1 training.future_horizon=1 training.ema_enabled=False +experiments=[use_light_encoder] training.ignore_past_for_length=True training.rgb_crop.crop_size=896 seed=42 15 | -------------------------------------------------------------------------------- /carformer/carformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/carformer/carformer/__init__.py -------------------------------------------------------------------------------- /carformer/carformer/config/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from .carformer_config import CarformerConfig 4 | 5 | _filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.,=-]") 6 | 7 | 8 | # Adapted from https://tedboy.github.io/flask/_modules/werkzeug/utils.html#secure_filename 9 | # Changed in order to work without importing flask 10 | def sanitize_shorten_ckppath(path): 11 | path = path.replace("use_light_encoder", "ltenc") 12 | # ema_every_steps = getattr(self.ponderer.config.training, "ema_every_steps", 1) 13 | # ema_start = getattr(self.ponderer.config.training, "ema_start", 0) 14 | # ema_end_epoch = getattr(self.ponderer.config.training, "ema_end_epoch", -1) 15 | path = path.replace("gen_mask_for_action", "amsk") 16 | path = path.replace("turns", "trn") 17 | path = path.replace("optimizer", "opt") 18 | path = path.replace("ema_every_steps", "emaint") 19 | path = path.replace("ema_start", "emast") 20 | path = path.replace("ema_end_epoch", "emagnd") 21 | path = path.replace("backbone", "bb") 22 | path = path.replace("training.rgb_crop.crop_size", "trncrp") 23 | path = path.replace("training.ema_decay", "emadc") 24 | path = path.replace("training/", "") 25 | path = path.replace("bev/", "") 26 | path = path.replace("llava1pt6", "lv16") 27 | path = path.replace("rgb_front", "rgb") 28 | path = path.replace("enabled", "T") 29 | path = path.replace("deepspeed", "ds") 30 | path = path.replace("use_gt_frc_only", "gtfrconly") 31 | path = path.replace("combofrc", "cfrc") 32 | path = path.replace("ignore_past_for_length", "igpastln") 33 | path = path.replace("zero_out_frc_branch", "zfbrc") 34 | path = path.replace("use_gt_frc", "gtfrc") 35 | path = path.replace("use_past_horizon", "up") 36 | path = path.replace("backbone.", "") 37 | path = path.replace("optimizer.", "") 38 | path = path.replace("kwargs.", "") 39 | path = path.replace("weight_decay", "") 40 | path = path.replace("hyperparams.", "") 41 | path = path.replace("hyperparams/", "") 42 | path = path.replace("training.", "") 43 | path = path.replace("use_real_latent_ratio", "rltrtio") 44 | path = path.replace("past_horizon", "phrz") 45 | path = path.replace("llama-", "lma-") 46 | path = path.replace("micro", "mc") 47 | path = path.replace("mini", "mn") 48 | path = path.replace("future_horizon", "hrz") 49 | path = path.replace("future_horizon", "ftr") 50 | path = path.replace("ema_enabled", "ema") 51 | path = path.replace("use_predicted_latent_with_gap", "prdltnt") 52 | path = path.replace("bucket_weights.type", "bktype") 53 | path = path.replace("preferturns", "pturns") 54 | path = path.replace("prefer", "p") 55 | path = path.replace("+experiments", "exps") 56 | path = path.replace("num_epochs", "eps") 57 | path = path.replace("batch_size", "bs") 58 | path = path.replace("False", "F") 59 | path = path.replace("True", "T") 60 | path = path.replace("forecast_steps", "frc_steps") 61 | path = path.replace("loss_params.", "") 62 | path = path.replace("action", "actn") 63 | path = path.replace("forecast", "frc") 64 | path = path.replace("classification", "cls") 65 | path = path.replace("state", "stt") 66 | path = path.replace("dataset", "dts") 67 | path = path.replace("subsample_ratio", "smplrtio") 68 | path = path.replace("reconstruction", "rcns") 69 | path = path.replace("wandb_tag", "wnb") 70 | path = path.replace("gradient_accumulation_steps", "gacc") 71 | path = path.replace("dropout", "dp") 72 | path = path.replace("use_future_vehicle_forcast", "dofrc") 73 | path = path.replace("normalize_goal", "nrmgoal") 74 | path = path.replace("weighted_sampling", "wtsmpl") 75 | path = path.replace("use_future_vehicle_frc", "dofrc") 76 | path = path.replace("light_select_layer", "slclr") 77 | path = path.replace("use_light_encoder_backbone", "ltenc") 78 | path = path.replace("actn_gap", "agap") 79 | 80 | for sep in os.path.sep, os.path.altsep: 81 | if sep: 82 | path = path.replace(sep, " ") 83 | path = str(_filename_ascii_strip_re.sub("", "_".join(path.split()))).strip("._") 84 | path = path.replace("training_bev_rgb_backbonergb_backbone", "rgbbb") 85 | path = path.replace("training_goal", "gl") 86 | path = path.replace("dual_target_point", "2tp") 87 | path = path.replace("rgb_backbone", "rgbb") 88 | return path 89 | 90 | 91 | def config_init(): 92 | from dataclasses import dataclass 93 | 94 | from omegaconf import MISSING, OmegaConf 95 | 96 | import hydra 97 | from hydra.core.config_store import ConfigStore 98 | 99 | cs = ConfigStore.instance() 100 | 101 | def merge_keys(*cfg): 102 | all_keys = set() 103 | for c in cfg: 104 | all_keys.update([str(x) for x in c.keys()]) 105 | return "-".join(sorted(all_keys)) 106 | 107 | # If has key, return True, else return False 108 | def has_key(cfg, key): 109 | return key in cfg 110 | 111 | def get_key(cfg, key, *args): 112 | # If args is not empty, recurse after getting the key 113 | if len(args) > 0: 114 | if key in cfg: 115 | try: 116 | return get_key(cfg[key], *args) 117 | except KeyError: 118 | return get_key(cfg[key], "reward") 119 | raise KeyError(f"Key {args} not found in {cfg[key]}") 120 | else: 121 | raise KeyError(f"Key {key} not found in {cfg}") 122 | 123 | if key in cfg: 124 | return cfg[key] 125 | else: 126 | raise KeyError(f"Key {key} not found in {cfg}") 127 | 128 | def resolve_quantizer_path(keys, data_format, quantizer_dict): 129 | # quantizer_key = "plant" if plant_data else "legacy" 130 | if data_format == "plant": 131 | quantizer_key = "plant" 132 | else: 133 | quantizer_key = "legacy" 134 | return get_key(quantizer_dict, quantizer_key, keys) 135 | 136 | OmegaConf.register_new_resolver("merge_keys", merge_keys) 137 | 138 | OmegaConf.register_new_resolver("has_key", has_key) 139 | 140 | OmegaConf.register_new_resolver("get_key", get_key) 141 | OmegaConf.register_new_resolver("eval", eval) 142 | 143 | bool_to_str = lambda x: "true" if x else "false" 144 | OmegaConf.register_new_resolver("bool_to_str", bool_to_str) 145 | OmegaConf.register_new_resolver("resolve_quantizer_path", resolve_quantizer_path) 146 | OmegaConf.register_new_resolver("sanitize", sanitize_shorten_ckppath) 147 | -------------------------------------------------------------------------------- /carformer/carformer/config/backbone/llama-base.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - llama 3 | architectures: 4 | - LlamaForCausalLM 5 | hidden_size: 768 6 | intermediate_size: 2048 7 | max_position_embeddings: 2048 8 | num_attention_heads: 16 9 | num_hidden_layers: 36 10 | vocab_size: 2 11 | init_from_lm_ckp: false 12 | init_name_or_path: -------------------------------------------------------------------------------- /carformer/carformer/config/backbone/llama-micro.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - llama 3 | architectures: 4 | - LlamaForCausalLM 5 | hidden_size: 768 6 | intermediate_size: 2048 7 | max_position_embeddings: 2048 8 | num_attention_heads: 16 9 | num_hidden_layers: 12 10 | vocab_size: 2 11 | init_from_lm_ckp: false 12 | init_name_or_path: -------------------------------------------------------------------------------- /carformer/carformer/config/backbone/llama.yaml: -------------------------------------------------------------------------------- 1 | architectures: 2 | - LlamaForCausalLM 3 | bos_token_id: 1 4 | eos_token_id: 2 5 | hidden_act: silu 6 | hidden_size: 2048 7 | initializer_range: 0.02 8 | intermediate_size: 5632 9 | max_position_embeddings: 2048 10 | model_type: llama 11 | num_attention_heads: 32 12 | num_hidden_layers: 22 13 | num_key_value_heads: 4 14 | pretraining_tp: 1 15 | rms_norm_eps: 1.0e-05 16 | rope_scaling: 17 | tie_word_embeddings: false 18 | torch_dtype: float32 19 | use_cache: true 20 | vocab_size: 32000 21 | attention_dropout: 0.0 -------------------------------------------------------------------------------- /carformer/carformer/config/carformer_config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ CarFormer configuration, adapted from GPT-2 configuration """ 17 | from collections import OrderedDict 18 | from typing import Any, List, Mapping, Optional 19 | 20 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available 21 | 22 | from transformers.configuration_utils import PretrainedConfig 23 | from transformers.utils import logging 24 | from transformers import CONFIG_MAPPING 25 | 26 | import os 27 | import yaml 28 | from typing import Union 29 | from transformers import AutoConfig 30 | 31 | logger = logging.get_logger(__name__) 32 | 33 | 34 | class CarformerConfig(PretrainedConfig): 35 | """ 36 | This is the configuration class to store the configuration of a [`Wanderer`] model or [`Ponderer`] model. It is used to 37 | instantiate the model according to the specified arguments, defining the model architecture. 38 | 39 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 40 | documentation from [`PretrainedConfig`] for more information. 41 | 42 | Args: 43 | vocab_size (`int`, *optional*, defaults to 50257): 44 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the 45 | `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`]. 46 | n_positions (`int`, *optional*, defaults to 1024): 47 | The maximum sequence length that this model might ever be used with. Typically set this to something large 48 | just in case (e.g., 512 or 1024 or 2048). 49 | n_embd (`int`, *optional*, defaults to 768): 50 | Dimensionality of the embeddings and hidden states. 51 | n_layer (`int`, *optional*, defaults to 12): 52 | Number of hidden layers in the Transformer encoder. 53 | n_head (`int`, *optional*, defaults to 12): 54 | Number of attention heads for each attention layer in the Transformer encoder. 55 | n_inner (`int`, *optional*, defaults to None): 56 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd 57 | activation_function (`str`, *optional*, defaults to `"gelu"`): 58 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`. 59 | resid_pdrop (`float`, *optional*, defaults to 0.1): 60 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. 61 | embd_pdrop (`int`, *optionabackbone_configl*, defaults to 0.1): 62 | The dropout ratio for the embeddings. 63 | attn_pdrop (`float`, *optional*, defaults to 0.1): 64 | The dropout ratio for the attention. 65 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5): 66 | The epsilon to use in the layer normalization layers. 67 | initializer_range (`float`, *optional*, defaults to 0.02): 68 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 69 | summary_type (`string`, *optional*, defaults to `"cls_index"`): 70 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 71 | [`TFGPT2DoubleHeadsModel`]. 72 | 73 | Has to be one of the following options: 74 | 75 | - `"last"`: Take the last token hidden state (like XLNet). 76 | - `"first"`: Take the first token hidden state (like BERT). 77 | - `"mean"`: Take the mean of all tokens hidden states. 78 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2). 79 | - `"attn"`: Not implemented now, use multi-head attention. 80 | summary_use_proj (`bool`, *optional*, defaults to `True`): 81 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 82 | [`TFGPT2DoubleHeadsModel`]. 83 | 84 | Whether or not to add a projection after the vector extraction. 85 | summary_activation (`str`, *optional*): 86 | Argument used when doing sequence summary. Used in for the multiple choice head in 87 | [`GPT2DoubleHeadsModel`]. 88 | 89 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation. 90 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`): 91 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 92 | [`TFGPT2DoubleHeadsModel`]. 93 | 94 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes. 95 | summary_first_dropout (`float`, *optional*, defaults to 0.1): 96 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and 97 | [`TFGPT2DoubleHeadsModel`]. 98 | 99 | The dropout ratio to be used after the projection and activation. 100 | scale_attn_weights (`bool`, *optional*, defaults to `True`): 101 | Scale attention weights by dividing by sqrt(hidden_size).. 102 | use_cache (`bool`, *optional*, defaults to `True`): 103 | Whether or not the model should return the last key/values attentions (not used by all models). 104 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`): 105 | Whether to additionally scale attention weights by `1 / layer_idx + 1`. 106 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`): 107 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention 108 | dot-product/softmax to float() when training with mixed precision. 109 | 110 | Example: 111 | 112 | ```python 113 | >>> from transformers import GPT2Config, GPT2Model 114 | 115 | >>> # Initializing a GPT2 configuration 116 | >>> configuration = GPT2Config() 117 | 118 | >>> # Initializing a model (with random weights) from the configuration 119 | >>> model = GPT2Model(configuration) 120 | 121 | >>> # Accessing the model configuration 122 | >>> configuration = model.config 123 | ```""" 124 | 125 | model_type = "gpt2" 126 | keys_to_ignore_at_inference = ["past_key_values"] 127 | attribute_map = { 128 | "hidden_size": "hidden_size", 129 | "max_position_embeddings": "max_position_embeddings", 130 | "num_attention_heads": "num_attention_heads", 131 | "num_hidden_layers": "num_hidden_layers", 132 | } 133 | 134 | def __init__( 135 | self, 136 | **kwargs, 137 | ): 138 | # import ipdb; ipdb.set_trace() 139 | # Load config from dict 140 | # import ipdb; ipdb.set_trace() 141 | # self.backbone_config = PretrainedConfig.from_dict(backbone) 142 | if "backbone" in kwargs: 143 | backbone = kwargs.pop("backbone") 144 | if "model_type" in backbone: 145 | backbone_cfg_cls = CONFIG_MAPPING[backbone["model_type"]] 146 | self.backbone = backbone_cfg_cls.from_dict(backbone) 147 | else: 148 | # import ipdb; ipdb.set_trace() 149 | self.backbone = AutoConfig.from_pretrained( 150 | backbone["init_name_or_path"] 151 | ) 152 | elif "init_name_or_path" in kwargs: 153 | self.backbone = AutoConfig.from_pretrained(kwargs["init_name_or_path"]) 154 | # import ipdb; ipdb.set_trace() 155 | super().__init__(**kwargs) 156 | if hasattr(self, "n_embd"): 157 | self.hidden_size = self.n_embd 158 | 159 | if hasattr(self, "training"): 160 | if "use_future_vehicle_forcast" in self.training: 161 | print("Fixing typo in config") 162 | self.training["use_future_vehicle_forecast"] = self.training[ 163 | "use_future_vehicle_forcast" 164 | ] 165 | 166 | self.dotify() 167 | 168 | def dotify(self): 169 | # Convert all dict attributes to dotdict 170 | # from omegaconf.dictconfig import DictConfig 171 | from dotmap import DotMap as ddict 172 | 173 | # Iterate over all attributes 174 | for key, value in self.__dict__.items(): 175 | if isinstance(value, dict): 176 | self.__dict__[key] = ddict(value, _dynamic=False) 177 | 178 | def dedotify(self): 179 | # Convert all dotdict attributes to dict 180 | # Iterate over all attributes 181 | from dotmap import DotMap as ddict 182 | 183 | for key, value in self.__dict__.items(): 184 | if isinstance(value, ddict): 185 | self.__dict__[key] = dict(value) 186 | 187 | def __getattr__(self, attr: str, default=None): 188 | # Backward compatibility with old attribute names 189 | # Raise warning on each access to deprecated attributes 190 | # print(attr) 191 | if attr == "backbone": 192 | raise AttributeError(f"Attribute {attr} not found") 193 | # print(attr) 194 | 195 | if getattr(self, "backbone", None) is not None: 196 | # import warnings 197 | # print(f"attr {attr} not found, activating deprecated failsafe") 198 | # warnings.warn( 199 | # f"Attribute {attr} will be deprecated in favor of backbone.{attr}. Please update your code.", 200 | # DeprecationWarning, 201 | # ) 202 | return getattr(self.backbone, attr) 203 | else: 204 | raise AttributeError(f"Attribute {attr} not found") 205 | 206 | @staticmethod 207 | def from_hydra(hydra_cfg: Any) -> "CarformerConfig": 208 | from omegaconf import OmegaConf 209 | 210 | return CarformerConfig.from_dict( 211 | OmegaConf.to_container(hydra_cfg, resolve=True) 212 | ) 213 | 214 | def save_pretrained( 215 | self, 216 | save_directory: Union[str, os.PathLike], 217 | push_to_hub: bool = False, 218 | **kwargs, 219 | ): 220 | """ 221 | Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the 222 | [`~PretrainedConfig.from_pretrained`] class method. 223 | 224 | Args: 225 | save_directory (`str` or `os.PathLike`): 226 | Directory where the configuration JSON file will be saved (will be created if it does not exist). 227 | push_to_hub (`bool`, *optional*, defaults to `False`): 228 | Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the 229 | repository you want to push to with `repo_id` (will default to the name of `save_directory` in your 230 | namespace). 231 | kwargs: 232 | Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. 233 | """ 234 | if os.path.isfile(save_directory): 235 | raise AssertionError( 236 | f"Provided path ({save_directory}) should be a directory, not a file" 237 | ) 238 | 239 | os.makedirs(save_directory, exist_ok=True) 240 | 241 | # Call super to save config.json 242 | super().save_pretrained(save_directory, push_to_hub=push_to_hub, **kwargs) 243 | -------------------------------------------------------------------------------- /carformer/carformer/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - user: shadi 3 | - backbone: llama-mini 4 | - training: quantized 5 | - hyperparams: default 6 | - dataset: b2d 7 | - logging: wandb 8 | - deepspeed: disable 9 | - _self_ 10 | 11 | hydra: 12 | run: 13 | dir: checkpoints/${expname}/${sanitize:'${hydra.job.override_dirname}_bev=${training.bev_type}'}/${now:%Y-%m-%d_%H-%M-%S} 14 | job: 15 | chdir: False 16 | config: 17 | override_dirname: 18 | exclude_keys: 19 | - expname 20 | - training/bev 21 | - num_workers 22 | - dataset.dataset_path_rel 23 | - user 24 | - preload_in_memory 25 | - augmentable_preloader 26 | - preload 27 | - visualize 28 | - amp 29 | - multi_gpu_strategy 30 | - user.dataset_dir 31 | - cpu 32 | - gpus 33 | - ckpt_path 34 | - deepspeed 35 | - training.parallel_dataset_init 36 | - force_save 37 | - visualize_start_epoch 38 | - force_log 39 | - overfit_batches 40 | - start_saving_epoch 41 | - nodes 42 | - logging.track_weight_changes 43 | kv_sep: '=' 44 | item_sep: '_' 45 | 46 | seed: 1234 47 | debug: False 48 | visualize: True 49 | visualize_start_epoch: -1 50 | visualize_interval: 1 51 | overfit: 0 52 | cpu: False 53 | gpus: ${oc.decode:${oc.env:WORLD_SIZE,1}} 54 | multi_gpu_strategy: ddp 55 | amp: True 56 | num_workers: 20 57 | early_stopping: False 58 | early_stopping_patience: 5 59 | early_stopping_metric: action_classification_loss 60 | 61 | save_every: 1 62 | start_saving_epoch: -1 63 | force_save: False 64 | force_log: False 65 | 66 | expname: TRAINING 67 | wandb_name: training_PlanT_${hydra:job.override_dirname} 68 | wandb_tag: 69 | save_dir: ${hydra:run.dir} 70 | cp_command_dir: ${user.working_dir}/cpcommands 71 | 72 | data_dir: ${user.dataset_dir}/${dataset.dataset_path_rel} # Path to the data directory and name of data folder 73 | preload: True 74 | preload_in_memory: False 75 | augmentable_preloader: False 76 | cache_dir: ${user.working_dir}/.cache/ 77 | wipe_cache: False 78 | 79 | use_deepspeed: False 80 | ckpt_path: null 81 | gradient_checkpointing: False 82 | overfit_batches: 1 83 | nodes: 1 -------------------------------------------------------------------------------- /carformer/carformer/config/dataset/b2d.yaml: -------------------------------------------------------------------------------- 1 | dataset_path_rel: B2D-base 2 | data_format: "b2d" 3 | subsample_ratio: 1.0 4 | fps: 10 -------------------------------------------------------------------------------- /carformer/carformer/config/deepspeed/basic.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | train_micro_batch_size_per_gpu: ${hyperparams.batch_size} 4 | gradient_accumulation_steps: ${hyperparams.gradient_accumulation_steps} 5 | optimizer: 6 | type: ${hyperparams.optimizer.name} 7 | params: 8 | lr: 5e-5 9 | betas: 10 | - 0.9 11 | - 0.999 12 | eps: 1e-6 13 | weight_decay: ${hyperparams.optimizer.kwargs.weight_decay} 14 | bf16: 15 | enabled: true 16 | zero_optimization: false -------------------------------------------------------------------------------- /carformer/carformer/config/deepspeed/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | use_deepspeed: true 3 | train_batch_size: ${deepspeed.train_micro_batch_size_per_gpu} 4 | -------------------------------------------------------------------------------- /carformer/carformer/config/deepspeed/disable.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | use_deepspeed: false -------------------------------------------------------------------------------- /carformer/carformer/config/deepspeed/zero2.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | bf16: 4 | enabled: true 5 | zero_optimization: 6 | stage: 2 7 | overlap_comm: true 8 | contiguous_gradients: true 9 | sub_group_size: 1000000000 10 | reduce_bucket_size: 500000000 11 | stage3_prefetch_bucket_size: 500000000 12 | stage3_param_persistence_threshold: 1000000 13 | stage3_max_live_parameters: 1000000000 14 | stage3_max_reuse_distance: 1000000000 15 | stage3_gather_16bit_weights_on_model_save: false 16 | # optimizer: 17 | # type: AdamW 18 | # params: 19 | # lr: 5e-5 20 | # betas: 21 | # - 0.9 22 | # - 0.999 23 | # eps: 1e-6 24 | # weight_decay: ${hyperparams.optimizer.kwargs.weight_decay} 25 | gradient_clipping: 1 26 | steps_per_print: 5 27 | train_micro_batch_size_per_gpu: ${hyperparams.batch_size} 28 | zero_allow_untested_optimizer: true -------------------------------------------------------------------------------- /carformer/carformer/config/deepspeed/zero3.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | bf16: 4 | enabled: true 5 | zero_optimization: 6 | stage: 3 7 | overlap_comm: true 8 | contiguous_gradients: true 9 | sub_group_size: 1000000000 10 | reduce_bucket_size: 500000000 11 | stage3_prefetch_bucket_size: 500000000 12 | stage3_param_persistence_threshold: 1000000 13 | stage3_max_live_parameters: 1000000000 14 | stage3_max_reuse_distance: 1000000000 15 | stage3_gather_16bit_weights_on_model_save: false 16 | # optimizer: 17 | # type: AdamW 18 | # params: 19 | # lr: 5e-5 20 | # betas: 21 | # - 0.9 22 | # - 0.999 23 | # eps: 1e-6 24 | # weight_decay: ${hyperparams.optimizer.kwargs.weight_decay} 25 | gradient_clipping: 1 26 | steps_per_print: 5 27 | train_micro_batch_size_per_gpu: ${hyperparams.batch_size} 28 | stage3_gather_16bit_weights_on_model_save: true 29 | -------------------------------------------------------------------------------- /carformer/carformer/config/experiments/init_rgb_from_ckpt.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | training: 3 | rgb_backbone: 4 | init_from_ckpt: 5 | enabled: true -------------------------------------------------------------------------------- /carformer/carformer/config/experiments/onlyfrc.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | training: 3 | loss_params: 4 | wp_loss: 0.0 5 | path_loss: 0.0 6 | mask_loss: 0.0 7 | state_forecast: 0.5 -------------------------------------------------------------------------------- /carformer/carformer/config/experiments/overfit.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | hydra: 4 | run: 5 | dir: ./checkpoints/debug/tmp_debug 6 | 7 | overfit: 1 8 | 9 | wandb_tag: overfit 10 | 11 | training: 12 | max_instances: ${eval:${hyperparams.batch_size} * 10} 13 | splits: 14 | train: val 15 | 16 | dataset: 17 | subsample_ratio: 1.0 # Since we are already using max_instances, we don't need to subsample. 18 | 19 | backbone: 20 | attn_pdrop: 0.0 21 | resid_pdrop: 0.0 22 | embd_pdrop: 0.0 -------------------------------------------------------------------------------- /carformer/carformer/config/experiments/use_light_encoder.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /training/bev/rgb_backbone@training.light_rgb_backbone: llava1pt6 4 | - /backbone@backbone: llama-micro 5 | - _self_ 6 | 7 | light_select_layer: 8 8 | training: 9 | utilize_fast_current_latent: True 10 | use_light_as_query: False 11 | light_rgb_backbone: 12 | select_layer: ${light_select_layer} 13 | ema_enabled: False 14 | model_path: ${user.working_dir}/${training.light_rgb_backbone.model_path_rel} 15 | -------------------------------------------------------------------------------- /carformer/carformer/config/hyperparams/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - optimizer: adamw 3 | - scheduler: linear_warmup 4 | debug: False 5 | lr: 3e-5 6 | max_grad_norm: 1.0 7 | batch_size: 128 8 | gradient_accumulation_steps: 1 9 | num_epochs: 200 10 | patience: 40 -------------------------------------------------------------------------------- /carformer/carformer/config/hyperparams/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | name: Adam 2 | kwargs: 3 | weight_decay: 0.0 -------------------------------------------------------------------------------- /carformer/carformer/config/hyperparams/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | name: AdamW 2 | kwargs: 3 | weight_decay: 1e-4 -------------------------------------------------------------------------------- /carformer/carformer/config/hyperparams/optimizer/sgd.yaml: -------------------------------------------------------------------------------- 1 | name: SGD 2 | kwargs: 3 | weight_decay: 1e-4 -------------------------------------------------------------------------------- /carformer/carformer/config/hyperparams/scheduler/cosine_annealing.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/carformer/carformer/config/hyperparams/scheduler/cosine_annealing.yaml -------------------------------------------------------------------------------- /carformer/carformer/config/hyperparams/scheduler/linear_warmup.yaml: -------------------------------------------------------------------------------- 1 | name: linear 2 | warmup_ratio: 0.05 3 | kwargs: 4 | num_warmup_steps: ${eval:'int(${hyperparams.num_epochs} * ${hyperparams.scheduler.warmup_ratio})'} 5 | num_training_steps: ${hyperparams.num_epochs} -------------------------------------------------------------------------------- /carformer/carformer/config/logging/wandb.yaml: -------------------------------------------------------------------------------- 1 | project: ponderer 2 | entity: ${user.wandb_entity} 3 | mode: online -------------------------------------------------------------------------------- /carformer/carformer/config/training/action/path.yaml: -------------------------------------------------------------------------------- 1 | path: 2 | name: path 3 | width: ${eval:'${training.num_path} * 2'} 4 | -------------------------------------------------------------------------------- /carformer/carformer/config/training/action/waypoints.yaml: -------------------------------------------------------------------------------- 1 | waypoints: 2 | name: waypoints 3 | width: ${eval:'${training.num_waypoints} * 2'} 4 | future_horizon: ${training.num_waypoints} 5 | -------------------------------------------------------------------------------- /carformer/carformer/config/training/bev/rgb.yaml: -------------------------------------------------------------------------------- 1 | # @package training 2 | defaults: 3 | - _self_ 4 | - rgb_backbone: llava1pt6 5 | 6 | bev: 7 | rgb_front: 8 | name: rgb_front 9 | 10 | rgb_crop: 11 | type: dualcenter 12 | crop_size: ${eval:'${training.rgb_backbone.input_size} * 2'} 13 | resize: ${training.rgb_backbone.input_size} 14 | 15 | 16 | tokenized_state: False 17 | object_level: False 18 | ema_enabled: False 19 | ema_decay: 0.992 20 | ema_every_steps: 1 21 | ema_start: 0 22 | ema_end_epoch: -1 -------------------------------------------------------------------------------- /carformer/carformer/config/training/bev/rgb_backbone/internvl2-4b.yaml: -------------------------------------------------------------------------------- 1 | model_path_rel: bin/rgb/internvl2-4b-visionenc/ 2 | model_path: ${user.working_dir}/${training.rgb_backbone.model_path_rel} 3 | 4 | frozen: false 5 | outputs: 6 | whole: false 7 | patches: true 8 | 9 | projection_dim: ${backbone.hidden_size} 10 | 11 | input_size: 448 12 | select_layer: -1 13 | try_to_truncate_layers: true 14 | downsample: True 15 | 16 | ema_enabled: ${training.ema_enabled} -------------------------------------------------------------------------------- /carformer/carformer/config/training/bev/rgb_backbone/internvl2-76b.yaml: -------------------------------------------------------------------------------- 1 | model_path_rel: bin/rgb/internvl2-76b-visionenc/ 2 | model_path: ${user.working_dir}/${training.rgb_backbone.model_path_rel} 3 | 4 | frozen: false 5 | outputs: 6 | whole: false 7 | patches: true 8 | 9 | projection_dim: ${backbone.hidden_size} 10 | 11 | input_size: 448 12 | select_layer: -1 13 | try_to_truncate_layers: true 14 | downsample: True 15 | 16 | ema_enabled: ${training.ema_enabled} -------------------------------------------------------------------------------- /carformer/carformer/config/training/bev/rgb_backbone/internvl2pt5-8b.yaml: -------------------------------------------------------------------------------- 1 | model_path_rel: bin/rgb/internvl2pt5-8b/ 2 | model_path: ${user.working_dir}/${training.rgb_backbone.model_path_rel} 3 | 4 | frozen: false 5 | outputs: 6 | whole: false 7 | patches: true 8 | 9 | projection_dim: ${backbone.hidden_size} 10 | 11 | input_size: 448 12 | select_layer: -1 13 | try_to_truncate_layers: true 14 | downsample: True 15 | 16 | ema_enabled: ${training.ema_enabled} 17 | -------------------------------------------------------------------------------- /carformer/carformer/config/training/bev/rgb_backbone/llava1pt6.yaml: -------------------------------------------------------------------------------- 1 | model_path_rel: bin/rgb/llava-v1.6-vicuna-visionenc 2 | model_path: ${user.working_dir}/${training.rgb_backbone.model_path_rel} 3 | 4 | frozen: 5 | model: False 6 | ema_model: False 7 | projector: False 8 | 9 | outputs: 10 | whole: false 11 | patches: true 12 | 13 | projection_dim: ${backbone.hidden_size} 14 | 15 | input_size: 336 16 | select_layer: -2 17 | dropout_attn: 0.0 18 | masking_rate: 0.0 19 | 20 | try_to_truncate_layers: true 21 | 22 | override_kwargs: 23 | attention_dropout: ${training.rgb_backbone.dropout_attn} 24 | 25 | downsample: True 26 | ema_enabled: ${training.ema_enabled} 27 | 28 | init_from_ckpt: 29 | enabled: False 30 | ckpt_path_rel: bin/rgb/llava-v14-micro.pt 31 | ckpt_path: ${user.working_dir}/${training.rgb_backbone.init_from_ckpt.ckpt_path_rel} 32 | projector: true 33 | ema_model: true 34 | model: true 35 | freeze: true -------------------------------------------------------------------------------- /carformer/carformer/config/training/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - action: 3 | - path 4 | - steer 5 | - goal: 6 | - dual_target_point 7 | - reward: 8 | - reward 9 | - bev: 10 | - rgb 11 | - state: 12 | - speed 13 | - _self_ 14 | loss_params: 15 | default: 16 | classification: 0 17 | action: 18 | classification: 1 19 | reconstruction: 1 20 | 21 | action_type: ${merge_keys:${training.action}} 22 | state_type: ${merge_keys:${training.state}, ${training.bev}} 23 | non_bev_state_type: ${merge_keys:${training.state}} 24 | bev_type: ${merge_keys:${training.bev}} 25 | goal_type: ${merge_keys:${training.goal}} 26 | reward_type: ${merge_keys:${training.reward}} 27 | condition_on_goal: true 28 | goal_conditioning_type: local 29 | max_token_types: 4 30 | quantized: false 31 | 32 | # Data settings 33 | integrate_rewards_to_go: false 34 | context_length: 1 35 | frame_stride: 5 36 | inter_window_stride: 2 37 | skip_noisy: true 38 | trim_first_and_last: true 39 | trim_count: 1 40 | max_instances: -1 41 | drop_last: true 42 | future_horizon: 0 43 | past_horizon: 0 44 | use_future_ego_waypoints: ${has_key:${training.action},waypoints} 45 | use_future_vehicle_forecast: false 46 | forecast_steps: 1 47 | include_noisy_in_action: false 48 | splits: 49 | train: train 50 | val: val 51 | dynamic_batching: true # false, whether or not to crop padding 52 | weighted_sampling: false # false, whether or not to sample based on class weights 53 | bucket_weights: 54 | type: uniform 55 | total_ratio: 0.6 56 | weights: 57 | - 1.0 # general 58 | - 1.0 # acc_scratch 59 | - 2.0 # acc_light_pedal 60 | - 2.0 # acc_medium_pedal 61 | - 1.0 # acc_heavy_pedal 62 | - 1.0 # acc_brake 63 | - 1.0 # acc_coast 64 | - 3.0 # steer_right 65 | - 3.0 # steer_left 66 | - 1.0 # vehicle_hazard_front 67 | - 1.0 # vehicle_hazard_back 68 | - 1.0 # vehicle_hazard_side 69 | - 1.0 # stop_sign 70 | - 1.0 # red_light 71 | - 1.0 # swerving 72 | - 1.0 # pedestrian 73 | 74 | get_weight_reduce_fn: mean 75 | get_noisy_reduce_fn: last 76 | 77 | # Training settings 78 | split_ratio: 0.8 79 | 80 | # Caching 81 | dataset_caching: 82 | enabled: True 83 | cache_metadata: True 84 | cache_slow_attributes: True 85 | cache_dir: ${cache_dir} 86 | parallel_dataset_init: True 87 | parallel_dataset_workers: 16 -------------------------------------------------------------------------------- /carformer/carformer/config/training/goal/default_goal.yaml: -------------------------------------------------------------------------------- 1 | # @package training 2 | goal_conditioning_type: local 3 | goal_continuous: True -------------------------------------------------------------------------------- /carformer/carformer/config/training/goal/dual_target_point.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_goal 3 | dual_target_point: 4 | name: dual_target_point 5 | width: 4 6 | mean: [[ 5.1162534 , -0.1575937 ],[26.005814 , -0.09633584]] 7 | std: [[ 0.8992543, 0.9467049],[15.428605 , 6.614513 ]] -------------------------------------------------------------------------------- /carformer/carformer/config/training/goal/target_point.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default_goal 3 | target_point: 4 | name: target_point 5 | width: 2 6 | mean: [5.1162534, -0.1575937] 7 | std: [0.8992543, 0.9467049] -------------------------------------------------------------------------------- /carformer/carformer/config/training/quantized.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - default 3 | - override action: 4 | - waypoints 5 | - path 6 | - override goal: 7 | - target_point 8 | quantized: true 9 | 10 | loss_params: 11 | wp_loss: 1.0 12 | path_loss: 1.0 13 | mask_loss: 0.0625 14 | state_forecast: 0.5 15 | bev: {} # Nothing 16 | 17 | # For waypoints 18 | num_waypoints: 10 19 | num_path: 20 20 | future_horizon: ${training.num_waypoints} 21 | # Forecasting 22 | use_future_vehicle_forecast: false 23 | vae_target_supervision: null 24 | 25 | use_predicted_latent_with_gap: false 26 | pred_latent_layers: 2 27 | pred_latent_ffn_hidden: 2048 28 | pred_latent_ffn_dropout: 0.1 29 | pred_latent_post_mlp: false 30 | pred_latent_use_metadata: true 31 | use_real_latent_ratio: 0.0 32 | 33 | # action gap (for future action offset) 34 | action_gap: 35 | 36 | # Supplemental supervision 37 | gen_masks_for_action: True 38 | 39 | normalize_goal: True 40 | create_goal_mask: True 41 | 42 | use_gt_frc: false 43 | zero_out_frc_branch: false 44 | use_gt_frc_only: false 45 | ignore_past_for_length: false 46 | 47 | use_past_horizon_states: false -------------------------------------------------------------------------------- /carformer/carformer/config/training/reward/reward.yaml: -------------------------------------------------------------------------------- 1 | reward: 2 | name: reward -------------------------------------------------------------------------------- /carformer/carformer/config/training/state/lights.yaml: -------------------------------------------------------------------------------- 1 | lights: 2 | name: lights -------------------------------------------------------------------------------- /carformer/carformer/config/training/state/speed.yaml: -------------------------------------------------------------------------------- 1 | speed: 2 | name: speed -------------------------------------------------------------------------------- /carformer/carformer/config/user/example.yaml: -------------------------------------------------------------------------------- 1 | dataset_dir: /leonardo_work/shamdan/ 2 | working_dir: /leonardo/home/userexternal/shamdan0/research/Carformer 3 | 4 | wandb_entity: shadihamdan 5 | -------------------------------------------------------------------------------- /carformer/carformer/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parser import * 2 | from .data_utils import * 3 | from .data import * 4 | from .wrapper import * 5 | from .utils import * 6 | -------------------------------------------------------------------------------- /carformer/carformer/data/data_parser.py: -------------------------------------------------------------------------------- 1 | # import orjson as json 2 | import json 3 | import gzip 4 | import os 5 | from PIL import Image 6 | import numpy as np 7 | from .data_utils import ( 8 | iterative_line_interpolation, 9 | get_hazard_directions, 10 | is_walker_hazard, 11 | ) 12 | 13 | home_dir = os.environ["HOME"] 14 | path_cache = {} 15 | 16 | 17 | class Parser: 18 | def __init__( 19 | self, 20 | root_path, 21 | state_type, 22 | action_type, 23 | reward_type, 24 | goal_type, 25 | folder_to_ext=None, 26 | size=None, 27 | cache_dir=None, 28 | ): 29 | self.root_path = root_path 30 | self.state_type = state_type 31 | self.action_type = action_type 32 | self.reward_type = reward_type 33 | self.goal_type = goal_type 34 | 35 | self.states = state_type.split("-") 36 | self.actions = action_type.split("-") 37 | self.rewards = reward_type.split("-") 38 | self.goals = goal_type.split("-") 39 | 40 | if folder_to_ext is None: 41 | self._folder_to_ext = {} 42 | # Check the extension of the first file in every folder in path 43 | for folder in os.listdir(self.root_path): 44 | folder_path = os.path.join(self.root_path, folder) 45 | if os.path.isdir(folder_path): 46 | fl = os.listdir(folder_path) 47 | # Get extension 48 | if len(fl) > 0: 49 | ext = os.path.splitext(fl[0])[1] 50 | self._folder_to_ext[folder] = ext 51 | else: 52 | self._folder_to_ext = folder_to_ext 53 | 54 | if size is not None: 55 | self.length = size 56 | else: 57 | # check if self.root_path/anno/00000.json.gz exists. If not, set length to 0 and return 58 | sanity_file = os.path.join(self.root_path, "anno", "00000.json.gz") 59 | if not os.path.exists(sanity_file): 60 | self.length = 0 61 | return 62 | self.length = len(os.listdir(os.path.join(self.root_path, "anno"))) 63 | 64 | self.cache_dir = cache_dir 65 | self.path_cache = None 66 | 67 | def get_state( 68 | self, 69 | idx, 70 | preprocessing_functions=None, 71 | filtering_functions=None, 72 | skip_keys=None, 73 | ): 74 | ts_prefix = str(idx).zfill(5) 75 | state_dict = self.gzip_json_load(os.path.join("anno", f"{ts_prefix}.json.gz")) 76 | 77 | state = {} 78 | 79 | for s in self.states: 80 | if skip_keys is not None and s in skip_keys: 81 | continue 82 | if "rgb" in s: 83 | rgb = Image.open( 84 | os.path.join(self.root_path, "camera", s, f"{ts_prefix}.jpg") 85 | ) 86 | 87 | speed = state_dict["speed"] 88 | 89 | action = np.mean(state_dict["bounding_boxes"][0]["world2ego"]) 90 | 91 | state[s] = rgb 92 | elif s == "speed": 93 | state["speed"] = state_dict["speed"] 94 | else: 95 | raise ValueError(f"State type {s} not recognized") 96 | 97 | if preprocessing_functions is not None and s in preprocessing_functions: 98 | state[s] = preprocessing_functions[s](state[s]) 99 | 100 | return state 101 | 102 | def get_action(self, idx, include_noise=False, skip_keys=None): 103 | ts_prefix = str(idx).zfill(5) 104 | state_dict = self.gzip_json_load(os.path.join("anno", f"{ts_prefix}.json.gz")) 105 | 106 | action = {} 107 | 108 | for a in self.actions: 109 | if skip_keys is not None and a in skip_keys: 110 | continue 111 | if a == "waypoints": 112 | # Get the current ego matrix from the measurement dict 113 | if skip_keys is not None and "ego_matrix" in skip_keys: 114 | continue 115 | action["ego_matrix"] = state_dict["bounding_boxes"][0]["world2ego"] 116 | elif a == "path": 117 | # Default to 20 waypoints like CarLLaVA 118 | ego_matrix = state_dict["bounding_boxes"][0]["world2ego"] 119 | pointer_idx = idx 120 | points = [] 121 | if self.path_cache is not None: 122 | relevant_points = self.path_cache[pointer_idx:] 123 | 124 | relevant_points_inv = np.linalg.inv(relevant_points) 125 | ego_matrix = np.asarray(ego_matrix) 126 | points = np.einsum( 127 | "Xi, BiY -> BXY", ego_matrix, relevant_points_inv 128 | )[:, :2, 3] 129 | 130 | # Multiply by -1 to get the correct y coordinate 131 | points[:, 1] = -points[:, 1] 132 | else: 133 | while True: 134 | print("Rebuilding path cache") 135 | path = os.path.join( 136 | self.root_path, 137 | "anno", 138 | f"{str(pointer_idx).zfill(5)}.json.gz", 139 | ) 140 | 141 | if path in path_cache: 142 | cur_point = path_cache[path] 143 | else: 144 | if not os.path.exists(path): 145 | break 146 | 147 | cur_dct = self.gzip_json_load( 148 | os.path.join( 149 | "anno", f"{str(pointer_idx).zfill(5)}.json.gz" 150 | ) 151 | ) 152 | cur_point = cur_dct["bounding_boxes"][0]["world2ego"] 153 | 154 | x, y = np.dot(ego_matrix, np.linalg.inv(cur_point))[:2, 3] 155 | 156 | points.append((x, -y)) 157 | pointer_idx += 1 158 | 159 | path = iterative_line_interpolation(points) 160 | 161 | action["path"] = path 162 | else: 163 | raise ValueError(f"Action type {a} not recognized") 164 | 165 | if include_noise: 166 | return action, False 167 | else: 168 | return action 169 | 170 | def get_noisy(self, idx): 171 | return False 172 | 173 | def get_reward(self, idx, skip_keys=None): 174 | rewards = {} 175 | 176 | for r in self.rewards: 177 | if skip_keys is not None and r in skip_keys: 178 | continue 179 | rewards[r] = 0.0 180 | 181 | return rewards 182 | 183 | def get_goal(self, idx, skip_keys=None): 184 | ts_prefix = str(idx).zfill(5) 185 | state_dict = self.gzip_json_load(os.path.join("anno", f"{ts_prefix}.json.gz")) 186 | 187 | goal = {} 188 | 189 | for g in self.goals: 190 | if skip_keys is not None and g in skip_keys: 191 | continue 192 | if g == "target_point": 193 | # Get the target point from the state 194 | ego = state_dict["bounding_boxes"][0] 195 | theta = ego["rotation"][-1] * np.pi / 180 196 | 197 | command_near_xy = np.array( 198 | [ 199 | state_dict["x_command_near"] - state_dict["x"], 200 | -state_dict["y_command_near"] + state_dict["y"], 201 | ] 202 | ) 203 | rotation_matrix = np.array( 204 | [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] 205 | ) 206 | local_command_xy = rotation_matrix @ command_near_xy 207 | # command_far_xy = np.array([state_dict["x_command_far"]-state_dict['x'],-state_dict["y_command_far"]+state_dict['y']]) 208 | # local_command_far_xy = rotation_matrix @ command_far_xy 209 | 210 | goal["target_point"] = local_command_xy 211 | elif g == "dual_target_point": 212 | # Get the target point from the state 213 | # local_command_point = np.array(state_dict["target_point"]) 214 | # goal["target_point"] = local_command_point 215 | ego = state_dict["bounding_boxes"][0] 216 | theta = ego["rotation"][-1] * np.pi / 180 217 | 218 | command_near_xy = np.array( 219 | [ 220 | state_dict["x_command_near"] - state_dict["x"], 221 | -state_dict["y_command_near"] + state_dict["y"], 222 | ] 223 | ) 224 | rotation_matrix = np.array( 225 | [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] 226 | ) 227 | local_command_xy = rotation_matrix @ command_near_xy 228 | command_far_xy = np.array( 229 | [ 230 | state_dict["x_command_far"] - state_dict["x"], 231 | -state_dict["y_command_far"] + state_dict["y"], 232 | ] 233 | ) 234 | local_command_far_xy = rotation_matrix @ command_far_xy 235 | 236 | goal["dual_target_point"] = np.stack( 237 | [local_command_xy, local_command_far_xy], axis=0 238 | ) 239 | else: 240 | raise ValueError(f"Goal type {g} not recognized") 241 | 242 | return goal 243 | 244 | def get_size(self): 245 | return self.length 246 | 247 | def get_weight(self, idx): 248 | return self.get_buckets(idx) 249 | 250 | def get_buckets(self, idx): 251 | # Get which bucket the current idx belongs to 252 | assert self.buckets is not None, "Buckets not initialized" 253 | 254 | return self.buckets[idx] 255 | 256 | @staticmethod 257 | def get_folder_to_ext(dir): 258 | folder_to_ext = {} 259 | for folder in os.listdir(dir): 260 | folder_path = os.path.join(dir, folder) 261 | if os.path.isdir(folder_path): 262 | fl = os.listdir(folder_path) 263 | # Get extension 264 | if len(fl) > 0: 265 | ext = os.path.splitext(fl[0])[1] 266 | folder_to_ext[folder] = ext 267 | else: 268 | raise ValueError(f"Folder {folder} is empty in directory {dir}") 269 | 270 | return folder_to_ext 271 | 272 | def gzip_json_load(self, rel_file_path, root_path=None): 273 | if root_path is None: 274 | root_path = self.root_path 275 | with gzip.open(os.path.join(root_path, rel_file_path), "r") as f: 276 | return json.loads(f.read().decode("utf-8")) 277 | 278 | def path_is_cached(self, cache_dir): 279 | return os.path.exists( 280 | os.path.join(cache_dir, f"{self.root_path.split('/')[-1]}.json.gz") 281 | ) 282 | 283 | def bucket_is_cached(self, path_cache_dir): 284 | return os.path.exists( 285 | os.path.join(path_cache_dir, f"{self.root_path.split('/')[-1]}.npz") 286 | ) 287 | 288 | def cache_buckets(self, bucket_cache_dir): 289 | if self.bucket_is_cached(bucket_cache_dir): 290 | return 291 | 292 | bucket_names, all_buckets = self.get_all_buckets() 293 | 294 | cache_path = os.path.join( 295 | bucket_cache_dir, f"{self.root_path.split('/')[-1]}.npz" 296 | ) 297 | 298 | np.savez(cache_path, bucket_names=bucket_names, buckets=all_buckets) 299 | 300 | def cache_path(self, cache_dir): 301 | if self.path_is_cached(cache_dir): 302 | self.load_path_cache(cache_dir) 303 | return 304 | 305 | if "path" in self.action_type: 306 | all_paths = self.get_all_paths() 307 | 308 | cache_path = os.path.join( 309 | cache_dir, f"{self.root_path.split('/')[-1]}.json.gz" 310 | ) 311 | 312 | with gzip.open(cache_path, "w") as f: 313 | f.write(json.dumps(all_paths).encode("utf-8")) 314 | 315 | def get_all_paths(self): 316 | if "path" not in self.action_type: 317 | return None 318 | all_egos = [] 319 | for i in range(self.length): 320 | ts_prefix = str(i).zfill(5) 321 | state_dict = self.gzip_json_load( 322 | os.path.join("anno", f"{ts_prefix}.json.gz") 323 | ) 324 | cur_point = state_dict["bounding_boxes"][0]["world2ego"] 325 | all_egos.append(cur_point) 326 | return all_egos 327 | 328 | def get_all_buckets(self): 329 | # Buckets are useful for bucketed sampling 330 | # Buckets defined as follows: 331 | # acceleration buckets: 332 | idx = 0 333 | all_state_dicts = [] 334 | from skit.profiling import Ticker 335 | 336 | t = Ticker(verbose=False, track=True) 337 | 338 | for i in range(self.length): 339 | ts_prefix = str(i).zfill(5) 340 | state_dict = self.gzip_json_load( 341 | os.path.join("anno", f"{ts_prefix}.json.gz") 342 | ) 343 | all_state_dicts.append(state_dict) 344 | 345 | swerving_scenarios = [ 346 | "Accident", 347 | "BlockedIntersection", 348 | "ConstructionObstacle", 349 | "HazardAtSideLane", 350 | "ParkedObstacle", 351 | "VehicleOpensDoorTwoWays", 352 | ] 353 | 354 | data_is_swerving = any([x in self.root_path for x in swerving_scenarios]) 355 | 356 | all_buckets = [] 357 | general_bucket_name = ["general"] 358 | acceleration_bucket_names = [ 359 | "acc_scratch", 360 | "acc_light_pedal", 361 | "acc_medium_pedal", 362 | "acc_heavy_pedal", 363 | "acc_brake", 364 | "acc_coast", 365 | ] 366 | steer_bucket_names = [ 367 | "steer_right", 368 | "steer_left", 369 | ] 370 | vehicle_hazard_bucket_names = [ 371 | "vehicle_hazard_front", 372 | "vehicle_hazard_back", 373 | "vehicle_hazard_side", 374 | ] 375 | stop_sign_bucket_names = ["stop_sign"] 376 | red_light_bucket_names = ["red_light"] 377 | swerving_bucket_names = ["swerving"] 378 | pedestrian_bucket_names = ["pedestrian"] 379 | 380 | for i in range(self.length): 381 | acceleration = all_state_dicts[i]["throttle"] 382 | brake = all_state_dicts[i]["brake"] 383 | speed = all_state_dicts[i]["speed"] 384 | 385 | acceleration_bucket = [ 386 | 1 if (acceleration > 0.2 and brake < 1.0 and speed < 0.05) else 0, 387 | 1 if (acceleration > 0.2 and acceleration < 0.5) else 0, 388 | 1 if (acceleration > 0.5 and acceleration < 0.9) else 0, 389 | 1 if (acceleration > 0.9) else 0, 390 | 1 if (brake > 0.2) else 0, 391 | 1 if (acceleration < 0.2 and brake < 1.0) else 0, 392 | ] 393 | 394 | steer = all_state_dicts[i]["steer"] 395 | 396 | steer_bucket = [ 397 | 1 if (steer > 0.2) else 0, 398 | 1 if (steer < -0.2) else 0, 399 | ] 400 | 401 | # other_vehicles = all_state_dicts[i]["vehicle_hazard"] 402 | 403 | vehicle_hazard_bucket = get_hazard_directions( 404 | all_state_dicts[i]["bounding_boxes"] 405 | ) 406 | # print(i, vehicle_hazard_bucket) 407 | 408 | vehicle_hazard_bucket = [ 409 | ( 410 | 1 if any([x < 30 for x in vehicle_hazard_bucket]) else 0 411 | ), # Heading from front 412 | ( 413 | 1 if any([x > 150 for x in vehicle_hazard_bucket]) else 0 414 | ), # Heading from back 415 | ( 416 | 1 if any([x > 30 and x < 150 for x in vehicle_hazard_bucket]) else 0 417 | ), # Heading from side 418 | ] 419 | 420 | if data_is_swerving and abs(steer) > 0.1: 421 | swerving_bucket = 1 422 | else: 423 | swerving_bucket = 0 424 | 425 | all_objs = all_state_dicts[i]["bounding_boxes"] 426 | 427 | stopsigns = [ 428 | x 429 | for x in all_objs 430 | if x["class"] == "traffic_sign" and x["type_id"] == "traffic.stop" 431 | ] 432 | 433 | stop_sign_bucket = 1 if any([x["affects_ego"] for x in stopsigns]) else 0 434 | 435 | redtrafficlights = [ 436 | x for x in all_objs if x["class"] == "traffic_light" and x["state"] == 0 437 | ] 438 | 439 | red_light_bucket = ( 440 | 1 if any([x["affects_ego"] for x in redtrafficlights]) else 0 441 | ) 442 | 443 | pedestrian_hazards = is_walker_hazard(all_objs) 444 | pedestrian_hazard_bucket = 1 if pedestrian_hazards else 0 445 | 446 | instance_buckets = ( 447 | [1] 448 | + acceleration_bucket 449 | + steer_bucket 450 | + vehicle_hazard_bucket 451 | + [ 452 | stop_sign_bucket, 453 | red_light_bucket, 454 | swerving_bucket, 455 | pedestrian_hazard_bucket, 456 | ] 457 | ) 458 | 459 | all_buckets.append(instance_buckets) 460 | 461 | bucket_names = ( 462 | general_bucket_name 463 | + acceleration_bucket_names 464 | + steer_bucket_names 465 | + vehicle_hazard_bucket_names 466 | + stop_sign_bucket_names 467 | + red_light_bucket_names 468 | + swerving_bucket_names 469 | + pedestrian_bucket_names 470 | ) 471 | 472 | all_buckets = np.asarray(all_buckets) 473 | 474 | return bucket_names, all_buckets 475 | 476 | def validate_cache(self, path_cache_dir): 477 | if not "path" in self.action_type: 478 | return True 479 | 480 | to_compare = self.load_path_cache(path_cache_dir) 481 | 482 | all_egos = self.get_all_paths() 483 | 484 | return np.allclose(np.asarray(all_egos), np.asarray(to_compare)) 485 | 486 | def load_path_cache(self, path_cache_dir): 487 | if not "path" in self.action_type: 488 | return None 489 | 490 | cache_path = os.path.join( 491 | path_cache_dir, f"{self.root_path.split('/')[-1]}.json.gz" 492 | ) 493 | if not os.path.exists(cache_path): 494 | return None 495 | 496 | path_cache = self.gzip_json_load(cache_path, root_path="") 497 | 498 | self.path_cache = np.asarray(path_cache) 499 | 500 | def load_bucket_cache(self, bucket_cache_dir): 501 | cache_path = os.path.join( 502 | bucket_cache_dir, f"{self.root_path.split('/')[-1]}.npz" 503 | ) 504 | if not os.path.exists(cache_path): 505 | return None 506 | 507 | bucket_cache = np.load(cache_path) 508 | 509 | bucket_names = bucket_cache["bucket_names"] 510 | # cast names to string 511 | bucket_names = [str(x) for x in bucket_names] 512 | buckets = bucket_cache["buckets"] 513 | 514 | self.buckets = buckets 515 | self.bucket_names = bucket_names 516 | -------------------------------------------------------------------------------- /carformer/carformer/data/data_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchvision.transforms import Compose, Resize, CenterCrop 4 | 5 | 6 | def get_rgb_preprocessing_function_from_config(config): 7 | return get_rgb_preprocessing_function( 8 | config.rgb_crop.type, 9 | config.rgb_crop.crop_size, 10 | config.rgb_crop.resize, 11 | ) 12 | 13 | 14 | def get_rgb_preprocessing_function(crop_type, crop_size, resize): 15 | transforms = [] 16 | 17 | if crop_type == "dualcenter": 18 | # Width: 2xcrop_size, height: crop_size 19 | # Then resize to 2xresize, resize 20 | transforms.append(CenterCrop((crop_size, 2 * crop_size))) 21 | if resize > 0: 22 | transforms.append(Resize((resize, 2 * resize))) 23 | elif crop_type == "center": 24 | transforms.append(CenterCrop(crop_size)) 25 | if resize > 0: 26 | transforms.append(Resize(resize)) 27 | else: 28 | raise NotImplementedError 29 | 30 | transform = Compose(transforms) 31 | 32 | return transform 33 | 34 | 35 | def get_virtual_lidar_to_vehicle_transform(): 36 | # This is a fake lidar coordinate 37 | T = np.eye(4) 38 | T[0, 3] = 1.3 39 | T[1, 3] = 0.0 40 | T[2, 3] = 2.5 41 | return T 42 | 43 | 44 | def get_vehicle_to_virtual_lidar_transform(): 45 | return np.linalg.inv(get_virtual_lidar_to_vehicle_transform()) 46 | 47 | 48 | def transform_waypoints(waypoints): 49 | """transform waypoints to be origin at ego_matrix""" 50 | vehicle_matrix = np.array(waypoints[0]) 51 | for i in range( 52 | 0, len(waypoints) 53 | ): # TODO: Start from 1 because 0 is ego vehicle initial position 54 | matrix = np.array(waypoints[i]) 55 | waypoints[i] = vehicle_matrix @ np.linalg.inv(matrix) 56 | 57 | return waypoints 58 | 59 | 60 | ######################################################################## 61 | 62 | ############## WARPING 63 | 64 | ###################################################################3#### 65 | 66 | 67 | def compute_relative_transform(origin, current, pix_per_meter=5): 68 | result = torch.bmm(torch.linalg.inv(origin), (current)) 69 | 70 | return result 71 | 72 | 73 | def get_affine_grid_transform(origin, current, inp_size=400, pix_per_meter=5): 74 | relative_transform = compute_relative_transform(origin, current, pix_per_meter) 75 | 76 | translation = relative_transform[:, :2, 3:] / ((inp_size / 2) / pix_per_meter) 77 | translation[:, [0, 1]] = translation[:, [1, 0]] 78 | 79 | affine_grid_transform = torch.cat( 80 | (torch.transpose(relative_transform[:, :2, :2], 2, 1), translation), axis=2 81 | ) 82 | 83 | # rot x, y. dont take height. 84 | # affine_grid_transform = torch.from_numpy(affine_grid_transform).float() 85 | 86 | return affine_grid_transform 87 | 88 | 89 | def warp_sequence(x, ego_matrices, mode="nearest", spatial_extent=None): 90 | """ 91 | Batch-compatible warping function. 92 | 93 | Warps a sequence based on the first frame using ego vehicle transformation matrices. 94 | """ 95 | sequence_length = x.shape[1] 96 | if sequence_length == 1: 97 | return x 98 | 99 | out = [x[:, 0]] 100 | 101 | # print('X. SHAPE ', x.shape) 102 | base_frame = ego_matrices[:, 0] # torch.from_numpy() 103 | 104 | for t in range(1, sequence_length): 105 | curr_frame = ego_matrices[:, t] # torch.from_numpy() 106 | aff_grid = get_affine_grid_transform( 107 | base_frame, curr_frame, inp_size=x.shape[-1], pix_per_meter=5 108 | ) # .unsqueeze(0) 109 | 110 | grid = torch.nn.functional.affine_grid( 111 | aff_grid, size=x[:, 0].shape, align_corners=False 112 | ) 113 | 114 | warped_bev = torch.nn.functional.grid_sample( 115 | (x[:, t]), 116 | grid.float(), 117 | mode="nearest", 118 | padding_mode="zeros", 119 | align_corners=False, 120 | ) 121 | 122 | out.append(warped_bev) 123 | 124 | return torch.stack(out, 1) 125 | 126 | 127 | ### Forecast utils 128 | def extract_forecast_targets( 129 | timesteps, 130 | context_length, 131 | future_horizon, 132 | forecast_steps=1, 133 | use_slots=True, 134 | object_level=True, 135 | ): 136 | assert len(timesteps) == context_length + future_horizon 137 | for i in range(context_length): 138 | timesteps[i].state["target_rgb_front"] = timesteps[i + forecast_steps].state[ 139 | "rgb_front" 140 | ] 141 | 142 | 143 | def circle_line_segment_intersection( 144 | circle_center, circle_radius, pt1, pt2, full_line=True, tangent_tol=1e-9 145 | ): 146 | """Find the points at which a circle intersects a line-segment. This can happen at 0, 1, or 2 points. 147 | 148 | :param circle_center: The (x, y) location of the circle center 149 | :param circle_radius: The radius of the circle 150 | :param pt1: The (x, y) location of the first point of the segment 151 | :param pt2: The (x, y) location of the second point of the segment 152 | :param full_line: True to find intersections along full line - not just in the segment. 153 | False will just return intersections within the segment. 154 | :param tangent_tol: Numerical tolerance at which we decide the intersections are close enough to consider it a 155 | tangent 156 | :return Sequence[Tuple[float, float]]: A list of length 0, 1, or 2, where each element is a point at which the 157 | circle intercepts a line segment. 158 | 159 | Note: We follow: http://mathworld.wolfram.com/Circle-LineIntersection.html 160 | Credit: https://stackoverflow.com/a/59582674/9173068 161 | """ 162 | 163 | if np.linalg.norm(pt1 - pt2) < 0.000000001: 164 | # print('Problem') 165 | return [] 166 | 167 | (p1x, p1y), (p2x, p2y), (cx, cy) = pt1, pt2, circle_center 168 | (x1, y1), (x2, y2) = (p1x - cx, p1y - cy), (p2x - cx, p2y - cy) 169 | dx, dy = (x2 - x1), (y2 - y1) 170 | dr = (dx**2 + dy**2) ** 0.5 171 | big_d = x1 * y2 - x2 * y1 172 | discriminant = circle_radius**2 * dr**2 - big_d**2 173 | 174 | if discriminant < 0: # No intersection between circle and line 175 | return [] 176 | else: # There may be 0, 1, or 2 intersections with the segment 177 | # This makes sure the order along the segment is correct 178 | # intersections = [(cx + (big_d * dy + sign * (-1 if dy < 0 else 1) * dx * discriminant**.5) / dr**2, 179 | # cy + (-big_d * dx + sign * abs(dy) * discriminant**.5) / dr**2) 180 | # for sign in ((1, -1) if dy < 0 else (-1, 1))] 181 | 182 | # Write explicitly to avoid iteration 183 | if dy < 0: 184 | sign_1 = 1 185 | sign_2 = -1 186 | else: 187 | sign_1 = -1 188 | sign_2 = 1 189 | 190 | intersections = [ 191 | ( 192 | cx + (big_d * dy + sign_1 * sign_2 * dx * discriminant**0.5) / dr**2, 193 | cy + (-big_d * dx + sign_1 * abs(dy) * discriminant**0.5) / dr**2, 194 | ), 195 | ( 196 | cx + (big_d * dy + sign_2 * sign_2 * dx * discriminant**0.5) / dr**2, 197 | cy + (-big_d * dx + sign_2 * abs(dy) * discriminant**0.5) / dr**2, 198 | ), 199 | ] 200 | 201 | if ( 202 | not full_line 203 | ): # If only considering the segment, filter out intersections that do not fall within the segment 204 | fraction_along_segment = [ 205 | (xi - p1x) / dx if abs(dx) > abs(dy) else (yi - p1y) / dy 206 | for xi, yi in intersections 207 | ] 208 | intersections = [ 209 | pt 210 | for pt, frac in zip(intersections, fraction_along_segment) 211 | if 0 <= frac <= 1 212 | ] 213 | # If line is tangent to circle, return just one point (as both intersections have same location) 214 | if len(intersections) == 2 and abs(discriminant) <= tangent_tol: 215 | return [intersections[0]] 216 | else: 217 | return intersections 218 | 219 | 220 | import numpy as np 221 | 222 | 223 | def iterative_line_interpolation(route, num_points=20): 224 | if not isinstance(route, np.ndarray): 225 | route = np.array(route) 226 | 227 | interpolated_route_points = [] 228 | 229 | min_distance = 0.5 230 | last_interpolated_point = np.array([0.0, 0.0]) 231 | current_route_index = 0 232 | current_point = route[current_route_index] 233 | last_point = route[current_route_index] 234 | 235 | while len(interpolated_route_points) < num_points: 236 | # First point should be min_distance away from the vehicle. 237 | dist = np.linalg.norm(current_point - last_interpolated_point) 238 | if dist < min_distance: 239 | current_route_index += 1 240 | last_point = current_point 241 | 242 | if current_route_index < route.shape[0]: 243 | current_point = route[current_route_index] 244 | intersection = circle_line_segment_intersection( 245 | circle_center=last_interpolated_point, 246 | circle_radius=min_distance, 247 | pt1=last_point, 248 | pt2=current_point, 249 | full_line=False, 250 | ) 251 | 252 | else: # We hit the end of the input route. We extrapolate the last 2 points 253 | current_point = route[-1] 254 | last_point = route[-2] 255 | intersection = circle_line_segment_intersection( 256 | circle_center=last_interpolated_point, 257 | circle_radius=min_distance, 258 | pt1=last_point, 259 | pt2=current_point, 260 | full_line=True, 261 | ) 262 | 263 | # 3 cases: 0 intersection, 1 intersection, 2 intersection 264 | if len(intersection) > 1: # 2 intersections 265 | # Take the one that is closer to current point 266 | point_1 = np.array(intersection[0]) 267 | point_2 = np.array(intersection[1]) 268 | direction = current_point - last_point 269 | dot_p1_to_last = np.dot(point_1, direction) 270 | dot_p2_to_last = np.dot(point_2, direction) 271 | 272 | if dot_p1_to_last > dot_p2_to_last: 273 | intersection_point = point_1 274 | else: 275 | intersection_point = point_2 276 | add_point = True 277 | elif len(intersection) == 1: # 1 Intersections 278 | intersection_point = np.array(intersection[0]) 279 | add_point = True 280 | else: # 0 Intersection 281 | add_point = False 282 | 283 | if add_point: 284 | last_interpolated_point = intersection_point 285 | interpolated_route_points.append(intersection_point) 286 | min_distance = 1.0 # After the first point we want each point to be 1 m away from the last. 287 | 288 | interpolated_route_points = np.array(interpolated_route_points) 289 | 290 | return interpolated_route_points 291 | 292 | 293 | def interpolate_waypoints(wps, num_points=10): 294 | assert len(wps) > 1, "Need at least 2 waypoints to interpolate" 295 | last_vector = wps[-1] - wps[-2] 296 | 297 | if len(wps) >= num_points: 298 | return wps[:num_points] 299 | 300 | interpolated_points = [] 301 | for i in range(len(wps) - num_points): 302 | # interpolated_points.append(wps[i] + last_vector * (i + 1)) 303 | interpolated_points.append((-99999, -99999)) 304 | 305 | return np.concatenate([wps, np.stack(interpolated_points)], axis=0) 306 | 307 | 308 | class iterative_intepolator: 309 | def __init__(self, num_points=20): 310 | self.num_points = num_points 311 | self.min_distance = 0.5 312 | self.last_interpolated_point = np.array([0.0, 0.0]) 313 | self.current_route_index = 0 314 | self.current_point = None 315 | self.last_point = None 316 | self.interpolated_route_points = [] 317 | self.last_inputs = [] 318 | 319 | def __call__(self, point): 320 | if self.current_point is None: 321 | self.current_point = point 322 | self.last_point = point 323 | 324 | if point is not None: 325 | self.last_inputs.append(point) 326 | self.last_inputs = self.last_inputs[-2:] 327 | 328 | dist = np.linalg.norm(self.current_point - self.last_interpolated_point) 329 | if dist < self.min_distance: 330 | # self.current_route_index += 1 331 | self.last_point = self.current_point 332 | return len(self) 333 | 334 | self.current_point = np.asarray(point) 335 | intersection = circle_line_segment_intersection( 336 | circle_center=self.last_interpolated_point, 337 | circle_radius=self.min_distance, 338 | pt1=self.last_point, 339 | pt2=self.current_point, 340 | full_line=False, 341 | ) 342 | else: 343 | current_point = self.last_inputs[-1] 344 | last_point = self.last_inputs[-2] 345 | intersection = circle_line_segment_intersection( 346 | circle_center=last_interpolated_point, 347 | circle_radius=min_distance, 348 | pt1=last_point, 349 | pt2=current_point, 350 | full_line=True, 351 | ) 352 | 353 | if len(intersection) > 1: # 2 intersections 354 | # Take the one that is closer to current point 355 | point_1 = np.array(intersection[0]) 356 | point_2 = np.array(intersection[1]) 357 | direction = current_point - last_point 358 | dot_p1_to_last = np.dot(point_1, direction) 359 | dot_p2_to_last = np.dot(point_2, direction) 360 | 361 | if dot_p1_to_last > dot_p2_to_last: 362 | intersection_point = point_1 363 | else: 364 | intersection_point = point_2 365 | elif len(intersection) == 1: # 1 Intersections 366 | intersection_point = np.array(intersection[0]) 367 | else: # 0 Intersection 368 | return len(self) 369 | 370 | self.last_interpolated_point = intersection_point 371 | self.interpolated_route_points.append(intersection_point) 372 | self.min_distance = 1.0 # After the first point we want each point to be 1 m away from the last. 373 | 374 | def get_interpolated_points(self): 375 | if len(self.interpolated_route_points) < self.num_points: 376 | for i in range(len(self.interpolated_route_points), self.num_points): 377 | self(None) 378 | 379 | print(len(self)) 380 | 381 | return self.interpolated_route_points 382 | 383 | def __len__(self): 384 | return self.num_points 385 | 386 | 387 | # Adapted from https://github.com/OpenDriveLab/TCP/blob/9ec4db0f0424801cdd607f1de930290830c5e88e/leaderboard/team_code/auto_pilot.py#L339 388 | def get_hazard_directions(vehicle_list): 389 | ego_vehicles = [x for x in vehicle_list if x["class"] == "ego_vehicle"] 390 | 391 | if len(ego_vehicles) == 0: 392 | return [] 393 | 394 | if len(ego_vehicles) > 1: 395 | print("More than one ego vehicle found") 396 | return [] 397 | 398 | ego_vehicle = ego_vehicles[0] 399 | 400 | z = ego_vehicle["location"][1] 401 | 402 | o1 = _orientation(ego_vehicle["rotation"][-1]) 403 | p1 = np.asarray(ego_vehicle["location"][:2]) 404 | s1 = max(2, 3.0 * ego_vehicle["speed"]) # increases the threshold distance 405 | v1_hat = o1 406 | v1 = s1 * v1_hat 407 | 408 | hazard_directions = [] 409 | 410 | for target_vehicle in vehicle_list: 411 | if target_vehicle["class"] == "ego_vehicle": 412 | continue 413 | 414 | if target_vehicle.get("base_type", None) != "car": 415 | continue 416 | 417 | o2 = _orientation(target_vehicle["rotation"][-1]) 418 | p2 = np.asarray(target_vehicle["location"][:2]) 419 | s2 = max(5.0, 2.0 * target_vehicle["speed"]) 420 | v2_hat = o2 421 | v2 = s2 * v2_hat 422 | 423 | p2_p1 = p2 - p1 424 | distance = np.linalg.norm(p2_p1) 425 | p2_p1_hat = p2_p1 / (distance + 1e-4) 426 | 427 | angle_to_car = np.degrees(np.arccos(v1_hat.dot(p2_p1_hat))) 428 | 429 | angle_between_heading = np.degrees(np.arccos(np.clip(o1.dot(o2), -1, 1))) 430 | 431 | # print() 432 | angle_from_ego = np.degrees(np.arccos(v2_hat.dot(p2_p1_hat))) 433 | 434 | # to consider -ve angles too 435 | angle_to_car = min(angle_to_car, 360.0 - angle_to_car) 436 | angle_between_heading = min( 437 | angle_between_heading, 360.0 - angle_between_heading 438 | ) 439 | 440 | if angle_between_heading > 60.0 and not (angle_to_car < 15 and distance < s1): 441 | continue 442 | elif angle_to_car > 30.0: 443 | continue 444 | elif distance > s1: 445 | continue 446 | 447 | # print(target_vehicle["type_id"], target_vehicle["distance"], target_vehicle["color"]) 448 | # print("s1", s1, "dist", distance, "tgt_dist", target_vehicle["distance"]) 449 | # print(angle_to_car, angle_between_heading, distance, s1) 450 | # print("angle from ego", angle_from_ego) 451 | 452 | hazard_directions.append(angle_from_ego) 453 | 454 | return hazard_directions 455 | 456 | 457 | def _orientation(yaw): 458 | return np.float32([np.cos(np.radians(yaw)), np.sin(np.radians(yaw))]) 459 | 460 | 461 | def get_collision(p1, v1, p2, v2): 462 | A = np.stack([v1, -v2], 1) 463 | b = p2 - p1 464 | 465 | if abs(np.linalg.det(A)) < 1e-3: 466 | return False, None 467 | 468 | x = np.linalg.solve(A, b) 469 | collides = all(x >= 0) and all(x <= 4) # how many seconds until collision 470 | 471 | return collides, p1 + x[0] * v1 472 | 473 | 474 | def is_walker_hazard(objects_list): 475 | ego_vehicles = [x for x in objects_list if x["class"] == "ego_vehicle"] 476 | if len(ego_vehicles) == 0: 477 | print("No ego vehicle found") 478 | return False 479 | 480 | ego_vehicle = ego_vehicles[0] 481 | 482 | z = ego_vehicle["location"][-1] 483 | 484 | walkers = [x for x in objects_list if x["class"] == "walker"] 485 | p1 = np.asarray(ego_vehicle["location"][:2]) 486 | v1 = 10.0 * _orientation(ego_vehicle["rotation"][-1]) 487 | 488 | for walker in walkers: 489 | v2_hat = _orientation(walker["rotation"][-1]) 490 | s2 = walker["speed"] 491 | 492 | if s2 < 0.05: 493 | v2_hat *= s2 494 | 495 | p2 = -3.0 * v2_hat + np.asarray(walker["location"][:2]) 496 | v2 = 8.0 * v2_hat 497 | 498 | collides, collision_point = get_collision(p1, v1, p2, v2) 499 | 500 | if collides: 501 | return True 502 | 503 | return False 504 | -------------------------------------------------------------------------------- /carformer/carformer/data/utils.py: -------------------------------------------------------------------------------- 1 | from carformer.data import ( 2 | B2DSequenceDataset, 3 | DatasetPreloader, 4 | InMemoryDatasetPreloader, 5 | ) 6 | import os 7 | 8 | 9 | def get_datasets(config, model=None, return_all=False, splits=["train", "val"]): 10 | if return_all: 11 | return _get_entire_dataset( 12 | config.data_dir, 13 | config.training, 14 | config.dataset.data_format, 15 | preload=config.preload, 16 | preload_in_memory=config.preload_in_memory, 17 | wipe_cache=config.wipe_cache, 18 | cache_dir=config.cache_dir, 19 | model=model, 20 | ) 21 | 22 | return _get_datasets( 23 | config.data_dir, 24 | config.training, 25 | config.dataset.data_format, 26 | preload=config.preload, 27 | preload_in_memory=config.preload_in_memory, 28 | wipe_cache=config.wipe_cache, 29 | cache_dir=config.cache_dir, 30 | model=model, 31 | splits=splits, 32 | ) 33 | 34 | 35 | def _get_datasets( 36 | data_dir, 37 | train_cfg, 38 | data_format, 39 | preload=False, 40 | preload_in_memory=False, 41 | wipe_cache=False, 42 | cache_dir="", 43 | model=None, 44 | splits=["train", "val"], 45 | ): 46 | # If is_plant is a boolean, give a deprecated warning 47 | if isinstance(data_format, bool): 48 | print( 49 | "Warning: is_plant boolean is deprecated, please use the data_format field in the dataset configuration" 50 | ) 51 | 52 | if data_format == "plant": 53 | data_module = PlantSequenceDataset 54 | elif data_format == "b2d": 55 | data_module = B2DSequenceDataset 56 | elif data_format == "pdm": 57 | data_module = PDMSequenceDataset 58 | elif data_format == "tf": 59 | raise ValueError("Transfuser dataset not supported") 60 | # data_module = SequenceDataset 61 | else: 62 | raise ValueError(f"Invalid data format {data_format}") 63 | 64 | if "train" in splits: 65 | train_dataset = data_module( 66 | data_dir, 67 | train_cfg.splits.train, 68 | train_cfg, 69 | ) 70 | else: 71 | train_dataset = None 72 | 73 | if "val" in splits: 74 | val_dataset = data_module( 75 | data_dir, 76 | train_cfg.splits.val, 77 | train_cfg, 78 | ) 79 | else: 80 | val_dataset = None 81 | 82 | if preload: 83 | assert cache_dir != "", "Cache dir must be specified if preloading is enabled" 84 | 85 | preloader = ( 86 | DatasetPreloader if not preload_in_memory else InMemoryDatasetPreloader 87 | ) 88 | args = [] 89 | 90 | if train_dataset is not None: 91 | train_dataset = preloader( 92 | train_dataset, 93 | os.path.join(cache_dir, train_dataset.get_parametrized_dirname()), 94 | *args, 95 | wipe_cache=wipe_cache, 96 | ) 97 | 98 | if val_dataset is not None: 99 | val_dataset = preloader( 100 | val_dataset, 101 | os.path.join(cache_dir, val_dataset.get_parametrized_dirname()), 102 | *args, 103 | wipe_cache=wipe_cache, 104 | ) 105 | if train_dataset is not None: 106 | train_dataset.load_state() 107 | 108 | if val_dataset is not None: 109 | val_dataset.load_state() 110 | 111 | return train_dataset, val_dataset 112 | 113 | 114 | def _get_entire_dataset( 115 | data_dir, 116 | train_cfg, 117 | is_plant, 118 | preload=False, 119 | preload_in_memory=False, 120 | wipe_cache=False, 121 | cache_dir="", 122 | model=None, 123 | ): 124 | if data_format == "b2d": 125 | data_module = B2DSequenceDataset 126 | else: 127 | raise ValueError(f"Invalid data format {data_format}") 128 | 129 | all_dataset = data_module( 130 | data_dir, 131 | "all", 132 | train_cfg, 133 | ) 134 | 135 | if preload: 136 | assert cache_dir != "", "Cache dir must be specified if preloading is enabled" 137 | 138 | preloader = ( 139 | DatasetPreloader if not preload_in_memory else InMemoryDatasetPreloader 140 | ) 141 | 142 | args = [] 143 | 144 | all_dataset = preloader( 145 | all_dataset, 146 | os.path.join(cache_dir, all_dataset.get_parametrized_dirname()), 147 | *args, 148 | wipe_cache=wipe_cache, 149 | ) 150 | 151 | all_dataset.load_state() 152 | 153 | return all_dataset 154 | -------------------------------------------------------------------------------- /carformer/carformer/data/wrapper.py: -------------------------------------------------------------------------------- 1 | # A wrapper around a dataset 2 | # First, it loads batches, 1 instance at a time, then writes it into a disk cache 3 | # Future calls to load batches will load from the cache instead of the original dataset 4 | # This is to avoid the overhead of loading from disk every time 5 | 6 | from skit import DatasetPreloader, InMemoryDatasetPreloader, AugmentableDatasetPreloader 7 | -------------------------------------------------------------------------------- /carformer/carformer/perception/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/carformer/carformer/perception/__init__.py -------------------------------------------------------------------------------- /carformer/carformer/perception/rgb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import yaml 4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CONFIG_MAPPING 5 | 6 | try: 7 | from transformers.models.auto.image_processing_auto import ( 8 | image_processor_class_from_name, 9 | ) 10 | except ImportError: 11 | from transformers.models.auto.image_processing_auto import ( 12 | get_image_processor_class_from_name as image_processor_class_from_name, 13 | ) 14 | 15 | rgb_models = {} 16 | 17 | 18 | class RGBEncoder(nn.Module): 19 | def __init__(self, config): 20 | super(RGBEncoder, self).__init__() 21 | # Load config 22 | if isinstance(config, str): 23 | with open(config, "r") as file: 24 | try: 25 | config = yaml.safe_load(file) 26 | self.config = config 27 | except yaml.YAMLError as exc: 28 | print(exc) 29 | raise exc 30 | else: 31 | self.config = config 32 | 33 | pretrained_path = self.config.get("model_path", None) 34 | pretrained_override_kwargs = self.config.get("override_kwargs", {}) 35 | 36 | self.ema_enabled = self.config.get("ema_enabled", False) 37 | self.keep_time_dim = self.config.get("keep_time_dim", False) 38 | 39 | if pretrained_path is not None: 40 | print( 41 | "Overriding main model config using kwargs: ", 42 | pretrained_override_kwargs, 43 | ) 44 | self.cfg = AutoConfig.from_pretrained( 45 | pretrained_path, **pretrained_override_kwargs 46 | ) 47 | 48 | self.processor = AutoImageProcessor.from_pretrained(pretrained_path) 49 | self.model = AutoModel.from_pretrained( 50 | pretrained_path, **pretrained_override_kwargs 51 | ) 52 | 53 | self.config["override_kwargs"] = {} 54 | self.config["model_config"] = self.cfg.to_dict() 55 | self.config["processor_config"] = self.processor.to_dict() 56 | 57 | if self.ema_enabled: 58 | self.ema_model = AutoModel.from_pretrained( 59 | pretrained_path, **pretrained_override_kwargs 60 | ) 61 | else: 62 | self.ema_model = None 63 | else: 64 | assert ( 65 | "model_config" in self.config 66 | ), "model_config must be provided if model_path is not provided" 67 | assert ( 68 | "processor_config" in self.config 69 | ), "processor_config must be provided if model_path is not provided" 70 | 71 | model_type = self.config["model_config"]["model_type"] 72 | image_processor_type = self.config["processor_config"][ 73 | "image_processor_type" 74 | ] 75 | 76 | self.cfg = CONFIG_MAPPING[model_type].from_dict(self.config["model_config"]) 77 | self.processor = image_processor_class_from_name( 78 | image_processor_type 79 | ).from_dict(self.config["processor_config"]) 80 | self.model = AutoModel.from_config(self.cfg) 81 | 82 | if self.ema_enabled: 83 | self.ema_model = AutoModel.from_config(self.cfg) 84 | else: 85 | self.ema_model = None 86 | 87 | self.projector = nn.Linear( 88 | self.cfg.hidden_size, self.config.projection_dim 89 | ) # self.cfg: AutoConfig, backbone specific. self.config: user provided config, contains main backbone hidden dim 90 | 91 | self.select_layer = self.config.get("select_layer", None) 92 | # Dummy rgb encoder utils 93 | self.dummy_output = self.config.get("dummy_output", False) 94 | 95 | self.move_channels_last_and_flatten = self.config.get( 96 | "move_channels_last_and_flatten", False 97 | ) 98 | self.ignore_cls = self.config.get("ignore_cls", True) 99 | 100 | if self.select_layer is None: 101 | self.select_layer = -2 102 | print("Warning: select_layer not provided, using default -2") 103 | 104 | init_config = self.config.get("init_from_ckpt", None) 105 | if init_config is not None and init_config["enabled"]: 106 | init_config = init_config.copy() 107 | self.config["init_from_ckpt"]["enabled"] = False # Do not init again 108 | ckpt = torch.load(init_config["ckpt_path"], map_location="cpu") 109 | if init_config["model"]: 110 | mdl_dct = { 111 | k[len("model.") :]: v 112 | for k, v in ckpt.items() 113 | if k.startswith("model.") 114 | } 115 | 116 | self.model.load_state_dict(mdl_dct, strict=True) 117 | 118 | if init_config["ema_model"] and self.ema_enabled: 119 | ema_dct = { 120 | k[len("model.") :]: v 121 | for k, v in ckpt.items() 122 | if k.startswith("model.") 123 | } 124 | self.ema_model.load_state_dict(ema_dct, strict=True) 125 | 126 | if init_config["projector"]: 127 | proj_dct = { 128 | k[len("projector.") :]: v 129 | for k, v in ckpt.items() 130 | if k.startswith("projector.") 131 | } 132 | self.projector.load_state_dict(proj_dct, strict=True) 133 | 134 | self.config["frozen"]["projector"] = init_config["projector"] 135 | self.config["frozen"]["model"] = init_config["model"] 136 | self.config["frozen"]["ema_model"] = init_config["ema_model"] 137 | 138 | self.try_to_truncate_layers = self.config.get("try_to_truncate_layers", False) 139 | 140 | if self.try_to_truncate_layers: 141 | import transformers 142 | 143 | if isinstance( 144 | self.model, transformers.models.clip.modeling_clip.CLIPVisionModel 145 | ): 146 | layer_to_select = self.select_layer 147 | if layer_to_select < 0: 148 | layer_to_select = ( 149 | len(self.model.vision_model.encoder.layers) + layer_to_select 150 | ) 151 | 152 | self.model.vision_model.encoder.layers = ( 153 | self.model.vision_model.encoder.layers[: layer_to_select + 1] 154 | ) 155 | 156 | if self.ema_enabled: 157 | self.ema_model.vision_model.encoder.layers = ( 158 | self.ema_model.vision_model.encoder.layers[ 159 | : layer_to_select + 1 160 | ] 161 | ) 162 | 163 | self.select_layer = -1 164 | else: 165 | print( 166 | "Warning: try_to_truncate_layers is set to True, but model is not CLIP. " 167 | "This will not truncate any layers." 168 | ) 169 | 170 | frozen = self.config.get("frozen", False) 171 | self.disable_ema_update = False 172 | if isinstance(frozen, bool): 173 | if frozen: 174 | self.model.requires_grad_(False) 175 | self.model.eval() 176 | self.frozen = True 177 | else: 178 | self.frozen = False 179 | self.frozen_dict = {} 180 | else: 181 | if frozen["model"]: 182 | self.model.requires_grad_(False) 183 | self.model.eval() 184 | if frozen["ema_model"]: 185 | if not self.ema_enabled: 186 | self.model.requires_grad_(False) 187 | self.model.eval() 188 | else: 189 | self.ema_model.requires_grad_(False) 190 | self.ema_model.eval() 191 | self.disable_ema_update = True 192 | if frozen["projector"]: 193 | self.projector.requires_grad_(False) 194 | self.frozen = False 195 | self.frozen_dict = { 196 | "model": frozen["model"] 197 | or (frozen["ema_model"] and not self.ema_enabled), 198 | "ema_model": frozen["ema_model"], 199 | "projector": frozen["projector"], 200 | } 201 | 202 | self.output_whole = self.config["outputs"]["whole"] 203 | self.output_patches = self.config["outputs"]["patches"] 204 | 205 | # center crop of width self.config.input_size*2, height self.config.input_size 206 | def center_crop(x): 207 | wdesired = self.config.input_size * 2 208 | hdesired = self.config.input_size 209 | wcur = x.shape[-2] 210 | hcur = x.shape[-3] 211 | 212 | return x[ 213 | ..., 214 | (hcur - hdesired) // 2 : (hcur + hdesired) // 2, 215 | (wcur - wdesired) // 2 : (wcur + wdesired) // 2, 216 | :, 217 | ] 218 | 219 | self.center_crop = center_crop 220 | 221 | self.camera_embeddings = nn.Embedding(2, self.cfg.hidden_size) 222 | 223 | self.masking_rate = self.config.get("masking_rate", 0.0) 224 | 225 | self.do_downsample = self.config.get("downsample", False) 226 | self.downsample_type = self.config.get("downsample_type", "conv") 227 | 228 | if self.do_downsample: 229 | if self.downsample_type == "conv": 230 | self.downsample = nn.Conv2d( 231 | self.cfg.hidden_size, 232 | self.cfg.hidden_size, 233 | kernel_size=2, 234 | stride=2, 235 | padding=0, 236 | bias=False, 237 | ) 238 | elif self.downsample_type == "avgpool": 239 | self.downsample = nn.AvgPool2d(2) 240 | else: 241 | raise ValueError("Invalid downsample type") 242 | else: 243 | self.downsample = None 244 | 245 | # Display warning that output_whole, patches is ignored currently and will be implemented later 246 | print( 247 | "Warning: output_whole, output_patches is ignored currently and will be implemented later" 248 | ) 249 | 250 | def forward(self, x, y=None, ema=False): 251 | # Dims: B, T, NPatch, 3, H, W 252 | if self.frozen: 253 | self.model.eval() 254 | 255 | with torch.no_grad(): 256 | return self.encode(x, y, ema=ema) 257 | else: 258 | return self.encode(x, y, ema=ema) 259 | 260 | def encode(self, x, y=None, ema=False): 261 | # print("x shape: ", x.shape, "y shape: ", y.shape if y is not None else None) 262 | # x = self.center_crop(x) # TODO: This is too hacky, fix this 263 | original_shape = x.shape 264 | if len(original_shape) == 5: 265 | T = original_shape[1] 266 | x = x.reshape( 267 | original_shape[0] * original_shape[1], *original_shape[2:] 268 | ) # B*T, H, W, 3 269 | else: 270 | T = 1 271 | # Encode 272 | prepped_x = self.processor( 273 | list(x.squeeze(1).permute(0, 3, 1, 2).half()), return_tensors="pt" 274 | )["pixel_values"] 275 | prepped_y = ( 276 | self.processor( 277 | list(y.squeeze(1).permute(0, 3, 1, 2).half()), 278 | return_tensors="pt", 279 | do_normalize=False, 280 | do_convert_rgb=False, 281 | )["pixel_values"] 282 | if y is not None 283 | else None 284 | ) 285 | 286 | # Prep the inputs before feeding to the vision encoders 287 | # Remove first patch, which is the entire image 288 | prepped_x = prepped_x[:, 1:] 289 | if self.frozen_dict.get("ema_model", False) and self.ema_enabled: 290 | self.ema_model.eval() 291 | 292 | if self.frozen_dict.get("model", False): 293 | self.model.eval() 294 | 295 | if prepped_y is not None: 296 | prepped_y = prepped_y[:, 1:] 297 | 298 | B, Npatch = prepped_x.shape[:2] 299 | By = prepped_y.shape[0] if prepped_y is not None else None 300 | 301 | if self.dummy_output: 302 | return self.projector( 303 | torch.zeros(B, Npatch, self.cfg.hidden_size) 304 | .to(self.model.device) 305 | .to(self.model.dtype) 306 | ) 307 | 308 | # Flatten, encode 309 | if self.keep_time_dim: 310 | prepped_x = prepped_x.reshape(B // T, T, *prepped_x.shape[1:]) 311 | # swap dim 1 and 2 312 | prepped_x = prepped_x.permute(0, 2, 1, 3, 4, 5) 313 | 314 | if prepped_y is not None: 315 | prepped_y = prepped_y.reshape(By, 1, *prepped_y.shape[1:]) 316 | prepped_y = prepped_y.permute(0, 2, 1, 3, 4, 5) 317 | 318 | elif ema and self.ema_enabled: 319 | vision_feats = self.ema_model( 320 | prepped_x.flatten(0, 1) 321 | .to(self.ema_model.device) 322 | .to(self.ema_model.dtype), 323 | output_hidden_states=True, 324 | )["hidden_states"] 325 | else: 326 | 327 | vision_feats = self.model( 328 | prepped_x.flatten(0, 1).to(self.model.device).to(self.model.dtype), 329 | output_hidden_states=True, 330 | )["hidden_states"] 331 | 332 | if prepped_y is not None: 333 | with torch.no_grad(): 334 | if hasattr(self.model, "vision_model"): 335 | vision_patch_labels = ( 336 | self.model.vision_model.embeddings.patch_embedding( 337 | prepped_y.flatten(0, 1) 338 | .to(self.model.device) 339 | .to(self.model.dtype) 340 | ) 341 | ) 342 | else: 343 | raise ValueError( 344 | "Model does not have a patcher, cannot process labels" 345 | ) 346 | 347 | if self.do_downsample: 348 | vision_patch_labels = self.downsample(vision_patch_labels) 349 | 350 | vision_patch_labels = vision_patch_labels.abs().sum(1).flatten(1, 2) 351 | 352 | if self.ignore_cls: 353 | # Vision feat to use, default from llama, ignore cls 354 | vision_feats = vision_feats[self.select_layer][:, 1:] 355 | else: 356 | assert y is None, "y is not None, but ignore_cls is False" 357 | assert self.do_downsample is False, "Cannot downsample with cls token" 358 | vision_feats = vision_feats[self.select_layer] 359 | 360 | if self.move_channels_last_and_flatten: 361 | vision_feats = vision_feats.permute(0, 2, 3, 1).flatten(1, 2) 362 | 363 | if self.do_downsample: 364 | vision_feat_side_size = int(vision_feats.shape[1] ** 0.5) 365 | 366 | vision_feats = vision_feats.view( 367 | B * Npatch, vision_feat_side_size, vision_feat_side_size, -1 368 | ) 369 | # Permute hidden dim to channel dim 370 | vision_feats = vision_feats.permute(0, 3, 1, 2) 371 | 372 | vision_feats = self.downsample(vision_feats) 373 | 374 | vision_feats = vision_feats.permute(0, 2, 3, 1).flatten( 375 | 1, 2 376 | ) # B*Npatch, H, W, C -> B*Npatch, H'*W', C 377 | 378 | vision_feats = vision_feats.view( 379 | B // T if self.keep_time_dim else B, Npatch, *vision_feats.shape[1:] 380 | ) 381 | vision_patch_labels = ( 382 | vision_patch_labels.view(By, Npatch, *vision_patch_labels.shape[1:]) 383 | if prepped_y is not None 384 | else None 385 | ) 386 | vision_feats = vision_feats + self.camera_embeddings( 387 | torch.tensor([0, 1], device=vision_feats.device) 388 | ).unsqueeze(0).unsqueeze(-2) 389 | 390 | # Merge Npatch with vectors per image 391 | vision_feats = vision_feats.flatten(1, 2) 392 | if prepped_y is not None: 393 | vision_patch_labels = vision_patch_labels.flatten(1, 2) 394 | 395 | if self.masking_rate > 0.0 and self.training: 396 | mask = ( 397 | torch.rand(B, vision_feats.shape[1], 1, device=vision_feats.device) 398 | > self.masking_rate 399 | ) 400 | vision_feats = vision_feats * mask 401 | 402 | # Project to required dim (2048) 403 | if self.frozen_dict.get("projector", False): 404 | self.projector.eval() 405 | 406 | vision_feats = self.projector(vision_feats) 407 | 408 | if len(original_shape) == 5: 409 | if self.keep_time_dim: 410 | vision_feats = vision_feats.reshape( 411 | original_shape[0], -1, *vision_feats.shape[1:] 412 | ) 413 | else: 414 | vision_feats = vision_feats.reshape( 415 | original_shape[0], original_shape[1], *vision_feats.shape[1:] 416 | ) 417 | vision_patch_labels = ( 418 | vision_patch_labels.reshape( 419 | original_shape[0], 1, *vision_patch_labels.shape[1:] 420 | ) 421 | if prepped_y is not None 422 | else None 423 | ) 424 | if prepped_y is not None: 425 | return vision_feats, vision_patch_labels 426 | 427 | return vision_feats 428 | 429 | def decode(self, z): 430 | raise NotImplementedError 431 | 432 | def interpret(self, z): 433 | raise NotImplementedError 434 | -------------------------------------------------------------------------------- /carformer/carformer/ponderer_lit.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import cv2 5 | import lightning as L 6 | import numpy as np 7 | import torch 8 | import wandb 9 | from torch.utils.data import DataLoader 10 | 11 | from carformer.data import get_datasets 12 | from carformer.ponderer import Ponderer 13 | from carformer.utils import ( 14 | WeightedDistributedSampler, 15 | ) 16 | from carformer.utils.distributed import get_rank 17 | from carformer.visualization.visutils import ( 18 | visualize_input_from_batch, 19 | visualize_trajectory_action_predictions, 20 | ) 21 | 22 | 23 | class PondererLit(L.LightningModule): 24 | def __init__(self, ponderer): 25 | super(PondererLit, self).__init__() 26 | self.ponderer = ponderer 27 | 28 | self.params = ponderer.cfg 29 | 30 | self.save_hyperparameters(ponderer.cfg.to_dict()) 31 | 32 | print("Save dir: ", self.params.save_dir) 33 | 34 | def setup(self, stage): 35 | if stage != "fit": 36 | return 37 | 38 | folder = self.params.save_dir 39 | 40 | if self.trainer.is_global_zero: 41 | os.makedirs(folder, exist_ok=True) 42 | os.makedirs(os.path.join(folder, "train_predictions"), exist_ok=True) 43 | os.makedirs(os.path.join(folder, "val_predictions"), exist_ok=True) 44 | 45 | # Save the model config and training args 46 | self.params.save_pretrained(folder) 47 | # Save all training arguments 48 | env_vars = [] 49 | 50 | vars_to_save = [ 51 | "HF_ENDPOINT", 52 | "OMP_NUM_THREADS", 53 | "MKL_NUM_THREADS", 54 | ] 55 | 56 | for var in vars_to_save: 57 | if var in os.environ: 58 | env_vars.append(f"{var}={os.environ[var]}") 59 | 60 | # Detect if ddp or not 61 | if "WORLD_SIZE" in os.environ: 62 | run_command = f"torchrun --nproc_per_node={os.environ['WORLD_SIZE']}" 63 | else: 64 | run_command = "python" 65 | 66 | # Save the command used to run the training into the file {folder}/command.sh 67 | with open(os.path.join(folder, "command.sh"), "w") as f: 68 | all_args = env_vars + [run_command] + sys.argv 69 | f.write(" ".join(all_args)) 70 | f.write("\n") 71 | # If wandb is used, log the url to the run 72 | try: 73 | url = wandb.run.get_url() 74 | f.write(f"Wandb URL: {url}") 75 | f.write("\n") 76 | except: 77 | pass 78 | 79 | def training_step(self, batch, batch_idx): 80 | preds, loss, pred_labels, preprocessed_inputs = self.ponderer( 81 | batch, return_labels=True, return_inputs=True 82 | ) 83 | 84 | loss = {k: v.mean() for k, v in loss.items()} 85 | 86 | self.log_dict( 87 | {f"train/loss/{k}": v for k, v in loss.items()}, 88 | sync_dist=True, 89 | prog_bar=True, 90 | ) 91 | rank = get_rank() 92 | 93 | world_size = self.trainer.world_size 94 | # import ipdb; ipdb.set_trace() 95 | 96 | if ( 97 | batch_idx < 16 98 | and self.params.visualize 99 | and self.trainer.current_epoch > self.params.visualize_start_epoch 100 | and (self.trainer.current_epoch % self.params.visualize_interval == 0) 101 | and rank < 4 # only log from first 4 ranks 102 | ): 103 | for idd in range(max(self.params.hyperparams.batch_size // 8, 1)): 104 | # import ipdb; ipdb.set_trace() 105 | 106 | ac_val = batch["action"][0][0, 20][0].item() 107 | 108 | impath = visualize_input_from_batch( 109 | batch, 110 | idd, 111 | preds, 112 | pred_labels, 113 | self.params.save_dir, 114 | f"ep{self.trainer.current_epoch}_st{rank*8 + idd + world_size*4*batch_idx}_ac_{ac_val:.2f}", 115 | self.ponderer, 116 | "train", 117 | ) 118 | dir_preds = os.path.join(self.params.save_dir, "train_predictions") 119 | 120 | os.makedirs(dir_preds, exist_ok=True) 121 | 122 | path2 = visualize_trajectory_action_predictions( 123 | batch, 124 | preds, 125 | labels=pred_labels, 126 | save_dir=dir_preds, 127 | model=self.ponderer, 128 | save_suffix="ep{}_st{}".format( 129 | self.trainer.current_epoch, 130 | rank * 8 + idd + world_size * 4 * batch_idx, 131 | ), 132 | save_idx=idd, 133 | action_source="transformer-regression", 134 | visualize_gt=True, 135 | ) 136 | if not preds["bev"] is None: 137 | pred_heatmaps = preds["bev"].float().detach().cpu().numpy()[idd] 138 | 139 | label_heatmaps = ( 140 | pred_labels["bev_mask"].float().detach().cpu().numpy()[idd] 141 | ) 142 | 143 | size_per_patch = label_heatmaps.shape[-1] // 2 144 | 145 | patch_1 = ( 146 | label_heatmaps[:size_per_patch].reshape( 147 | int(size_per_patch**0.5), int(size_per_patch**0.5) 148 | ) 149 | > 0.5 150 | ) * 255 151 | patch_2 = ( 152 | label_heatmaps[size_per_patch:].reshape( 153 | int(size_per_patch**0.5), int(size_per_patch**0.5) 154 | ) 155 | > 0.5 156 | ) * 255 157 | 158 | patch_1_pred = ( 159 | pred_heatmaps[:size_per_patch].reshape( 160 | int(size_per_patch**0.5), int(size_per_patch**0.5) 161 | ) 162 | > 0.5 163 | ) * 255 164 | patch_2_pred = ( 165 | pred_heatmaps[size_per_patch:].reshape( 166 | int(size_per_patch**0.5), int(size_per_patch**0.5) 167 | ) 168 | > 0.5 169 | ) * 255 170 | 171 | patches = np.concatenate([patch_1, patch_2], axis=1) 172 | patches_pred = np.concatenate([patch_1_pred, patch_2_pred], axis=1) 173 | 174 | # Make pred red, label green 175 | patch = np.stack( 176 | [patches, patches_pred, np.zeros_like(patches)], axis=-1 177 | ) 178 | 179 | cv2.imwrite( 180 | os.path.join( 181 | dir_preds, 182 | f"ep{self.trainer.current_epoch}_st{rank*8 + idd + world_size*4*batch_idx}_patch.png", 183 | ), 184 | patch, 185 | ) 186 | 187 | return loss 188 | 189 | def validation_step(self, batch, batch_idx): 190 | preds, loss, pred_labels, preprocessed_inputs = self.ponderer( 191 | batch, return_labels=True, return_inputs=True 192 | ) 193 | 194 | loss = {k: v.mean() for k, v in loss.items()} 195 | 196 | self.log_dict({f"val/loss/{k}": v for k, v in loss.items()}, sync_dist=True) 197 | rank = get_rank() 198 | 199 | world_size = self.trainer.world_size 200 | 201 | if ( 202 | batch_idx < 16 203 | and self.params.visualize 204 | and self.trainer.current_epoch > self.params.visualize_start_epoch 205 | and (self.trainer.current_epoch % self.params.visualize_interval == 0) 206 | and rank < 4 # only log from first 4 ranks 207 | ): 208 | for idd in range(max(self.params.hyperparams.batch_size // 8, 1)): 209 | impath = visualize_input_from_batch( 210 | batch, 211 | idd, 212 | preds, 213 | pred_labels, 214 | self.params.save_dir, 215 | f"ep{self.trainer.current_epoch}_st{rank*8 + idd + world_size*4*batch_idx}", 216 | self.ponderer, 217 | "val", 218 | ) 219 | dir_preds = os.path.join(self.params.save_dir, "val_predictions") 220 | 221 | os.makedirs(dir_preds, exist_ok=True) 222 | 223 | path2 = visualize_trajectory_action_predictions( 224 | batch, 225 | preds, 226 | labels=pred_labels, 227 | save_dir=dir_preds, 228 | model=self.ponderer, 229 | save_suffix="ep{}_st{}".format( 230 | self.trainer.current_epoch, 231 | rank * 8 + idd + world_size * 4 * batch_idx, 232 | ), 233 | save_idx=idd, 234 | action_source="transformer-regression", 235 | ) 236 | if not preds["bev"] is None: 237 | pred_heatmaps = preds["bev"].float().detach().cpu().numpy()[idd] 238 | 239 | label_heatmaps = ( 240 | pred_labels["bev_mask"].float().detach().cpu().numpy()[idd] 241 | ) 242 | 243 | size_per_patch = label_heatmaps.shape[-1] // 2 244 | 245 | patch_1 = ( 246 | label_heatmaps[:size_per_patch].reshape( 247 | int(size_per_patch**0.5), int(size_per_patch**0.5) 248 | ) 249 | > 0.5 250 | ) * 255 251 | patch_2 = ( 252 | label_heatmaps[size_per_patch:].reshape( 253 | int(size_per_patch**0.5), int(size_per_patch**0.5) 254 | ) 255 | > 0.5 256 | ) * 255 257 | 258 | patch_1_pred = ( 259 | pred_heatmaps[:size_per_patch].reshape( 260 | int(size_per_patch**0.5), int(size_per_patch**0.5) 261 | ) 262 | > 0.5 263 | ) * 255 264 | patch_2_pred = ( 265 | pred_heatmaps[size_per_patch:].reshape( 266 | int(size_per_patch**0.5), int(size_per_patch**0.5) 267 | ) 268 | > 0.5 269 | ) * 255 270 | 271 | patches = np.concatenate([patch_1, patch_2], axis=1) 272 | patches_pred = np.concatenate([patch_1_pred, patch_2_pred], axis=1) 273 | 274 | # Make pred red, label green 275 | patch = np.stack( 276 | [patches, patches_pred, np.zeros_like(patches)], axis=-1 277 | ) 278 | 279 | cv2.imwrite( 280 | os.path.join( 281 | dir_preds, 282 | f"ep{self.trainer.current_epoch}_st{rank*8 + idd + world_size*4*batch_idx}_patch.png", 283 | ), 284 | patch, 285 | ) 286 | 287 | return loss 288 | 289 | def configure_optimizers(self): 290 | return_dict = {} 291 | 292 | params = self.parameters() 293 | 294 | opt_kwargs = self.params.hyperparams.optimizer.kwargs 295 | 296 | if "weight_decay" in opt_kwargs: 297 | # Pop weight decay from optimizer kwargs 298 | weight_decay = opt_kwargs.pop("weight_decay") 299 | 300 | # Create optimizer groups 301 | optim_groups = self.create_optimizer_groups(weight_decay) 302 | else: 303 | optim_groups = params 304 | 305 | opt = getattr(torch.optim, self.params.hyperparams.optimizer.name)( 306 | optim_groups, 307 | lr=self.params.hyperparams.lr, 308 | **self.params.hyperparams.optimizer.kwargs, 309 | ) 310 | stepping_batches = self.trainer.estimated_stepping_batches 311 | 312 | from torch.optim.lr_scheduler import ( 313 | CosineAnnealingWarmRestarts, 314 | ) 315 | 316 | scheduler = CosineAnnealingWarmRestarts( 317 | opt, T_0=stepping_batches // 30, T_mult=2, eta_min=1e-6 318 | ) 319 | 320 | return_dict["optimizer"] = opt 321 | 322 | return_dict["lr_scheduler"] = { 323 | "scheduler": scheduler, 324 | "interval": "step", 325 | } 326 | 327 | return return_dict 328 | 329 | def train_dataloader(self): 330 | train_dataset, _ = get_datasets(self.params, self.ponderer, splits=["train"]) 331 | subsample_ratio = self.params.dataset.subsample_ratio 332 | if self.params.training.weighted_sampling: 333 | initial_weights = train_dataset.getweights() 334 | 335 | self.sample_weights = initial_weights 336 | 337 | bucket_names = train_dataset.get_bucket_names() 338 | self.bucket_names = bucket_names 339 | 340 | bucket_config = self.params.training.bucket_weights 341 | 342 | if bucket_config.type == "uniform": 343 | weights = np.asarray(initial_weights) 344 | weights = (weights / weights.sum(0)).mean(-1) 345 | subsample_ratio *= bucket_config.total_ratio 346 | if bucket_config.type == "preferturns": 347 | weights = np.asarray(initial_weights) 348 | weights = ( 349 | (weights / weights.sum(0)) * np.asarray(bucket_config.weights) 350 | ).sum(-1) / np.asarray(bucket_config.weights).sum() 351 | subsample_ratio *= bucket_config.total_ratio 352 | else: 353 | raise NotImplementedError( 354 | f"Bucketing type {bucket_config.type} not implemented yet" 355 | ) 356 | else: 357 | weights = None 358 | 359 | sampler = WeightedDistributedSampler( 360 | train_dataset, 361 | subsample_ratio, 362 | shuffle=True if not self.params.overfit else False, 363 | weights=weights, 364 | ) 365 | 366 | train_loader = DataLoader( 367 | train_dataset, 368 | batch_size=self.params.hyperparams.batch_size, 369 | sampler=sampler, 370 | num_workers=self.params.num_workers, 371 | ) 372 | 373 | self.train_loader = train_loader 374 | 375 | return train_loader 376 | 377 | def val_dataloader(self): 378 | _, val_dataset = get_datasets(self.params, self.ponderer, splits=["val"]) 379 | 380 | val_sampler = WeightedDistributedSampler( 381 | val_dataset, 382 | shuffle=False, 383 | ) 384 | 385 | val_loader = DataLoader( 386 | val_dataset, 387 | batch_size=self.params.hyperparams.batch_size, 388 | sampler=val_sampler, 389 | num_workers=self.params.num_workers, 390 | ) 391 | 392 | return val_loader 393 | 394 | @staticmethod 395 | def from_pretrained(path, epoch=30, deepspeed="auto", apply_deepspeed=True): 396 | assert os.path.exists(path), "Path {} does not exist".format(path) 397 | 398 | if deepspeed == "auto": 399 | # Infer whether the checkpoint is in deepspeed format 400 | if os.path.exists(os.path.join(path, "zero_to_fp32.py")): 401 | deepspeed = True 402 | else: 403 | deepspeed = False 404 | 405 | config_path = os.path.join(path, "config.json") 406 | 407 | assert os.path.exists(config_path), "Config path {} does not exist".format( 408 | config_path 409 | ) 410 | 411 | config = CarformerConfig.from_pretrained(config_path) 412 | 413 | # Fix quantizer paths if needed 414 | import inspect 415 | 416 | import carformer 417 | 418 | carformer_path = os.path.dirname(inspect.getfile(carformer)) 419 | # Go one layer up 420 | carformer_path = os.path.dirname(carformer_path) 421 | 422 | # Replace encoder path 423 | if "rgb_backbone" in config.training: 424 | if config.training.rgb_backbone.frozen: 425 | config.training.rgb_backbone.model_path = os.path.join( 426 | carformer_path, config.training.rgb_backbone.model_path_rel 427 | ) 428 | else: 429 | # Do not load path, the weights and config should already exist. 430 | config.training.rgb_backbone.model_path = None 431 | 432 | model = Ponderer(config) 433 | 434 | if epoch in ["last", "best"]: 435 | checkpoint_path = os.path.join(path, "{}_model.pt".format(epoch)) 436 | else: 437 | if epoch is not None: 438 | if deepspeed: 439 | checkpoint_path = os.path.join( 440 | path, "{}".format(epoch) 441 | ) # Deepspeed style 442 | else: 443 | checkpoint_path = os.path.join( 444 | path, "epochs", "epoch_{}.pt".format(epoch) 445 | ) 446 | else: 447 | checkpoint_path = None 448 | 449 | if deepspeed == True: 450 | import deepspeed as ds 451 | 452 | arg_dict = {} 453 | if checkpoint_path is not None: 454 | arg_dict["checkpoint"] = checkpoint_path 455 | 456 | print("Loading args ", arg_dict) 457 | model = ds.init_inference(model, arg_dict) 458 | else: 459 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 460 | dtype = next(iter(checkpoint["model"].values())).dtype 461 | model.to(dtype) 462 | model.load_state_dict(checkpoint["model"], strict=True) 463 | if apply_deepspeed: 464 | import deepspeed as ds 465 | 466 | return model 467 | 468 | def on_train_epoch_start(self): 469 | # Only do this on rank 0 470 | if not self.trainer.is_global_zero: 471 | return 472 | 473 | if self.params.training.weighted_sampling: 474 | weights = self.sample_weights 475 | self.train_loader.sampler.__iter__() 476 | 477 | indices = self.train_loader.sampler.last_indices 478 | 479 | bucket_names = self.bucket_names 480 | 481 | dist = weights[indices].sum(0) / ( 482 | self.train_loader.sampler.num_samples 483 | * self.train_loader.sampler.num_replicas 484 | ) 485 | baseline_dist = weights.sum(0) / len(weights) 486 | 487 | print("Weighted sampling statistics, epoch: ", self.trainer.current_epoch) 488 | for bucket_name, dist_val, baseline_val in zip( 489 | bucket_names, dist, baseline_dist 490 | ): 491 | print( 492 | f"Bucket {bucket_name}: {dist_val*100:.2f}% (Baseline: {baseline_val*100:.2f}%)", 493 | end="\t", 494 | ) 495 | 496 | def create_optimizer_groups(self, weight_decay): 497 | """ 498 | This long function is unfortunately doing something very simple and is 499 | being very defensive: 500 | We are separating out all parameters of the model into two buckets: 501 | those that will experience 502 | weight decay for regularization and those that won't 503 | (biases, and layernorm/embedding weights). 504 | We are then returning the optimizer groups. 505 | """ 506 | 507 | # separate out all parameters to those that will and won't experience 508 | # regularizing weight decay 509 | decay = set() 510 | no_decay = set() 511 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) 512 | blacklist_weight_modules = ( 513 | torch.nn.LayerNorm, 514 | torch.nn.Embedding, 515 | torch.nn.BatchNorm2d, 516 | ) 517 | for mn, m in self.named_modules(): 518 | for pn, _ in m.named_parameters(): 519 | fpn = f"{mn}.{pn}" if mn else pn # full param name 520 | # if "attn_pool" in fpn: 521 | # import ipdb; ipdb.set_trace() 522 | # print(fpn) 523 | # if len(no_decay&decay) > 0: 524 | # import ipdb; ipdb.set_trace() 525 | # fpn = pn # full param name 526 | if pn.endswith("bias"): 527 | # all biases will not be decayed 528 | no_decay.add(fpn) 529 | elif "attn_pool" in fpn or "vqvae" in fpn: 530 | no_decay.add(fpn) 531 | elif "norm.weight" in fpn: 532 | no_decay.add(fpn) 533 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules): 534 | # weights of whitelist modules will be weight decayed 535 | decay.add(fpn) 536 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules): 537 | # weights of blacklist modules will NOT be weight decayed 538 | no_decay.add(fpn) 539 | elif ( 540 | pn.endswith("weight") and "conv." in pn 541 | ): # Add decay for convolutional layers. 542 | decay.add(fpn) 543 | elif pn.endswith("weight") and ".bn" in pn: # No decay for batch norms. 544 | no_decay.add(fpn) 545 | elif pn.endswith("weight") and ".ln" in pn: # No decay for layer norms. 546 | no_decay.add(fpn) 547 | elif ( 548 | pn.endswith("weight") and "downsample.0.weight" in pn 549 | ): # Conv2D layer with stride 2 550 | decay.add(fpn) 551 | elif pn.endswith("weight") and "downsample.1.weight" in pn: # BN layer 552 | no_decay.add(fpn) 553 | elif pn.endswith("weight") and ".attn" in pn: # Attention linear layers 554 | decay.add(fpn) 555 | elif ( 556 | pn.endswith("weight") and "channel_to_" in pn 557 | ): # Convolutional layers for channel change 558 | decay.add(fpn) 559 | elif pn.endswith("weight") and ".mlp" in pn: # MLP linear layers 560 | decay.add(fpn) 561 | elif ( 562 | pn.endswith("weight") and "target_speed_network" in pn 563 | ): # MLP linear layers 564 | decay.add(fpn) 565 | elif ( 566 | pn.endswith("weight") and "join." in pn and not ".norm" in pn 567 | ): # MLP layers 568 | decay.add(fpn) 569 | elif ( 570 | pn.endswith("weight") and "join." in pn and ".norm" in pn 571 | ): # Norm layers 572 | no_decay.add(fpn) 573 | elif pn.endswith("weight") and "layernorm" in pn: # Norm layers 574 | no_decay.add(fpn) 575 | elif pn.endswith("weight") and ".norm" in fpn: # Norm layers 576 | no_decay.add(fpn) 577 | elif "class_embedding" in fpn: # cls embeds 578 | no_decay.add(fpn) 579 | elif pn.endswith("_ih") or pn.endswith("_hh"): 580 | # all recurrent weights will not be decayed 581 | no_decay.add(fpn) 582 | elif pn.endswith("_emb") or "_token" in pn: 583 | no_decay.add(fpn) 584 | elif pn.endswith("_embed"): 585 | no_decay.add(fpn) 586 | elif "pos_embed" in pn: 587 | no_decay.add(fpn) 588 | elif "patch_embed" in pn: 589 | decay.add(fpn) 590 | elif "bias_ih_l0" in pn or "bias_hh_l0" in pn: 591 | no_decay.add(fpn) 592 | elif "weight_ih_l0" in pn or "weight_hh_l0" in pn: 593 | decay.add(fpn) 594 | elif "_query" in pn or "weight_hh_l0" in pn: 595 | no_decay.add(fpn) 596 | elif "proj_weight" in pn: 597 | decay.add(fpn) 598 | elif "valid_bev_pixels" in pn: 599 | no_decay.add(fpn) 600 | elif "ls1" in pn or "ls2" in pn: 601 | no_decay.add(fpn) 602 | elif "position_embedding" in pn: 603 | no_decay.add(fpn) 604 | 605 | # validate that we considered every parameter 606 | param_dict = dict(self.named_parameters()) 607 | inter_params = decay & no_decay 608 | union_params = decay | no_decay 609 | assert ( 610 | len(inter_params) == 0 611 | ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!" 612 | assert len(param_dict.keys() - union_params) == 0, ( 613 | f"parameters {str(param_dict.keys() - union_params)} were not " 614 | f"separated into either decay/no_decay set!" 615 | ) 616 | 617 | # create the pytorch optimizer object 618 | optim_groups = [ 619 | { 620 | "params": [param_dict[pn] for pn in sorted(list(decay))], 621 | "weight_decay": weight_decay, 622 | }, 623 | { 624 | "params": [param_dict[pn] for pn in sorted(list(no_decay))], 625 | "weight_decay": 0.0, 626 | }, 627 | ] 628 | return optim_groups 629 | 630 | def on_train_batch_end(self, *args, **kwargs): 631 | 632 | ema_enabled = self.ponderer.bev_encoder.ema_enabled 633 | disable_ema_updates = self.ponderer.bev_encoder.disable_ema_update 634 | if disable_ema_updates: 635 | return 636 | 637 | if ema_enabled: 638 | ema_every_steps = getattr( 639 | self.ponderer.config.training, "ema_every_steps", 1 640 | ) 641 | ema_start = getattr(self.ponderer.config.training, "ema_start", 0) 642 | ema_end_epoch = getattr(self.ponderer.config.training, "ema_end_epoch", -1) 643 | 644 | if self.global_step % ema_every_steps == 0 and self.global_step > ema_start: 645 | # Stop if end epoch is reached 646 | if ema_end_epoch == -1 or self.trainer.current_epoch < ema_end_epoch: 647 | ema_decay = self.ponderer.config.training.ema_decay 648 | 649 | vision_encoder = self.ponderer.bev_encoder 650 | 651 | model = vision_encoder.model 652 | 653 | ema_model = vision_encoder.ema_model 654 | self.ponderer.ema_update(model, ema_model, beta=ema_decay) 655 | -------------------------------------------------------------------------------- /carformer/carformer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .distributedsampler import * 3 | -------------------------------------------------------------------------------- /carformer/carformer/utils/distributed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # DistributedDataParallel required imports 4 | import torch.distributed as dist 5 | import torch.multiprocessing as mp 6 | from torch.nn.parallel import DistributedDataParallel as DDP 7 | from torch.distributed import init_process_group, destroy_process_group 8 | import os # For os.environ 9 | 10 | 11 | def disable_printing(is_master): 12 | """ 13 | This function disables printing when not in master process 14 | """ 15 | import builtins as __builtin__ 16 | 17 | builtin_print = __builtin__.print 18 | 19 | def print(*args, **kwargs): 20 | force = kwargs.pop("force", False) 21 | if is_master or force: 22 | builtin_print(*args, **kwargs) 23 | 24 | __builtin__.print = print 25 | 26 | 27 | # For distributed training, we need to use the following function 28 | # This function sets up the environment for distributed training, and lets every process know 29 | # which process it is (rank) and how many processes there are (world_size) 30 | def ddp_setup(args, dist_url="env://", use_deepspeed=False): 31 | # We get the rank and world size from the environment variables 32 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 33 | rank = int(os.environ["RANK"]) 34 | world_size = int(os.environ["WORLD_SIZE"]) 35 | if args.gpus != world_size: 36 | # Warn 37 | print( 38 | "WARNING: args.gpus != os.environ['WORLD_SIZE']", 39 | flush=True, 40 | ) 41 | args.gpus = int(os.environ["WORLD_SIZE"]) 42 | else: 43 | raise ValueError("Rank and world size must be set. Please make sure you are using torchrun") 44 | 45 | # Single gpu: 46 | # python train_hydra_ds.py +deepspeed=zero3 47 | # 8 gpus, with 10 cpus per GPU: 48 | # OMP_NUM_THREADS=10 torchrun --nproc_per_node=8 train_hydra_ds.py +deepspeed=zero3 49 | 50 | 51 | print( 52 | "| Initializing process with rank {} out of {} processes |".format( 53 | rank, world_size 54 | ), 55 | flush=True, 56 | ) 57 | 58 | # This is a useful hack I like to do sometimes. This DISABLES printing on nodes that are not rank 0, to make the output cleaner 59 | disable_printing(rank == 0) 60 | torch.cuda.set_device(rank) 61 | 62 | if use_deepspeed: 63 | import deepspeed as ds 64 | ds.init_distributed(dist_backend="nccl") 65 | else: 66 | init_process_group( 67 | # backend="nccl", # just GPU, commented out in order to use both CPU and GPU 68 | init_method=dist_url, 69 | rank=rank, 70 | world_size=args.gpus, 71 | ) 72 | 73 | 74 | def is_dist_avail_and_initialized(): 75 | if not dist.is_available(): 76 | return False 77 | if not dist.is_initialized(): 78 | return False 79 | return True 80 | 81 | 82 | def get_rank(): 83 | if not is_dist_avail_and_initialized(): 84 | return 0 85 | return dist.get_rank() 86 | 87 | 88 | def is_main_process(): 89 | return get_rank() == 0 90 | 91 | 92 | def save_on_master(*args, **kwargs): 93 | if is_main_process(): 94 | torch.save(*args, **kwargs) 95 | -------------------------------------------------------------------------------- /carformer/carformer/utils/distributedsampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import math 4 | import numpy as np 5 | import warnings 6 | 7 | 8 | class WeightedDistributedSampler(torch.utils.data.Sampler): 9 | """Sampler that restricts data loading to a subset of the dataset. 10 | It is especially useful in conjunction with 11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 12 | process can pass a DistributedSampler instance as a DataLoader sampler, 13 | and load a subset of the original dataset that is exclusive to it. 14 | .. note:: 15 | Dataset is assumed to be of constant size. 16 | Arguments: 17 | dataset: Dataset used for sampling. 18 | num_replicas (optional): Number of processes participating in 19 | distributed training. 20 | bins: a list of bins containing the indices of the dataset in each bin. Bins are chosen at equal probability, and then a sample is chosen from the bin at equal probability. 21 | rank (optional): Rank of the current process within num_replicas. 22 | shuffle (optional): If true (default), sampler will shuffle the indices 23 | split_strategy (optional): 24 | Assuming 4 gpus of index 0-3: 25 | interleaved: 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3 26 | partition: 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3 27 | interleaved is default behavior, closer behavior to using single device 28 | partition is suitable for cases where you want to split the read operations to different disks, ideally used with shuffle=False for preprocessing purposes 29 | """ 30 | 31 | def __init__( 32 | self, 33 | dataset, 34 | subsample_ratio=1.0, 35 | num_replicas=None, 36 | rank=None, 37 | weights=None, 38 | shuffle=False, 39 | split_strategy="interleaved" 40 | ): 41 | # If weights, make sure they are a torch tensor 42 | if weights is not None: 43 | weights = torch.tensor(weights, dtype=torch.float) 44 | if num_replicas is None: 45 | if not dist.is_available(): 46 | # Default to single process 47 | num_replicas = 1 48 | else: 49 | if dist.is_initialized(): 50 | num_replicas = dist.get_world_size() 51 | else: 52 | # Warn user that num_replicas is set to 1 53 | warnings.warn( 54 | "DistributedSampler is initialized without " 55 | "dist being initialized. Set num_replicas " 56 | "to 1 instead." 57 | ) 58 | num_replicas = 1 59 | if rank is None: 60 | if not dist.is_available(): 61 | # Default to single process 62 | rank = 0 63 | else: 64 | if dist.is_initialized(): 65 | rank = dist.get_rank() 66 | else: 67 | # Warn user that rank is set to 0 68 | warnings.warn( 69 | "DistributedSampler is initialized without " 70 | "dist being initialized. Set rank " 71 | "to 0 instead." 72 | ) 73 | rank = 0 74 | 75 | self.dataset = dataset 76 | self.num_replicas = num_replicas 77 | self.rank = rank 78 | self.epoch = 0 79 | self.num_samples = int( 80 | math.ceil(len(self.dataset) * 1.0 / self.num_replicas * subsample_ratio) 81 | ) 82 | self.split_strategy = split_strategy 83 | assert self.split_strategy in ["interleaved", "partition"] 84 | self.total_size = self.num_samples * self.num_replicas 85 | self.shuffle = shuffle 86 | self.weights = weights 87 | 88 | self.last_indices = None 89 | self.iter_counter = 0 90 | 91 | def __iter__(self): 92 | # exit(0) 93 | if self.iter_counter > 2 and (self.shuffle or self.weights is not None): 94 | print(("="*10+"\n")*2) 95 | print("WARNING: Iterating over the sampler using the same epoch more than 2 times. Double check if the sampler is being reset properly.") 96 | print(("="*10+"\n")*2) 97 | 98 | # deterministically shuffle based on epoch 99 | g = torch.Generator() 100 | g.manual_seed(self.epoch) 101 | if self.weights is not None: 102 | # if self.epoch == 0: 103 | print("Generating indices with weights for epoch {}".format(self.epoch)) 104 | indices = torch.multinomial( 105 | self.weights, self.total_size, replacement=True, generator=g 106 | ) 107 | self.last_indices = indices 108 | indices = indices.tolist() 109 | # Print histogram of indices 110 | # if self.epoch == 0: 111 | # hist = np.histogram( 112 | # np.asarray(indices), bins=range(0, len(self.dataset) + 1) 113 | # ) 114 | # print(hist) 115 | else: 116 | if self.shuffle: 117 | indices = torch.randperm(len(self.dataset), generator=g) 118 | self.last_indices = indices 119 | indices = indices.tolist() 120 | else: 121 | indices = list(range(len(self.dataset))) 122 | self.last_indices = indices 123 | 124 | # add extra samples to make it evenly divisible 125 | if len(indices) < self.total_size: 126 | indices += indices[: (self.total_size - len(indices))] 127 | assert len(indices) == self.total_size 128 | 129 | # subsample 130 | if self.split_strategy == "interleaved": 131 | indices = indices[self.rank : self.total_size : self.num_replicas] 132 | elif self.split_strategy == "partition": 133 | indices = indices[(self.total_size // self.num_replicas)* (self.rank) : (self.total_size // self.num_replicas)* (self.rank+1)] 134 | else: 135 | raise ValueError("split_strategy must be one of 'interleaved' or 'partition'") 136 | # print(self.rank, indices[:100]) 137 | assert len(indices) == self.num_samples 138 | 139 | return iter(indices) 140 | 141 | def set_subsample_ratio(self, subsample_ratio): 142 | self.num_samples = int( 143 | math.ceil(len(self.dataset) * 1.0 / self.num_replicas * subsample_ratio) 144 | ) 145 | self.total_size = self.num_samples * self.num_replicas 146 | 147 | def __len__(self): 148 | return self.num_samples 149 | 150 | def set_epoch(self, epoch): 151 | self.epoch = epoch 152 | self.iter_counter = 0 153 | 154 | 155 | # # Bin sampler 156 | # class BinDistributedSampler(torch.utils.data.Sampler): 157 | # """Sampler that restricts data loading to a subset of the dataset. 158 | # It is especially useful in conjunction with 159 | # :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 160 | # process can pass a DistributedSampler instance as a DataLoader sampler, 161 | # and load a subset of the original dataset that is exclusive to it. 162 | # .. note:: 163 | # Dataset is assumed to be of constant size. 164 | # Arguments: 165 | # dataset: Dataset used for sampling. 166 | # num_replicas (optional): Number of processes participating in 167 | # distributed training. 168 | # rank (optional): Rank of the current process within num_replicas. 169 | # shuffle (optional): If true (default), sampler will shuffle the indices 170 | # """ 171 | 172 | # def __init__( 173 | # self, 174 | # dataset, 175 | # subsample_ratio=1.0, 176 | # num_replicas=None, 177 | # rank=None, 178 | # shuffle=False, 179 | # ): 180 | # if num_replicas is None: 181 | # if not dist.is_available(): 182 | # # Default to single process 183 | # num_replicas = 1 184 | # else: 185 | # if dist.is_initialized(): 186 | # num_replicas = dist.get_world_size() 187 | # else: 188 | # # Warn user that num_replicas is set to 1 189 | # warnings.warn( 190 | # "DistributedSampler is initialized without " 191 | # "dist being initialized. Set num_replicas " 192 | # "to 1 instead." 193 | # ) 194 | # num_replicas = 1 195 | # if rank is None: 196 | # if not dist.is_available(): 197 | # # Default to single process 198 | # rank = 0 199 | # else: 200 | # if dist.is_initialized(): 201 | # rank = dist.get_rank() 202 | # else: 203 | # # Warn user that rank is set to 0 204 | # warnings.warn( 205 | # "DistributedSampler is initialized without " 206 | # "dist being initialized. Set rank " 207 | # "to 0 instead." 208 | # ) 209 | # rank = 0 210 | 211 | # self.dataset = dataset 212 | # self.num_replicas = num_replicas 213 | # self.rank = rank 214 | # self.epoch = 0 215 | # self.num_samples = int( 216 | # math.ceil(len(self.dataset) * 1.0 / self.num_replicas * subsample_ratio) 217 | # ) 218 | # self.total_size = self.num_samples * self.num_replicas 219 | # self.shuffle = shuffle 220 | 221 | # def __iter__(self): 222 | # # deterministically shuffle based on epoch 223 | # g = torch.Generator() 224 | # g.manual_seed(self 225 | -------------------------------------------------------------------------------- /carformer/carformer/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from enum import IntEnum 4 | from collections import defaultdict 5 | from collections.abc import MutableMapping 6 | import random 7 | 8 | 9 | def seed_everything(seed): 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed(seed) 14 | 15 | 16 | class TokenTypeIDs(IntEnum): 17 | GOAL = 4 18 | STATE = 0 19 | BEV = 1 20 | ACTION = 2 21 | REWARD = 3 22 | EOS = -1 23 | 24 | 25 | def deinterleave(interleaved_tensors, interleaved_token_type_ids, axis=1): 26 | """ 27 | Deinterleave tensors along the specified axis. 28 | Args: 29 | interleaved_tensors: A tensor of shape (..., step_width, ...) to deinterleave 30 | interleaved_token_type_ids: A tensor of shape (batch, step_width) containing the token type ids 31 | axis: The axis to deinterleave along 32 | Returns: 33 | A dictionary of deinterleaved tensors, keyed by token type id 34 | """ 35 | results_by_tokentype = defaultdict(list) 36 | 37 | for tokenType in TokenTypeIDs: 38 | max_axis_length = 0 39 | for batch_idx in range(interleaved_token_type_ids.shape[0]): 40 | # Get the indices of the unique token type ids 41 | token_type_id_indices = torch.where( 42 | interleaved_token_type_ids[batch_idx] == tokenType 43 | )[0] 44 | 45 | # Get the tensor corresponding to the token type id 46 | results_by_tokentype[tokenType].append( 47 | torch.index_select( 48 | interleaved_tensors[batch_idx].unsqueeze(0), 49 | axis, 50 | token_type_id_indices, 51 | ) 52 | ) 53 | max_axis_length = max( 54 | max_axis_length, results_by_tokentype[tokenType][-1].shape[axis] 55 | ) 56 | 57 | # Pad the tensors to the max length 58 | output_shape = list(results_by_tokentype[tokenType][0].shape) 59 | for i in range(len(results_by_tokentype[tokenType])): 60 | output_shape[axis] = ( 61 | max_axis_length - results_by_tokentype[tokenType][i].shape[axis] 62 | ) 63 | if output_shape[axis] > 0: 64 | results_by_tokentype[tokenType][i] = torch.cat( 65 | [ 66 | results_by_tokentype[tokenType][i], 67 | torch.zeros( 68 | output_shape, dtype=results_by_tokentype[tokenType][i].dtype 69 | ).to(results_by_tokentype[tokenType][i].device), 70 | ], 71 | axis=axis, 72 | ) 73 | 74 | # Concatenate the tensors 75 | results_by_tokentype[tokenType] = torch.cat( 76 | results_by_tokentype[tokenType], axis=0 77 | ) 78 | 79 | return results_by_tokentype 80 | 81 | 82 | # Normalize version compatible with torch tensors 83 | def normalize_angle_torch(x): 84 | x = x % (2 * np.pi) # force in range [0, 2 pi) 85 | x = torch.where(x > np.pi, x - 2 * np.pi, x) # move to [-pi, pi) 86 | return x 87 | 88 | 89 | def unwrap_model(model): 90 | if hasattr(model, "module"): 91 | return model.module 92 | else: 93 | return model 94 | 95 | 96 | # Flatten backbone config nested dicts into a single dict 97 | # a: {b: c} -> {"a.b": c} 98 | def flatten_dict(d, parent_key="", sep="."): 99 | items = [] 100 | for k, v in d.items(): 101 | new_key = parent_key + sep + k if parent_key else k 102 | if isinstance(v, MutableMapping): 103 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 104 | else: 105 | items.append((new_key, v)) 106 | return dict(items) 107 | 108 | 109 | def fuzzy_extract_state_dict_from_checkpoint(checkpoint, consume="ponderer."): 110 | if "model" in checkpoint: 111 | checkpoint = checkpoint["model"] 112 | 113 | if consume: # Remove the prefix from the keys 114 | checkpoint = { 115 | (k[len(consume) :] if k.startswith(consume) else k): v 116 | for k, v in checkpoint.items() 117 | } 118 | 119 | return checkpoint 120 | -------------------------------------------------------------------------------- /carformer/carformer/visualization/visutils.py: -------------------------------------------------------------------------------- 1 | from carformer.utils import TokenTypeIDs 2 | from PIL import Image, ImageDraw, ImageFilter 3 | import torch 4 | import os 5 | import numpy as np 6 | import cv2 7 | 8 | 9 | def extract_target_point_from_trajectory_goals(goals, goal_type): 10 | if "dual_target_point" in goal_type: 11 | goal_idx = goal_type.index("dual_target_point") 12 | elif "target_point" in goal_type: 13 | goal_idx = goal_type.index("target_point") 14 | else: 15 | raise ValueError( 16 | f"Unknown goal type {goal_type}. 'target_point' or 'dual_target_point' must be in the goal types." 17 | ) 18 | 19 | return goals.reshape(-1, 2).numpy() 20 | 21 | 22 | def draw_points_on_camera( 23 | camera, 24 | points, 25 | color=(0, 255, 255), 26 | first_origin=None, 27 | radius=4, 28 | blur=-1, 29 | ): 30 | rgb = Image.fromarray(cv2.cvtColor(camera, cv2.COLOR_BGR2RGB)) 31 | 32 | if blur > 0: 33 | # Start with a blank image and draw the points on it, blur it and then overlay it on the original image 34 | rgb_tmp = Image.new("RGBA", rgb.size, tuple([*color, 0])) 35 | draw = ImageDraw.Draw(rgb_tmp) 36 | else: 37 | draw = ImageDraw.Draw(rgb) 38 | 39 | for i, pt in enumerate(points): 40 | x, y = pt[:2].astype(int) 41 | 42 | y = first_origin[-1] + y 43 | x = first_origin[0] + x 44 | if x < 0 or x >= rgb.size[0] or y < 0 or y >= rgb.size[-1]: 45 | continue 46 | draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color) 47 | 48 | # point_list = first_origin + point_list * pix_per_meter 49 | 50 | # point_list[0] += i * skip_size 51 | 52 | # for j, point in enumerate(point_list): 53 | # cv2.circle( 54 | # img=canvas, 55 | # center=tuple(point.astype(np.int32)), 56 | # radius=np.rint(radius).astype(int), 57 | # color=tuple(c * (color_decay**j) for c in color), 58 | # thickness=cv2.FILLED, 59 | # ) 60 | 61 | if blur > 0: 62 | rgb_tmp = rgb_tmp.filter(ImageFilter.GaussianBlur(blur)) 63 | if rgb.mode == "RGB": 64 | rgb = rgb.convert("RGBA") 65 | 66 | rgb = Image.alpha_composite(rgb, rgb_tmp) 67 | 68 | # convert back to RGB 69 | rgb = rgb.convert("RGB") 70 | 71 | # draw.ellipse([first_origin[0]-radius, first_origin[-1]-radius, first_origin[0]+radius, first_origin[-1]+radius], fill=color) 72 | 73 | # Return cv2 format array 74 | return cv2.cvtColor(np.array(rgb), cv2.COLOR_RGB2BGR) 75 | 76 | 77 | def get_action_canvas( 78 | action, 79 | action_types, 80 | size=192, 81 | bev_canvas=None, 82 | copy_bev_canvas=False, 83 | bev_crop_type="front", 84 | pix_per_meter=5, 85 | supp_action=None, 86 | ): 87 | if "path" in action_types: 88 | return get_action_waypoints_path( 89 | action, 90 | action_types, 91 | size=size, 92 | bev_canvas=bev_canvas, 93 | copy_bev_canvas=copy_bev_canvas, 94 | bev_crop_type=bev_crop_type, 95 | pix_per_meter=pix_per_meter, 96 | supp_action=supp_action, 97 | ) 98 | else: 99 | raise ValueError( 100 | f"Unknown action type {action_types}. 'path' must be in the action types." 101 | ) 102 | 103 | 104 | def get_action_waypoints_path( 105 | action, 106 | action_types, 107 | size=192, 108 | bev_canvas=None, 109 | copy_bev_canvas=False, 110 | bev_crop_type="front", 111 | pix_per_meter=5, 112 | supp_action=None, 113 | ): 114 | if bev_canvas is None: 115 | bev_canvas = np.zeros((size, size * action.shape[0], 3), dtype=np.uint8) 116 | else: 117 | if copy_bev_canvas: 118 | bev_canvas = bev_canvas.copy() 119 | 120 | waypoints = action[0, 20:].copy() 121 | 122 | waypoints[:, -1] *= -1 123 | 124 | path = action[0, :20] 125 | 126 | waypoints = point_to_canvas_coordinates_rel_to_center(waypoints, height=-1.6) / 2 127 | path = point_to_canvas_coordinates_rel_to_center(path, height=-1.6) / 2 128 | 129 | if bev_crop_type == "dualfront": 130 | origin = (size, size // 2) 131 | elif bev_crop_type == "front": 132 | origin = (size // 2, size) 133 | else: 134 | origin = (size // 2, size // 2) 135 | 136 | bev_canvas = draw_points_on_camera( 137 | bev_canvas, waypoints, color=(0, 0, 255), first_origin=origin, radius=5 138 | ) 139 | 140 | bev_canvas = draw_points_on_camera( 141 | bev_canvas, path, color=(255, 255, 0), first_origin=origin, radius=2 142 | ) 143 | if supp_action is not None: 144 | supp_waypoints = supp_action[0, 20:].copy() 145 | supp_waypoints[:, -1] *= -1 146 | supp_path = supp_action[0, :20] 147 | 148 | supp_waypoints = ( 149 | point_to_canvas_coordinates_rel_to_center(supp_waypoints, height=-1.6) / 2 150 | ) 151 | supp_path = ( 152 | point_to_canvas_coordinates_rel_to_center(supp_path, height=-1.6) / 2 153 | ) 154 | 155 | bev_canvas = draw_points_on_camera( 156 | bev_canvas, supp_waypoints, color=(0, 0, 155), first_origin=origin, radius=3 157 | ) 158 | bev_canvas = draw_points_on_camera( 159 | bev_canvas, supp_path, color=(155, 0, 155), first_origin=origin, radius=1 160 | ) 161 | 162 | if not copy_bev_canvas: 163 | raise ValueError("Not implemented") 164 | return bev_canvas 165 | 166 | 167 | def point_to_canvas_coordinates_rel_to_center( 168 | points, original_size=(1600, 800), height=-1 169 | ): # shape: Nx2 170 | rgb_front_intrinsic = np.asarray( 171 | [ 172 | [1142.5184053936916, 0.0, 800.0], 173 | [0.0, 1142.5184053936916, 450.0], 174 | [0.0, 0.0, 1.0], 175 | ] 176 | ) 177 | 178 | points = np.stack( 179 | [points[:, 1], np.ones_like(points[:, 0]) * height, points[:, 0]], axis=-1 180 | ) 181 | # points = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1) 182 | # If any of points[:, 2:3] are between 1e-4 and -1e-4, then we set them to plus 1e-4 183 | mask = np.logical_and(points[:, 2:3] < 1e-4, points[:, 2:3] > -1e-4) 184 | 185 | points[:, 2:3][mask] = 1e-4 186 | 187 | points = points / points[:, 2:3] 188 | points = (rgb_front_intrinsic @ points.T[:3, :]).T 189 | 190 | points = -points 191 | 192 | points[:, 0] += original_size[0] // 2 193 | points[:, 1] += original_size[1] // 2 194 | 195 | return points 196 | 197 | 198 | def get_goal_canvas( 199 | commands, 200 | command_types, 201 | size=192, 202 | color=(255, 0, 255), 203 | bev_canvas=None, 204 | copy_bev_canvas=False, 205 | bev_crop_type="front", 206 | pix_per_meter=5, 207 | ): 208 | if "command" in command_types: 209 | return get_command_canvas(commands, size=size, color=color) 210 | else: 211 | return get_targetpoint_canvas( 212 | commands, 213 | command_types, 214 | size=size, 215 | color=color, 216 | bev_canvas=bev_canvas, 217 | copy_bev_canvas=copy_bev_canvas, 218 | bev_crop_type=bev_crop_type, 219 | pix_per_meter=pix_per_meter, 220 | ) 221 | 222 | 223 | def get_targetpoint_canvas( 224 | commands, 225 | command_types, 226 | color=(255, 0, 255), 227 | size=192, 228 | bev_canvas=None, 229 | copy_bev_canvas=False, 230 | bev_crop_type="front", 231 | pix_per_meter=5, 232 | ): 233 | if bev_canvas is None: 234 | bev_canvas = np.zeros((size, size * commands.shape[0], 3), dtype=np.uint8) 235 | else: 236 | if copy_bev_canvas: 237 | bev_canvas = bev_canvas.copy() 238 | 239 | # For now, assert size of action is 8 and no other action types 240 | assert len(command_types) == 1 and ( 241 | command_types[0] == "target_point" or command_types[0] == "dual_target_point" 242 | ) 243 | 244 | target_points = extract_target_point_from_trajectory_goals(commands, command_types) 245 | # print(target_points) 246 | target_points = ( 247 | point_to_canvas_coordinates_rel_to_center(target_points, height=-1.6) / 2 248 | ) 249 | 250 | # print(target_points) 251 | # canvas = draw_points_on_camera(canvas, target_points, color=(255, 0, 255), first_origin=origin, radius=10, blur=8) 252 | 253 | if bev_crop_type == "dualfront": 254 | origin = (size, size // 2) 255 | elif bev_crop_type == "front": 256 | origin = (size // 2, size) 257 | else: 258 | origin = (size // 2, size // 2) 259 | 260 | # Draw the target points 261 | bev_canvas = draw_points_on_camera( 262 | bev_canvas, target_points, color=color, first_origin=origin, radius=12, blur=8 263 | ) 264 | 265 | if not copy_bev_canvas: 266 | return None 267 | return bev_canvas 268 | 269 | 270 | def visualize_trajectory_action_predictions( 271 | batch, 272 | batch_outputs, 273 | save_dir=None, 274 | labels=None, 275 | save_idx=0, 276 | action_source="transformer-regression", 277 | save_suffix="", 278 | do_write=True, 279 | visualize_gt=True, 280 | *args, 281 | **kwargs, 282 | ): 283 | # Move batch to cpu 284 | # batch = {k: v.cpu() for k, v in batch.items() if isinstance(v, torch.Tensor)} 285 | # Copy batch shallowly 286 | batch = {k: v for k, v in batch.items()} 287 | 288 | # Convert labels to empty dict if None 289 | if labels is None: 290 | labels = {} 291 | 292 | if action_source == "transformer": 293 | if "output_dict" in batch_outputs: 294 | batch_outputs = batch_outputs["output_dict"] 295 | batch_outputs = {k: v for k, v in batch_outputs.items()} 296 | 297 | # Change the batch action to be the predicted action 298 | pred_actions = batch_outputs[TokenTypeIDs.ACTION].to(batch["action"].device) 299 | elif action_source == "gru": 300 | if "waypoints" in batch_outputs: 301 | pred_actions = batch_outputs["waypoints"].to(batch["action"].device) 302 | else: 303 | pred_actions = batch_outputs.to(batch["action"].device) 304 | elif action_source == "transformer-regression": 305 | waypoints = batch_outputs["action"]["wps"].float().detach() # 1x10x2 306 | path = batch_outputs["action"]["path"].float().detach() # 1x20x2 307 | 308 | waypoints = waypoints.cumsum(-2) 309 | path = path.cumsum(-2) 310 | if len(waypoints.shape) == 3: 311 | waypoints = waypoints.unsqueeze(1) 312 | path = path.unsqueeze(1) 313 | 314 | if visualize_gt: 315 | supp_action = batch["action"] # [save_idx].unsqueeze(0) 316 | else: 317 | supp_action = None 318 | 319 | batch["action"] = torch.zeros( 320 | (waypoints.shape[0], 1, 30, 2), dtype=path.dtype, device=path.device 321 | ) 322 | batch["action"][:, :, :20, :] = path 323 | batch["action"][:, :, 20:, :] = waypoints 324 | else: 325 | raise ValueError(f"Unknown action source {action_source}") 326 | 327 | # pred_actions_len = pred_actions.shape[0] 328 | 329 | # batch["action"] = torch.cat( 330 | # ( 331 | # batch["action"], 332 | # pred_actions.reshape( 333 | # batch["action"].shape[0], -1, batch["action"].shape[-1] 334 | # ), 335 | # ), 336 | # dim=1, 337 | # ) 338 | # import ipdb; ipdb.set_trace() 339 | # return visualize_trajectory(batch, save_dir, save_idx=save_idx, *args, **kwargs) 340 | return visualize_input_from_batch( 341 | batch=batch, 342 | batch_idx=save_idx, 343 | batch_outputs=batch_outputs, 344 | labels=labels, 345 | save_dir=save_dir, 346 | save_affix=f"pred_" + "{}".format(save_suffix), 347 | model=kwargs.get("model", None), 348 | save_prefix=None, 349 | do_write=do_write, 350 | supp_action=supp_action, 351 | ) 352 | 353 | 354 | def get_bev_canvas( 355 | batch, 356 | batch_idx, 357 | model=None, 358 | batch_outputs=None, 359 | labels=None, 360 | include_targets=True, 361 | use_target_mask_if_available=True, 362 | ): 363 | if model is not None: 364 | if "rgb_front" in batch: 365 | bev_mode = "rgb_front" 366 | else: 367 | raise ValueError( 368 | "Model is provided but no rgb_front in batch. " 369 | "This is unexpected, please check your model and batch." 370 | ) 371 | else: 372 | if "rgb_front" in batch: 373 | bev_mode = "rgb_front" 374 | else: 375 | raise ValueError("No rgb in batch") 376 | 377 | targets_rgb = None 378 | if bev_mode == "rgb_front": 379 | # shape: BxTxHxWx3 380 | if use_target_mask_if_available and "goal_mask" in batch: 381 | rgb_front = batch["goal_mask"] 382 | else: 383 | rgb_front = batch["rgb_front"] 384 | 385 | gt_rgb = ( 386 | rgb_front[batch_idx] 387 | .cpu() 388 | .numpy() 389 | .transpose((0, 1, 2, 3)) 390 | .astype(np.float32) 391 | / 255.0 392 | ) 393 | 394 | gt_rgb = gt_rgb[:1].reshape(-1, *gt_rgb.shape[2:]) 395 | gt_reproduced = None 396 | preds_rgb = None 397 | if include_targets: 398 | if "target_rgb_front" in batch: 399 | targets_rgb = ( 400 | batch["target_rgb_front"][batch_idx] 401 | .cpu() 402 | .numpy() 403 | .transpose((1, 0, 2, 3)) 404 | .astype(np.float32) 405 | / 255.0 406 | ) 407 | targets_rgb = targets_rgb.squeeze(1) 408 | 409 | waypoints = batch["frc_wps"][batch_idx, 0].cpu().numpy().copy() 410 | 411 | waypoints[:, -1] *= -1 412 | 413 | waypoints_wrld = ( 414 | point_to_canvas_coordinates_rel_to_center(waypoints, height=-1.6) 415 | / 2 416 | ) 417 | 418 | speed = batch["frc_speed"][batch_idx, 0].cpu().item() 419 | goals = batch["frc_goal"][batch_idx, 0].cpu() 420 | goals = ( 421 | point_to_canvas_coordinates_rel_to_center(goals, height=-1.6) / 2 422 | ) 423 | 424 | targets_rgb_pil = Image.fromarray((targets_rgb * 255).astype(np.uint8)) 425 | from PIL import ImageDraw, ImageFont 426 | 427 | draw = ImageDraw.Draw(targets_rgb_pil) 428 | 429 | try: 430 | font = ImageFont.truetype( 431 | "/usr/share/fonts/dejavu/DejaVuSans.ttf", 15 432 | ) 433 | except IOError: 434 | font = ImageFont.load_default() 435 | 436 | # Get center of the image 437 | width, height = targets_rgb_pil.size 438 | 439 | # Calculate the position for the text 440 | position = (0, height // 2) 441 | # if x > y: 442 | # break 443 | # Draw the text at the calculated position with the specified angle 444 | draw.text( 445 | position, 446 | f"Speed: {speed:.2f}\n Waypoints: {waypoints}\n Goals: {goals}", 447 | font=font, 448 | fill=(255, 255, 255, 128), 449 | anchor="lm", 450 | ) 451 | 452 | targets_rgb = np.array(targets_rgb_pil) 453 | 454 | draw_points_on_camera( 455 | targets_rgb, 456 | waypoints_wrld, 457 | color=(255, 0, 0), 458 | first_origin=(width // 2, height // 2), 459 | radius=4, 460 | ) 461 | draw_points_on_camera( 462 | targets_rgb, 463 | goals, 464 | color=(200, 200, 0), 465 | first_origin=(width // 2, height // 2), 466 | radius=2, 467 | ) 468 | 469 | targets_rgb = targets_rgb.astype(np.float32) / 255.0 470 | 471 | if gt_reproduced is not None: 472 | gt_reproduced = gt_reproduced.reshape( 473 | gt_reproduced.shape[0], -1, gt_reproduced.shape[-1] 474 | ) 475 | 476 | rows = [x for x in [gt_rgb, gt_reproduced, targets_rgb, preds_rgb] if x is not None] 477 | canvas_unit_size = gt_rgb.shape[0] 478 | 479 | canvas = np.zeros((canvas_unit_size * len(rows), gt_rgb.shape[1], 3)) 480 | 481 | canvas = np.concatenate(rows, axis=0) 482 | 483 | # Convert canvas of 0-1 floats to 0-255 uint8 484 | canvas = (canvas * 255).astype(np.uint8) 485 | # Rgb to bgr 486 | canvas = canvas[:, :, ::-1] 487 | 488 | return canvas, canvas_unit_size 489 | 490 | 491 | def visualize_input_from_batch( 492 | batch, 493 | batch_idx, 494 | batch_outputs, 495 | labels, 496 | save_dir, 497 | save_affix, 498 | model, 499 | save_prefix, 500 | do_write=True, 501 | supp_action=None, 502 | use_target_mask_if_available=True, 503 | ): 504 | canvas, img_size = get_bev_canvas( 505 | batch=batch, 506 | batch_idx=batch_idx, 507 | model=model, 508 | batch_outputs=batch_outputs, 509 | labels=labels, 510 | use_target_mask_if_available=use_target_mask_if_available, 511 | ) 512 | 513 | canvas_to_reuse = canvas 514 | 515 | # Add a black copy of the canvas 516 | black = np.zeros_like(canvas_to_reuse) 517 | 518 | canvas_to_reuse = np.concatenate([canvas_to_reuse, black], axis=0) 519 | 520 | actions = batch["action"][batch_idx].cpu().numpy() 521 | if supp_action is not None: 522 | supp_action = supp_action[batch_idx].cpu().numpy() 523 | 524 | action_canvas = get_action_canvas( 525 | actions, 526 | ["path", "waypoints"], 527 | size=img_size, 528 | bev_canvas=canvas_to_reuse, 529 | copy_bev_canvas=True, 530 | bev_crop_type="dualfront", 531 | pix_per_meter=3, 532 | supp_action=supp_action, 533 | ) 534 | 535 | goal = batch["goal"][batch_idx].cpu() 536 | goal_canvas = get_goal_canvas( 537 | goal, 538 | ["target_point"], 539 | size=img_size, 540 | bev_canvas=action_canvas, 541 | copy_bev_canvas=True, 542 | bev_crop_type="dualfront", 543 | pix_per_meter=3, 544 | ) 545 | 546 | if not do_write: 547 | return goal_canvas 548 | 549 | impath = os.path.join( 550 | save_dir, 551 | f"{save_prefix}_predictions" if save_prefix else "", 552 | "epoch_{}.png".format(save_affix), 553 | ) 554 | 555 | # Save as epoch_{epoch}.png in the log directory 556 | cv2.imwrite( 557 | impath, 558 | goal_canvas, 559 | ) 560 | 561 | return impath 562 | -------------------------------------------------------------------------------- /carformer/requirements.txt: -------------------------------------------------------------------------------- 1 | PyYAML 2 | torch>=2.0.0 3 | transformers>4.30 4 | tokenizers 5 | plotly 6 | omegaconf==2.3.0 7 | hydra-core==1.2.0 8 | mpi4py==3.1.4 9 | #deepspeed==0.10.1 10 | -------------------------------------------------------------------------------- /carformer/setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name="carformer", 5 | version="0.2", 6 | author="Shadi Hamdan et al.", 7 | author_email="shamdan17@ku.edu.tr", 8 | description="Transformers for Sequential Decision making in Autonomous Driving", 9 | long_description="Transformers for Sequential Decision making in Autonomous Driving", 10 | long_description_content_type="text", 11 | url="https://github.com/Shamdan17/ETA", 12 | project_urls={ 13 | "Bug Tracker": "https://github.com/Shamdan17/ETA/issues", 14 | }, 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "License :: OSI Approved :: MIT License", 18 | "Operating System :: OS Independent", 19 | ], 20 | packages=setuptools.find_packages(where="."), 21 | python_requires=">=3.7", 22 | ) 23 | -------------------------------------------------------------------------------- /carformer/train.py: -------------------------------------------------------------------------------- 1 | from carformer.ponderer import Ponderer 2 | from carformer.ponderer_lit import PondererLit 3 | 4 | from carformer.utils import seed_everything 5 | import os 6 | from carformer.config import config_init, CarformerConfig 7 | from omegaconf import OmegaConf 8 | import hydra 9 | from skit.distributed.deepspeed_utils import ZeroLightModelCheckpoint 10 | 11 | config_init() 12 | 13 | 14 | @hydra.main(version_base="1.1", config_path="carformer/config", config_name="config") 15 | def main(cfg): 16 | seed_everything(cfg.seed) 17 | 18 | args = cfg 19 | 20 | if args.ckpt_path is not None: 21 | print("Loading model from checkpoint") 22 | cfg.save_dir = os.path.dirname(args.ckpt_path) 23 | 24 | config = CarformerConfig.from_hydra(args) 25 | 26 | print(config) 27 | 28 | model = Ponderer(config) 29 | 30 | model = PondererLit(model).train() 31 | 32 | ds_cfg = OmegaConf.to_container(cfg.deepspeed, resolve=True) 33 | 34 | folder = cfg.save_dir 35 | 36 | from lightning import Trainer 37 | from lightning.pytorch.loggers import WandbLogger 38 | from lightning.pytorch.callbacks import ( 39 | ModelCheckpoint, 40 | LearningRateMonitor, 41 | ) 42 | from lightning.pytorch.strategies import DeepSpeedStrategy 43 | 44 | callbacks = [LearningRateMonitor(logging_interval="step")] 45 | 46 | if (not args.overfit) or args.force_save: 47 | callbacks.extend( 48 | [ 49 | ModelCheckpoint( 50 | dirpath=os.path.join(folder), 51 | filename="last_model", 52 | every_n_epochs=0, 53 | save_last=True, 54 | ), 55 | ] 56 | ) 57 | 58 | callbacks.extend( 59 | [ 60 | ZeroLightModelCheckpoint( 61 | dirpath=os.path.join(folder, "epochs"), 62 | filename="{epoch}", 63 | monitor="val/loss/wp_loss", 64 | mode="min", 65 | every_n_epochs=args.save_every, 66 | save_top_k=-1, 67 | save_weights_only=True, 68 | start_save_epoch=args.start_saving_epoch, 69 | ) 70 | ] 71 | ) 72 | 73 | if args.wandb_tag: 74 | tags = [args.wandb_tag] 75 | else: 76 | tags = [] 77 | 78 | if not (args.overfit or args.debug) or args.force_log: 79 | log = True 80 | else: 81 | log = False 82 | 83 | trainer = Trainer( 84 | accelerator="gpu" if not args.cpu else "cpu", 85 | devices=args.gpus, 86 | num_nodes=args.nodes, 87 | strategy=(DeepSpeedStrategy(config=ds_cfg)), 88 | max_epochs=args.hyperparams.num_epochs, 89 | logger=( 90 | WandbLogger( 91 | project=args.logging.project, 92 | entity=args.logging.entity, 93 | mode=args.logging.mode, 94 | tags=tags, 95 | offline=True, 96 | save_code=True, 97 | ) 98 | if log 99 | else None 100 | ), 101 | callbacks=callbacks, 102 | accumulate_grad_batches=args.hyperparams.gradient_accumulation_steps, 103 | use_distributed_sampler=False, 104 | overfit_batches=args.overfit_batches if args.overfit else 0.0, 105 | limit_val_batches=0 if args.overfit else 1.0, 106 | log_every_n_steps=1 if args.overfit else 25, 107 | enable_checkpointing=False if args.overfit else True, 108 | ) 109 | 110 | trainer.fit(model, ckpt_path=args.ckpt_path) 111 | 112 | 113 | if __name__ == "__main__": 114 | main() 115 | -------------------------------------------------------------------------------- /docs/TRAIN_EVAL.md: -------------------------------------------------------------------------------- 1 | ## Training Setup 2 | ### Bench2Drive data 3 | Download the [Bench2Drive](https://github.com/Thinklab-SJTU/Bench2Drive) dataset and unzip all the directories to a single folder. The dataset should have a structure similar to the following: 4 | 5 | ``` 6 | MainFolder/ 7 | HazardAtSideLaneTwoWays_Town12_Route1133_Weather15/ 8 | ParkingCutIn_Town12_Route765_Weather11/ 9 | ``` 10 | 11 | ### Libraries 12 | Install the [skit](https://github.com/Shamdan17/skit) toolkit for distributed sampling and training wrappers. More dependencies are listed in the `requirements.txt` file. Install them using pip: 13 | 14 | ```bash 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | ### Config files 19 | Under `carformer/config/user`, create a yaml file following the `example.yaml` example. 20 | 21 | ```yaml 22 | dataset_dir: /PATH/TO/DATASET/FOLDER/ 23 | working_dir: /PATH/TO/carformer/ 24 | 25 | wandb_entity: WANDBUSERNAME 26 | ``` 27 | 28 | Furthermore, modify the Makefile to update the username to the config name you just created. 29 | 30 | ### Training 31 | 32 | To train the base and async models on 8 nodes with 4 GPUs each, run the following commands from the `carformer` directory. We run these commands in a SLURM job. Directly using make will not properly parallelize the training across multiple nodes, so modify it accordingly if not working in a SLURM environment. 33 | 34 | ```bash 35 | make ETA_base_model_s42 36 | make ETA_async_model_s42 37 | ``` 38 | 39 | ## Evaluation Setup 40 | 41 | ### Bench2Drive 42 | 43 | [Bench2Drive](https://github.com/Thinklab-SJTU/Bench2Drive) is required for evaluation. Please follow the "Eval Tools" section. 44 | 45 | ### File setup: 46 | 47 | Only follow these steps **AFTER** Bench2Drive is set up following the Bench2Drive instructions. Please place the files found in "misc" and "team_code" in the following structure in the Bench2Drive repo: 48 | 49 | ``` 50 | Bench2Drive\ 51 | assets\ 52 | docs\ 53 | leaderboard\ 54 | leaderboard\ 55 | --> Copy "leaderboard_evaluator_local.py" from the "misc" folder here 56 | scripts\ 57 | --> Copy "run_eval_leaderboard.py" from the "misc" folder here 58 | team_code\ 59 | --> Copy files from "team_code" folder here 60 | scenario_runner\ 61 | tools\ 62 | ``` 63 | 64 | ### Config file setup: 65 | 66 | For evaluation, you need to update the config files under teamcode/config: 67 | 68 | ```yaml 69 | working_dir: /path/to/Bench2Drive 70 | b2d_path: /path/to/Bench2Drive 71 | ``` 72 | 73 | ### Download checkpoints: 74 | 75 | Checkpoints uploading is in progress. 76 | 77 | ### Running the evaluation: 78 | 79 | **Note:** Unlike Bench2Drive's setup, this evaluation code requires a separate instance of CARLA running. It will NOT run CARLA for you. 80 | 81 | #### Running carla 82 | You can run CARLA on port 30000 persistently (restarts 10 seconds after crashing) with the following command: 83 | 84 | ```bash 85 | while : ; do ./CarlaUE4.sh -carla-rpc-port=30000 ; sleep 10 ; done 86 | ``` 87 | 88 | #### Running the evaluation 89 | From the Bench2Drive directory, run the following command to evaluate MODEL_NAME using carla at port 30000. Update the "user" to your username. 90 | 91 | ```bash 92 | python leaderboard/scripts/run_eval_leaderboard.py user=shadi port=30000 trafficManagerPort=20000 experiments=Ponderer viz=0 experiments.ponderer_model_name=MODEL_NAME checkpoint_file=results.json experiments.agent_root=/PATH/TO/CHECKPOINT/MODEL_NAME experiments.root_path=/PATH/TO/CHECKPOINT/MODEL_NAME/ experiments.runnickname=NICKNAMEHERE resume=0 experiments.epoch_num=37 93 | ``` 94 | -------------------------------------------------------------------------------- /misc/leaderboard_evaluator_local.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Copyright (c) 2018-2019 Intel Corporation. 3 | # authors: German Ros (german.ros@intel.com), Felipe Codevilla (felipe.alcm@gmail.com) 4 | # 5 | # This work is licensed under the terms of the MIT license. 6 | # For a copy, see . 7 | 8 | """ 9 | CARLA Challenge Evaluator Routes 10 | 11 | Provisional code to evaluate Autonomous Agents for the CARLA Autonomous Driving challenge 12 | """ 13 | from __future__ import print_function 14 | 15 | import traceback 16 | import argparse 17 | from argparse import RawTextHelpFormatter 18 | from distutils.version import LooseVersion 19 | import importlib 20 | import os 21 | import pkg_resources 22 | import sys 23 | import carla 24 | import signal 25 | 26 | from srunner.scenariomanager.carla_data_provider import * 27 | from srunner.scenariomanager.timer import GameTime 28 | from srunner.scenariomanager.watchdog import Watchdog 29 | 30 | from leaderboard.scenarios.scenario_manager import ScenarioManager 31 | from leaderboard.scenarios.route_scenario import RouteScenario 32 | from leaderboard.envs.sensor_interface import SensorConfigurationInvalid 33 | from leaderboard.autoagents.agent_wrapper import ( 34 | AgentError, 35 | validate_sensor_configuration, 36 | TickRuntimeError, 37 | ) 38 | from leaderboard.utils.statistics_manager import StatisticsManager, FAILURE_MESSAGES 39 | from leaderboard.utils.route_indexer import RouteIndexer 40 | import atexit 41 | import subprocess 42 | import time 43 | import random 44 | from datetime import datetime 45 | 46 | sensors_to_icons = { 47 | "sensor.camera.rgb": "carla_camera", 48 | "sensor.lidar.ray_cast": "carla_lidar", 49 | "sensor.other.radar": "carla_radar", 50 | "sensor.other.gnss": "carla_gnss", 51 | "sensor.other.imu": "carla_imu", 52 | "sensor.opendrive_map": "carla_opendrive_map", 53 | "sensor.speedometer": "carla_speedometer", 54 | } 55 | 56 | import socket 57 | 58 | 59 | def find_free_port(starting_port): 60 | port = starting_port 61 | while True: 62 | try: 63 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 64 | s.bind(("localhost", port)) 65 | return port 66 | except OSError: 67 | port += 1 68 | 69 | 70 | def get_weather_id(weather_conditions, args, config): 71 | from xml.etree import ElementTree as ET 72 | 73 | tree = ET.parse(os.path.join(args.b2d_path, "leaderboard/data/weather.xml")) 74 | root = tree.getroot() 75 | 76 | def conditions_match(weather, conditions): 77 | for key, value in weather: 78 | if key == "route_percentage": 79 | continue 80 | if str(getattr(conditions, key)) != value: 81 | return False 82 | return True 83 | 84 | for case in root.findall("case"): 85 | weather = case[0].items() 86 | if conditions_match(weather, weather_conditions): 87 | return case.items()[0][1] 88 | return None 89 | 90 | 91 | class LeaderboardEvaluator(object): 92 | """ 93 | Main class of the Leaderboard. Everything is handled from here, 94 | from parsing the given files, to preparing the simulation, to running the route. 95 | """ 96 | 97 | # Tunable parameters 98 | client_timeout = 300.0 # in seconds 99 | frame_rate = 20.0 # in Hz 100 | 101 | def __init__(self, args, statistics_manager): 102 | """ 103 | Setup CARLA client and world 104 | Setup ScenarioManager 105 | """ 106 | self.world = None 107 | self.manager = None 108 | self.sensors = None 109 | self.sensors_initialized = False 110 | self.sensor_icons = [] 111 | self.agent_instance = None 112 | self.route_scenario = None 113 | 114 | self.statistics_manager = statistics_manager 115 | 116 | # This is the ROS1 bridge server instance. This is not encapsulated inside the ROS1 agent because the same 117 | # instance is used on all the routes (i.e., the server is not restarted between routes). This is done 118 | # to avoid reconnection issues between the server and the roslibpy client. 119 | self._ros1_server = None 120 | 121 | # Setup the simulation 122 | self.client, self.client_timeout, self.traffic_manager = self._setup_simulation( 123 | args 124 | ) 125 | 126 | dist = pkg_resources.get_distribution("carla") 127 | if dist.version != "leaderboard": 128 | if LooseVersion(dist.version) < LooseVersion("0.9.10"): 129 | raise ImportError( 130 | "CARLA version 0.9.10.1 or newer required. CARLA version found: {}".format( 131 | dist 132 | ) 133 | ) 134 | 135 | # Load agent 136 | module_name = os.path.basename(args.agent).split(".")[0] 137 | sys.path.insert(0, os.path.dirname(args.agent)) 138 | self.module_agent = importlib.import_module(module_name) 139 | 140 | # Create the ScenarioManager 141 | self.manager = ScenarioManager( 142 | args.timeout, self.statistics_manager, args.debug 143 | ) 144 | 145 | # Time control for summary purposes 146 | self._start_time = GameTime.get_time() 147 | self._end_time = None 148 | 149 | # Prepare the agent timer 150 | self._agent_watchdog = None 151 | # signal.signal(signal.SIGINT, self._signal_handler) 152 | 153 | self._client_timed_out = False 154 | 155 | def _signal_handler(self, signum, frame): 156 | """ 157 | Terminate scenario ticking when receiving a signal interrupt. 158 | Either the agent initialization watchdog is triggered, or the runtime one at scenario manager 159 | """ 160 | if self._agent_watchdog and not self._agent_watchdog.get_status(): 161 | raise RuntimeError( 162 | "Timeout: Agent took longer than {}s to setup".format( 163 | self.client_timeout 164 | ) 165 | ) 166 | elif self.manager: 167 | self.manager.signal_handler(signum, frame) 168 | 169 | def __del__(self): 170 | """ 171 | Cleanup and delete actors, ScenarioManager and CARLA world 172 | """ 173 | if hasattr(self, "manager") and self.manager: 174 | del self.manager 175 | if hasattr(self, "world") and self.world: 176 | del self.world 177 | 178 | def _get_running_status(self): 179 | """ 180 | returns: 181 | bool: False if watchdog exception occured, True otherwise 182 | """ 183 | if self._agent_watchdog: 184 | return self._agent_watchdog.get_status() 185 | return False 186 | 187 | def _cleanup(self): 188 | """ 189 | Remove and destroy all actors 190 | """ 191 | CarlaDataProvider.cleanup() 192 | 193 | if self._agent_watchdog: 194 | self._agent_watchdog.stop() 195 | 196 | try: 197 | if self.agent_instance: 198 | self.agent_instance.destroy() 199 | self.agent_instance = None 200 | except Exception as e: 201 | print("\n\033[91mFailed to stop the agent:", flush=True) 202 | print(f"\n{traceback.format_exc()}\033[0m", flush=True) 203 | 204 | if self.route_scenario: 205 | self.route_scenario.remove_all_actors() 206 | self.route_scenario = None 207 | if self.statistics_manager: 208 | self.statistics_manager.remove_scenario() 209 | 210 | if self.manager: 211 | self._client_timed_out = not self.manager.get_running_status() 212 | self.manager.cleanup() 213 | 214 | # Make sure no sensors are left streaming 215 | alive_sensors = self.world.get_actors().filter("*sensor*") 216 | for sensor in alive_sensors: 217 | sensor.stop() 218 | sensor.destroy() 219 | 220 | def _setup_simulation(self, args): 221 | """ 222 | Prepares the simulation by getting the client, and setting up the world and traffic manager settings 223 | """ 224 | self.carla_path = os.environ["CARLA_ROOT"] 225 | 226 | attempts = 0 227 | num_max_restarts = 20 228 | while attempts < num_max_restarts: 229 | try: 230 | client = carla.Client(args.host, args.port) 231 | if args.timeout: 232 | client_timeout = args.timeout 233 | client.set_timeout(client_timeout) 234 | 235 | settings = carla.WorldSettings( 236 | synchronous_mode=True, 237 | fixed_delta_seconds=1.0 / self.frame_rate, 238 | deterministic_ragdolls=True, 239 | spectator_as_ego=False, 240 | ) 241 | client.get_world().apply_settings(settings) 242 | print(f"load_world success , attempts={attempts}", flush=True) 243 | break 244 | except Exception as e: 245 | print(f"load_world failed , attempts={attempts}", flush=True) 246 | print(e, flush=True) 247 | attempts += 1 248 | time.sleep(5) 249 | attempts = 0 250 | num_max_restarts = 40 251 | while attempts < num_max_restarts: 252 | try: 253 | args.traffic_manager_port = find_free_port(args.traffic_manager_port) 254 | traffic_manager = client.get_trafficmanager(args.traffic_manager_port) 255 | traffic_manager.set_synchronous_mode(True) 256 | traffic_manager.set_hybrid_physics_mode(True) 257 | print(f"traffic_manager init success, try_time={attempts}", flush=True) 258 | break 259 | except Exception as e: 260 | print(f"traffic_manager init fail, try_time={attempts}", flush=True) 261 | print(e, flush=True) 262 | attempts += 1 263 | time.sleep(5) 264 | return client, client_timeout, traffic_manager 265 | 266 | def _reset_world_settings(self): 267 | """ 268 | Changes the modified world settings back to asynchronous 269 | """ 270 | # Has simulation failed? 271 | if self.world and self.manager and not self._client_timed_out: 272 | # Reset to asynchronous mode 273 | self.world.tick() # TODO: Make sure all scenario actors have been destroyed 274 | settings = self.world.get_settings() 275 | settings.synchronous_mode = False 276 | settings.fixed_delta_seconds = None 277 | settings.deterministic_ragdolls = False 278 | settings.spectator_as_ego = True 279 | self.world.apply_settings(settings) 280 | 281 | # Make the TM back to async 282 | self.traffic_manager.set_synchronous_mode(False) 283 | self.traffic_manager.set_hybrid_physics_mode(False) 284 | 285 | def _load_and_wait_for_world(self, args, town): 286 | """ 287 | Load a new CARLA world without changing the settings and provide data to CarlaDataProvider 288 | """ 289 | self.world = self.client.load_world(town, reset_settings=False) 290 | 291 | # Large Map settings are always reset, for some reason 292 | settings = self.world.get_settings() 293 | settings.tile_stream_distance = 650 294 | settings.actor_active_distance = 650 295 | self.world.apply_settings(settings) 296 | 297 | self.world.reset_all_traffic_lights() 298 | CarlaDataProvider.set_client(self.client) 299 | CarlaDataProvider.set_traffic_manager_port(args.traffic_manager_port) 300 | CarlaDataProvider.set_world(self.world) 301 | 302 | # This must be here so that all route repetitions use the same 'unmodified' seed 303 | self.traffic_manager.set_random_device_seed(args.traffic_manager_seed) 304 | 305 | # Wait for the world to be ready 306 | self.world.tick() 307 | 308 | map_name = CarlaDataProvider.get_map().name.split("/")[-1] 309 | if map_name != town: 310 | raise Exception( 311 | "The CARLA server uses the wrong map!" 312 | " This scenario requires the use of map {}".format(town) 313 | ) 314 | 315 | def _register_statistics(self, route_index, entry_status, crash_message=""): 316 | """ 317 | Computes and saves the route statistics 318 | """ 319 | print("\033[1m> Registering the route statistics\033[0m", flush=True) 320 | self.statistics_manager.save_entry_status(entry_status) 321 | self.statistics_manager.compute_route_statistics( 322 | route_index, 323 | self.manager.scenario_duration_system, 324 | self.manager.scenario_duration_game, 325 | crash_message, 326 | ) 327 | 328 | def _load_and_run_scenario(self, args, config): 329 | """ 330 | Load and run the scenario given by config. 331 | 332 | Depending on what code fails, the simulation will either stop the route and 333 | continue from the next one, or report a crash and stop. 334 | """ 335 | crash_message = "" 336 | entry_status = "Started" 337 | 338 | print( 339 | "\n\033[1m========= Preparing {} (repetition {}) =========\033[0m".format( 340 | config.name, config.repetition_index 341 | ), 342 | flush=True, 343 | ) 344 | 345 | # Prepare the statistics of the route 346 | route_name = f"{config.name}_rep{config.repetition_index}" 347 | scenario_name = config.scenario_configs[0].name 348 | town_name = str(config.town) 349 | weather_id = get_weather_id(config.weather[0][1], args, config) 350 | currentDateAndTime = datetime.now() 351 | currentTime = currentDateAndTime.strftime("%m_%d_%H_%M_%S") 352 | save_name = ( 353 | f"{route_name}_{town_name}_{scenario_name}_{weather_id}_{currentTime}" 354 | ) 355 | self.statistics_manager.create_route_data( 356 | route_name, scenario_name, weather_id, save_name, town_name, config.index 357 | ) 358 | 359 | print("\033[1m> Loading the world\033[0m", flush=True) 360 | 361 | # Load the world and the scenario 362 | try: 363 | self._load_and_wait_for_world(args, config.town) 364 | self.route_scenario = RouteScenario( 365 | world=self.world, config=config, debug_mode=args.debug 366 | ) 367 | self.statistics_manager.set_scenario(self.route_scenario) 368 | 369 | except Exception: 370 | # The scenario is wrong -> set the ejecution to crashed and stop 371 | print("\n\033[91mThe scenario could not be loaded:", flush=True) 372 | print(f"\n{traceback.format_exc()}\033[0m", flush=True) 373 | 374 | entry_status, crash_message = FAILURE_MESSAGES["Simulation"] 375 | self._register_statistics(config.index, entry_status, crash_message) 376 | self._cleanup() 377 | return True 378 | 379 | print("\033[1m> Setting up the agent\033[0m", flush=True) 380 | 381 | # Set up the user's agent, and the timer to avoid freezing the simulation 382 | try: 383 | self._agent_watchdog = Watchdog(args.timeout) 384 | self._agent_watchdog.start() 385 | agent_class_name = getattr(self.module_agent, "get_entry_point")() 386 | agent_class_obj = getattr(self.module_agent, agent_class_name) 387 | 388 | # Start the ROS1 bridge server only for ROS1 based agents. 389 | if ( 390 | getattr(agent_class_obj, "get_ros_version")() == 1 391 | and self._ros1_server is None 392 | ): 393 | from leaderboard.autoagents.ros1_agent import ROS1Server 394 | 395 | self._ros1_server = ROS1Server() 396 | self._ros1_server.start() 397 | 398 | self.agent_instance = agent_class_obj(args.host, args.port, args.debug) 399 | self.agent_instance.set_global_plan( 400 | self.route_scenario.gps_route, self.route_scenario.route 401 | ) 402 | # args.agent_config = args.agent_config + "+" + save_name 403 | # self.agent_instance.setup(args.agent_config) 404 | self.agent_instance.setup(args) 405 | 406 | # Check and store the sensors 407 | if not self.sensors: 408 | self.sensors = self.agent_instance.sensors() 409 | track = self.agent_instance.track 410 | 411 | # validate_sensor_configuration(self.sensors, track, args.track) 412 | 413 | self.sensor_icons = [ 414 | sensors_to_icons[sensor["type"]] for sensor in self.sensors 415 | ] 416 | self.statistics_manager.save_sensors(self.sensor_icons) 417 | self.statistics_manager.write_statistics() 418 | 419 | self.sensors_initialized = True 420 | 421 | self._agent_watchdog.stop() 422 | self._agent_watchdog = None 423 | 424 | except SensorConfigurationInvalid as e: 425 | # The sensors are invalid -> set the ejecution to rejected and stop 426 | print("\n\033[91mThe sensor's configuration used is invalid:", flush=True) 427 | print(f"{e}\033[0m\n", flush=True) 428 | 429 | entry_status, crash_message = FAILURE_MESSAGES["Sensors"] 430 | self._register_statistics(config.index, entry_status, crash_message) 431 | self._cleanup() 432 | return True 433 | 434 | except Exception as e: 435 | # The agent setup has failed -> start the next route 436 | print("\n\033[91mCould not set up the required agent:", flush=True) 437 | print(f"\n{traceback.format_exc()}\033[0m", flush=True) 438 | print(f"{e}\033[0m\n", flush=True) 439 | 440 | entry_status, crash_message = FAILURE_MESSAGES["Agent_init"] 441 | self._register_statistics(config.index, entry_status, crash_message) 442 | self._cleanup() 443 | return True 444 | 445 | print("\033[1m> Running the route\033[0m", flush=True) 446 | 447 | # Run the scenario 448 | try: 449 | # Load scenario and run it 450 | if args.record: 451 | self.client.start_recorder( 452 | "{}/{}_rep{}.log".format( 453 | args.record, config.name, config.repetition_index 454 | ) 455 | ) 456 | self.manager.load_scenario( 457 | self.route_scenario, 458 | self.agent_instance, 459 | config.index, 460 | config.repetition_index, 461 | ) 462 | self.manager.tick_count = 0 463 | self.manager.run_scenario() 464 | 465 | except AgentError: 466 | # The agent has failed -> stop the route 467 | print("\n\033[91mStopping the route, the agent has crashed:", flush=True) 468 | print(f"\n{traceback.format_exc()}\033[0m") 469 | 470 | entry_status, crash_message = FAILURE_MESSAGES["Agent_runtime"] 471 | 472 | except KeyboardInterrupt: 473 | return True 474 | 475 | except TickRuntimeError: 476 | entry_status, crash_message = "Started", "TickRuntime" 477 | 478 | except Exception: 479 | print("\n\033[91mError during the simulation:", flush=True) 480 | print(f"\n{traceback.format_exc()}\033[0m", flush=True) 481 | 482 | entry_status, crash_message = FAILURE_MESSAGES["Simulation"] 483 | 484 | # Stop the scenario 485 | try: 486 | print("\033[1m> Stopping the route\033[0m", flush=True) 487 | self.manager.stop_scenario() 488 | self._register_statistics(config.index, entry_status, crash_message) 489 | 490 | if args.record: 491 | self.client.stop_recorder() 492 | 493 | self._cleanup() 494 | 495 | except Exception: 496 | print( 497 | "\n\033[91mFailed to stop the scenario, the statistics might be empty:", 498 | flush=True, 499 | ) 500 | print(f"\n{traceback.format_exc()}\033[0m", flush=True) 501 | 502 | _, crash_message = FAILURE_MESSAGES["Simulation"] 503 | 504 | # If the simulation crashed, stop the leaderboard, for the rest, move to the next route 505 | return crash_message == "Simulation crashed" 506 | 507 | def run(self, args): 508 | """ 509 | Run the challenge mode 510 | """ 511 | route_indexer = RouteIndexer(args.routes, args.repetitions, args.routes_subset) 512 | 513 | if args.resume: 514 | resume = route_indexer.validate_and_resume(args.checkpoint) 515 | else: 516 | resume = False 517 | 518 | if resume: 519 | self.statistics_manager.add_file_records(args.checkpoint) 520 | else: 521 | self.statistics_manager.clear_records() 522 | self.statistics_manager.save_progress(route_indexer.index, route_indexer.total) 523 | self.statistics_manager.write_statistics() 524 | 525 | crashed = False 526 | t1 = time.time() 527 | while route_indexer.peek() and not crashed: 528 | 529 | # Run the scenario 530 | config = route_indexer.get_next_config() 531 | crashed = self._load_and_run_scenario(args, config) 532 | # Save the progress and write the route statistics 533 | self.statistics_manager.save_progress( 534 | route_indexer.index, route_indexer.total 535 | ) 536 | self.statistics_manager.write_statistics() 537 | if crashed: 538 | print( 539 | f"{route_indexer.index} crash, [{route_indexer.index}/{route_indexer.total}], please restart", 540 | flush=True, 541 | ) 542 | break 543 | 544 | # Shutdown ROS1 bridge server if necessary 545 | if self._ros1_server is not None: 546 | self._ros1_server.shutdown() 547 | 548 | # Go back to asynchronous mode 549 | self._reset_world_settings() 550 | 551 | if not crashed: 552 | # Save global statistics 553 | print(f"cost time={time.time()-t1}", flush=True) 554 | print("\033[1m> Registering the global statistics\033[0m", flush=True) 555 | self.statistics_manager.compute_global_statistics() 556 | self.statistics_manager.validate_and_write_statistics( 557 | self.sensors_initialized, crashed 558 | ) 559 | 560 | return crashed 561 | 562 | 563 | def main_eval(arguments): 564 | # def main(): 565 | # description = ( 566 | # "CARLA AD Leaderboard Evaluation: evaluate your Agent in CARLA scenarios\n" 567 | # ) 568 | 569 | # # general parameters 570 | # parser = argparse.ArgumentParser( 571 | # description=description, formatter_class=RawTextHelpFormatter 572 | # ) 573 | # parser.add_argument( 574 | # "--host", default="localhost", help="IP of the host server (default: localhost)" 575 | # ) 576 | # parser.add_argument( 577 | # "--port", default=2000, type=int, help="TCP port to listen to (default: 2000)" 578 | # ) 579 | # parser.add_argument( 580 | # "--traffic-manager-port", 581 | # default=8000, 582 | # type=int, 583 | # help="Port to use for the TrafficManager (default: 8000)", 584 | # ) 585 | # parser.add_argument( 586 | # "--traffic-manager-seed", 587 | # default=0, 588 | # type=int, 589 | # help="Seed used by the TrafficManager (default: 0)", 590 | # ) 591 | # parser.add_argument("--debug", type=int, help="Run with debug output", default=0) 592 | # parser.add_argument( 593 | # "--record", 594 | # type=str, 595 | # default="", 596 | # help="Use CARLA recording feature to create a recording of the scenario", 597 | # ) 598 | # parser.add_argument( 599 | # "--timeout", 600 | # default=600.0, 601 | # type=float, 602 | # help="Set the CARLA client timeout value in seconds", 603 | # ) 604 | 605 | # # simulation setup 606 | # parser.add_argument( 607 | # "--routes", required=True, help="Name of the routes file to be executed." 608 | # ) 609 | # parser.add_argument( 610 | # "--routes-subset", default="", type=str, help="Execute a specific set of routes" 611 | # ) 612 | # parser.add_argument( 613 | # "--repetitions", type=int, default=1, help="Number of repetitions per route." 614 | # ) 615 | 616 | # # agent-related options 617 | # parser.add_argument( 618 | # "-a", 619 | # "--agent", 620 | # type=str, 621 | # help="Path to Agent's py file to evaluate", 622 | # required=True, 623 | # ) 624 | # parser.add_argument( 625 | # "--agent-config", 626 | # type=str, 627 | # help="Path to Agent's configuration file", 628 | # default="", 629 | # ) 630 | 631 | # parser.add_argument( 632 | # "--track", type=str, default="SENSORS", help="Participation track: SENSORS, MAP" 633 | # ) 634 | # parser.add_argument( 635 | # "--resume", 636 | # type=bool, 637 | # default=False, 638 | # help="Resume execution from last checkpoint?", 639 | # ) 640 | # parser.add_argument( 641 | # "--checkpoint", 642 | # type=str, 643 | # default="./simulation_results.json", 644 | # help="Path to checkpoint used for saving statistics and resuming", 645 | # ) 646 | # parser.add_argument( 647 | # "--debug-checkpoint", 648 | # type=str, 649 | # default="./live_results.txt", 650 | # help="Path to checkpoint used for saving live results", 651 | # ) 652 | # parser.add_argument("--gpu-rank", type=int, default=0) 653 | # arguments = parser.parse_args() 654 | 655 | statistics_manager = StatisticsManager( 656 | arguments.checkpoint, arguments.debug_checkpoint 657 | ) 658 | leaderboard_evaluator = LeaderboardEvaluator(arguments, statistics_manager) 659 | crashed = leaderboard_evaluator.run(arguments) 660 | 661 | del leaderboard_evaluator 662 | 663 | 664 | def main(): 665 | description = ( 666 | "CARLA AD Leaderboard Evaluation: evaluate your Agent in CARLA scenarios\n" 667 | ) 668 | 669 | # general parameters 670 | parser = argparse.ArgumentParser( 671 | description=description, formatter_class=RawTextHelpFormatter 672 | ) 673 | parser.add_argument( 674 | "--host", default="localhost", help="IP of the host server (default: localhost)" 675 | ) 676 | parser.add_argument( 677 | "--port", default=2000, type=int, help="TCP port to listen to (default: 2000)" 678 | ) 679 | parser.add_argument( 680 | "--traffic-manager-port", 681 | default=8000, 682 | type=int, 683 | help="Port to use for the TrafficManager (default: 8000)", 684 | ) 685 | parser.add_argument( 686 | "--traffic-manager-seed", 687 | default=0, 688 | type=int, 689 | help="Seed used by the TrafficManager (default: 0)", 690 | ) 691 | parser.add_argument("--debug", type=int, help="Run with debug output", default=0) 692 | parser.add_argument( 693 | "--record", 694 | type=str, 695 | default="", 696 | help="Use CARLA recording feature to create a recording of the scenario", 697 | ) 698 | parser.add_argument( 699 | "--timeout", 700 | default=600.0, 701 | type=float, 702 | help="Set the CARLA client timeout value in seconds", 703 | ) 704 | 705 | # simulation setup 706 | parser.add_argument( 707 | "--routes", required=True, help="Name of the routes file to be executed." 708 | ) 709 | parser.add_argument( 710 | "--routes-subset", default="", type=str, help="Execute a specific set of routes" 711 | ) 712 | parser.add_argument( 713 | "--repetitions", type=int, default=1, help="Number of repetitions per route." 714 | ) 715 | 716 | # agent-related options 717 | parser.add_argument( 718 | "-a", 719 | "--agent", 720 | type=str, 721 | help="Path to Agent's py file to evaluate", 722 | required=True, 723 | ) 724 | parser.add_argument( 725 | "--agent-config", 726 | type=str, 727 | help="Path to Agent's configuration file", 728 | default="", 729 | ) 730 | 731 | parser.add_argument( 732 | "--track", type=str, default="SENSORS", help="Participation track: SENSORS, MAP" 733 | ) 734 | parser.add_argument( 735 | "--resume", 736 | type=bool, 737 | default=False, 738 | help="Resume execution from last checkpoint?", 739 | ) 740 | parser.add_argument( 741 | "--checkpoint", 742 | type=str, 743 | default="./simulation_results.json", 744 | help="Path to checkpoint used for saving statistics and resuming", 745 | ) 746 | parser.add_argument( 747 | "--debug-checkpoint", 748 | type=str, 749 | default="./live_results.txt", 750 | help="Path to checkpoint used for saving live results", 751 | ) 752 | parser.add_argument("--gpu-rank", type=int, default=0) 753 | arguments = parser.parse_args() 754 | 755 | statistics_manager = StatisticsManager( 756 | arguments.checkpoint, arguments.debug_checkpoint 757 | ) 758 | leaderboard_evaluator = LeaderboardEvaluator(arguments, statistics_manager) 759 | crashed = leaderboard_evaluator.run(arguments) 760 | 761 | del leaderboard_evaluator 762 | 763 | if crashed: 764 | sys.exit(-1) 765 | else: 766 | sys.exit(0) 767 | 768 | 769 | if __name__ == "__main__": 770 | main() 771 | -------------------------------------------------------------------------------- /misc/run_eval_leaderboard.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from leaderboard.leaderboard_evaluator_local import main 3 | import argparse 4 | import os 5 | import sys 6 | import hydra 7 | from pathlib import Path 8 | from omegaconf import DictConfig, OmegaConf 9 | 10 | 11 | @hydra.main(config_path="../../Bench2DriveZoo/team_code/config", config_name="config") 12 | def main(cfg): 13 | print(OmegaConf.to_yaml(cfg)) 14 | 15 | cfg_org = cfg.copy() 16 | cfg = cfg.experiments 17 | 18 | print(cfg_org.eval.routes) 19 | print(cfg_org.checkpoint) 20 | print("Working directory : {}".format(os.getcwd())) 21 | print(f"Save gifs: {cfg_org.save_explainability_viz}") 22 | 23 | # create result folder 24 | Path(cfg_org.checkpoint).parent.mkdir(parents=True, exist_ok=True) 25 | 26 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{cfg_org.CUDA_VISIBLE_DEVICES}" 27 | 28 | arg_dict0 = OmegaConf.to_container(cfg_org.eval, resolve=True) 29 | arg_dict1 = OmegaConf.to_container(cfg, resolve=True) 30 | arg_dict2 = OmegaConf.to_container(cfg_org, resolve=True) 31 | arg_dict1.update(arg_dict2) 32 | arg_dict1.update(arg_dict0) 33 | args = argparse.Namespace(**arg_dict1) 34 | 35 | from leaderboard import leaderboard_evaluator_local 36 | import numpy as np 37 | 38 | # np.warnings.filterwarnings("error", category=np.VisibleDeprecationWarning) 39 | 40 | leaderboard_evaluator_local.main_eval(args) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.7.0 2 | antlr4-python3-runtime==4.9.3 3 | asttokens==2.4.1 4 | certifi 5 | cffi 6 | charset-normalizer 7 | click==8.1.7 8 | comm 9 | contourpy==1.1.1 10 | coverage==7.6.1 11 | cycler==0.12.1 12 | debugpy 13 | decorator 14 | deepspeed==0.14.4 15 | dictor==0.1.12 16 | docker-pycreds==0.4.0 17 | dotmap==1.3.30 18 | einops==0.8.0 19 | elementpath==1.3.3 20 | ephem==4.1.5 21 | executing==2.0.1 22 | filelock==3.0.12 23 | filterpy==1.4.5 24 | flash-attn==2.6.3 25 | fonttools==4.53.1 26 | fsspec==2024.6.1 27 | gitdb==4.0.11 28 | GitPython==3.1.43 29 | h2 30 | hjson==3.1.0 31 | hpack==4.0.0 32 | huggingface-hub==0.24.5 33 | hydra-core==1.3.2 34 | hyperframe 35 | idna 36 | imageio==2.35.1 37 | imgaug==0.4.0 38 | importlib_metadata 39 | importlib_resources==6.4.0 40 | ipdb==0.13.13 41 | ipykernel==6.29.5 42 | ipython==8.12.3 43 | jedi 44 | Jinja2 45 | joblib==1.4.2 46 | jsonpickle==4.0.2 47 | jupyter_client 48 | jupyter_core 49 | kiwisolver==1.4.5 50 | laspy==2.5.4 51 | lazy_loader==0.4 52 | line_profiler==4.2.0 53 | MarkupSafe 54 | matplotlib==3.5.3 55 | matplotlib-inline 56 | mpmath 57 | natsort==8.4.0 58 | nest_asyncio 59 | networkx 60 | ninja==1.11.1.1 61 | numpy 62 | nvidia-ml-py==12.555.43 63 | omegaconf==2.3.0 64 | opencv-python==4.2.0.32 65 | packaging==24.1 66 | pandas==2.0.3 67 | parso 68 | pexpect 69 | pickleshare 70 | Pillow 71 | platformdirs==4.2.2 72 | plotly==5.23.0 73 | prompt_toolkit==3.0.47 74 | protobuf==5.27.3 75 | psutil 76 | ptyprocess 77 | pure_eval==0.2.3 78 | py-cpuinfo==9.0.0 79 | py-trees==0.8.3 80 | pycparser 81 | pydantic==2.8.2 82 | pydantic_core==2.20.1 83 | pydot==3.0.1 84 | pygame==2.6.0 85 | Pygments 86 | pyparsing==3.1.2 87 | PySocks 88 | python-dateutil 89 | pytz==2024.1 90 | PyWavelets==1.4.1 91 | PyYAML 92 | pyzmq 93 | rdp==0.8 94 | regex==2024.7.24 95 | requests 96 | safetensors==0.4.4 97 | scikit-image==0.21.0 98 | scikit-learn==1.3.2 99 | scipy==1.10.1 100 | sentencepiece==0.2.0 101 | sentry-sdk==2.12.0 102 | setproctitle==1.3.3 103 | Shapely==1.7.1 104 | simple_watchdog_timer==0.1.1 105 | six 106 | smmap==5.0.1 107 | stack-data==0.6.3 108 | sympy 109 | tabulate==0.9.0 110 | tenacity==9.0.0 111 | threadpoolctl==3.5.0 112 | tifffile==2023.7.10 113 | timm==1.0.8 114 | tokenizers==0.19.1 115 | tomli==2.0.1 116 | torch==2.4.0 117 | torchaudio==2.4.0 118 | torchvision==0.19.0 119 | tornado 120 | tqdm==4.66.5 121 | traitlets 122 | transformers==4.43.4 123 | transforms3d==0.4.2 124 | triton==3.0.0 125 | typing_extensions 126 | tzdata==2024.1 127 | ujson==5.10.0 128 | urllib3 129 | wandb==0.17.5 130 | wcwidth 131 | xmlschema==1.0.18 132 | zipp==3.20.2 133 | zstandard==0.23.0 -------------------------------------------------------------------------------- /team_code/config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - user: example 3 | - experiments: eta 4 | - eval: b2d 5 | 6 | host: localhost 7 | port: 2000 8 | trafficManagerPort: 8000 9 | traffic_manager_port: ${trafficManagerPort} 10 | trafficManagerSeed: 0 11 | dataProviderSeed: 0 12 | debug: 0 13 | viz: 0 14 | viz_interval: 50 15 | viz_max: 2000 16 | viz_extra: 0 17 | record: '' 18 | timeout: 600.0 19 | 20 | hydra: 21 | run: 22 | dir: ${experiments.agent_root}/${save_path}/${experiments.runnickname} 23 | job: 24 | config: 25 | override_dirname: 26 | exclude_keys: 27 | - eval 28 | - experiments 29 | - experiments.wanderer_model_name 30 | - experiments.ponderer_model_name 31 | - port 32 | - trafficManagerPort 33 | - experiments.epoch_num 34 | - experiments.use_gru_output 35 | - user 36 | - traffic_manager_seed 37 | kv_sep: '=' 38 | item_sep: '_' 39 | env_set: 40 | OMP_NUM_THREADS: 1 41 | 42 | repetitions: 1 43 | track: SENSORS 44 | resume: 1 45 | save_path: evallogs 46 | log_save_path: result_logs 47 | checkpoint_file: results.json 48 | debug_checkpoint: ${hydra:run.dir}/debug/${checkpoint_file} 49 | checkpoint: ${hydra:run.dir}/${checkpoint_file} 50 | traffic_manager_seed: 0 51 | 52 | 53 | DEBUG_CHALLENGE: 0 54 | CUDA_VISIBLE_DEVICES: 0 55 | SEED_OFFSET: 0 56 | -------------------------------------------------------------------------------- /team_code/config/eval/b2d.yaml: -------------------------------------------------------------------------------- 1 | BENCHMARK: b2d 2 | BENCHMARKNICKNAME: b2dfull 3 | route_rel: leaderboard/data/bench2drive220.xml 4 | routes: ${user.working_dir}/${eval.route_rel} -------------------------------------------------------------------------------- /team_code/config/experiments/eta.yaml: -------------------------------------------------------------------------------- 1 | BASE_PORT: 30000 2 | BASE_TM_PORT: 50000 3 | IS_BENCH2DRIVE: True 4 | BASE_ROUTES: leaderboard/data/bench2drive220 5 | 6 | BASE_CHECKPOINT_ENDPOINT: endpoint 7 | SAVE_PATH: ./dummy 8 | PLANNER_TYPE: only_traj 9 | 10 | agent: ${user.working_dir}/Bench2DriveZoo/team_code/eta_agent.py 11 | runnickname: 12 | root_path: ${user.working_dir} 13 | ponderer_model_name: 14 | agent_root: ${root_path}/${experiments.ponderer_model_name} 15 | ghost_agent_root: null 16 | ghost_only_viz: False 17 | 18 | epoch_num: 37 19 | 20 | creep_delay: 100 21 | creep_duration: 10 22 | use_creep: True 23 | 24 | 25 | steer_damping: 0.5 26 | use_gt_frc: null 27 | 28 | # Bench2Drive related 29 | routes_subset: 30 | b2d_path: ${user.b2d_path} -------------------------------------------------------------------------------- /team_code/config/user/example.yaml: -------------------------------------------------------------------------------- 1 | working_dir: /home/shadi/research/ 2 | b2d_path: /home/shadi/research/ETA/Bench2Drive -------------------------------------------------------------------------------- /team_code/planner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import deque 3 | 4 | import numpy as np 5 | import math 6 | 7 | EARTH_RADIUS_EQUA = 6378137.0 8 | 9 | DEBUG = False 10 | 11 | 12 | class Plotter(object): 13 | def __init__(self, size): 14 | self.size = size 15 | self.clear() 16 | self.title = str(self.size) 17 | 18 | def clear(self): 19 | from PIL import Image, ImageDraw 20 | 21 | self.img = Image.fromarray(np.zeros((self.size, self.size, 3), dtype=np.uint8)) 22 | self.draw = ImageDraw.Draw(self.img) 23 | 24 | def dot(self, pos, node, color=(255, 255, 255), label=None, r=2): 25 | x, y = 5.5 * (pos - node) 26 | x += self.size / 2 27 | y += self.size / 2 28 | 29 | self.draw.ellipse((x - r, y - r, x + r, y + r), color) 30 | 31 | # If a label is provided, write it 32 | if label: 33 | self.draw.text((x, y), label, fill=color) 34 | 35 | def show(self): 36 | if not DEBUG: 37 | return 38 | 39 | import cv2 40 | 41 | cv2.imshow(self.title, cv2.cvtColor(np.array(self.img), cv2.COLOR_BGR2RGB)) 42 | cv2.waitKey(1) 43 | 44 | 45 | class RoutePlanner(object): 46 | def __init__( 47 | self, min_distance, max_distance, debug_size=256, lat_ref=42.0, lon_ref=2.0 48 | ): 49 | self.route = deque() 50 | self.min_distance = min_distance 51 | self.max_distance = max_distance 52 | 53 | # self.mean = np.array([49.0, 8.0]) # for carla 9.9 54 | # self.scale = np.array([111324.60662786, 73032.1570362]) # for carla 9.9 55 | self.mean = np.array([0.0, 0.0]) # for carla 9.10 56 | self.scale = np.array([111324.60662786, 111319.490945]) # for carla 9.10 57 | 58 | self.debug = Plotter(debug_size) 59 | # self.lat_ref, self.lon_ref = self._get_latlon_ref() 60 | self.lat_ref = lat_ref 61 | self.lon_ref = lon_ref 62 | 63 | def set_route(self, global_plan, gps=False, global_plan_world=None): 64 | self.route.clear() 65 | 66 | if global_plan_world: 67 | for (pos, cmd), (pos_word, _) in zip(global_plan, global_plan_world): 68 | if gps: 69 | pos = self.gps_to_location(np.array([pos["lat"], pos["lon"]])) 70 | # pos -= self.mean 71 | # pos *= self.scale 72 | else: 73 | pos = np.array([pos.location.x, pos.location.y]) 74 | # pos -= self.mean 75 | 76 | self.route.append((pos, cmd, pos_word)) 77 | else: 78 | for pos, cmd in global_plan: 79 | if gps: 80 | pos = self.gps_to_location(np.array([pos["lat"], pos["lon"]])) 81 | # pos -= self.mean 82 | # pos *= self.scale 83 | else: 84 | pos = np.array([pos.location.x, pos.location.y]) 85 | # pos -= self.mean 86 | 87 | self.route.append((pos, cmd)) 88 | 89 | def run_step(self, gps, return_both=False): 90 | self.debug.clear() 91 | 92 | if len(self.route) == 1: 93 | if return_both: 94 | return self.route[0], self.route[0] 95 | return self.route[0] 96 | 97 | to_pop = 0 98 | farthest_in_range = -np.inf 99 | cumulative_distance = 0.0 100 | 101 | for i in range(1, len(self.route)): 102 | if cumulative_distance > self.max_distance: 103 | break 104 | 105 | cumulative_distance += np.linalg.norm( 106 | self.route[i][0] - self.route[i - 1][0] 107 | ) 108 | distance = np.linalg.norm(self.route[i][0] - gps) 109 | 110 | if distance <= self.min_distance and distance > farthest_in_range: 111 | farthest_in_range = distance 112 | to_pop = i 113 | 114 | r = 255 * int(distance > self.min_distance) 115 | g = 255 * int(self.route[i][1].value == 4) 116 | b = 255 117 | self.debug.dot( 118 | gps, 119 | self.route[i][0], 120 | (r, g, b), 121 | label=f"{distance:.2f}_{to_pop}_{farthest_in_range:.2f}", 122 | ) 123 | 124 | for _ in range(to_pop): 125 | if len(self.route) > 2: 126 | self.route.popleft() 127 | 128 | self.debug.dot(gps, self.route[0][0], (0, 255, 0)) 129 | self.debug.dot(gps, self.route[1][0], (255, 0, 0)) 130 | self.debug.dot(gps, gps, (0, 0, 255)) 131 | self.debug.show() 132 | 133 | if return_both: 134 | return ( 135 | (self.route[1], self.route[2]) 136 | if len(self.route) > 2 137 | else (self.route[1], self.route[1]) 138 | ) 139 | 140 | return self.route[1] 141 | 142 | def gps_to_location(self, gps): 143 | # gps content: numpy array: [lat, lon, alt] 144 | lat, lon = gps 145 | scale = math.cos(self.lat_ref * math.pi / 180.0) 146 | my = math.log(math.tan((lat + 90) * math.pi / 360.0)) * ( 147 | EARTH_RADIUS_EQUA * scale 148 | ) 149 | mx = (lon * (math.pi * EARTH_RADIUS_EQUA * scale)) / 180.0 150 | y = ( 151 | scale 152 | * EARTH_RADIUS_EQUA 153 | * math.log(math.tan((90.0 + self.lat_ref) * math.pi / 360.0)) 154 | - my 155 | ) 156 | x = mx - scale * self.lon_ref * math.pi * EARTH_RADIUS_EQUA / 180.0 157 | return np.array([x, y]) 158 | --------------------------------------------------------------------------------