├── .github
└── FUNDING.yml
├── LICENSE
├── README.md
├── assets
├── overview.png
└── results.png
├── carformer
├── .gitignore
├── Makefile
├── carformer
│ ├── __init__.py
│ ├── config
│ │ ├── __init__.py
│ │ ├── backbone
│ │ │ ├── llama-base.yaml
│ │ │ ├── llama-micro.yaml
│ │ │ └── llama.yaml
│ │ ├── carformer_config.py
│ │ ├── config.yaml
│ │ ├── dataset
│ │ │ └── b2d.yaml
│ │ ├── deepspeed
│ │ │ ├── basic.yaml
│ │ │ ├── default.yaml
│ │ │ ├── disable.yaml
│ │ │ ├── zero2.yaml
│ │ │ └── zero3.yaml
│ │ ├── experiments
│ │ │ ├── init_rgb_from_ckpt.yaml
│ │ │ ├── onlyfrc.yaml
│ │ │ ├── overfit.yaml
│ │ │ └── use_light_encoder.yaml
│ │ ├── hyperparams
│ │ │ ├── default.yaml
│ │ │ ├── optimizer
│ │ │ │ ├── adam.yaml
│ │ │ │ ├── adamw.yaml
│ │ │ │ └── sgd.yaml
│ │ │ └── scheduler
│ │ │ │ ├── cosine_annealing.yaml
│ │ │ │ └── linear_warmup.yaml
│ │ ├── logging
│ │ │ └── wandb.yaml
│ │ ├── training
│ │ │ ├── action
│ │ │ │ ├── path.yaml
│ │ │ │ └── waypoints.yaml
│ │ │ ├── bev
│ │ │ │ ├── rgb.yaml
│ │ │ │ └── rgb_backbone
│ │ │ │ │ ├── internvl2-4b.yaml
│ │ │ │ │ ├── internvl2-76b.yaml
│ │ │ │ │ ├── internvl2pt5-8b.yaml
│ │ │ │ │ └── llava1pt6.yaml
│ │ │ ├── default.yaml
│ │ │ ├── goal
│ │ │ │ ├── default_goal.yaml
│ │ │ │ ├── dual_target_point.yaml
│ │ │ │ └── target_point.yaml
│ │ │ ├── quantized.yaml
│ │ │ ├── reward
│ │ │ │ └── reward.yaml
│ │ │ └── state
│ │ │ │ ├── lights.yaml
│ │ │ │ └── speed.yaml
│ │ └── user
│ │ │ └── example.yaml
│ ├── data
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── data_parser.py
│ │ ├── data_utils.py
│ │ ├── utils.py
│ │ └── wrapper.py
│ ├── perception
│ │ ├── __init__.py
│ │ └── rgb.py
│ ├── ponderer.py
│ ├── ponderer_lit.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── distributed.py
│ │ ├── distributedsampler.py
│ │ └── utils.py
│ └── visualization
│ │ └── visutils.py
├── requirements.txt
├── setup.py
└── train.py
├── docs
└── TRAIN_EVAL.md
├── misc
├── leaderboard_evaluator_local.py
└── run_eval_leaderboard.py
├── requirements.txt
└── team_code
├── config
├── config.yaml
├── eval
│ └── b2d.yaml
├── experiments
│ └── eta.yaml
└── user
│ └── example.yaml
├── eta_agent.py
└── planner.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 | github: [OpenDriveLab] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
3 | patreon: # Replace with a single Patreon username
4 | open_collective: # Replace with a single Open Collective username
5 | ko_fi: # Replace with a single Ko-fi username
6 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
7 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
8 | liberapay: # Replace with a single Liberapay username
9 | issuehunt: # Replace with a single IssueHunt username
10 | otechie: # Replace with a single Otechie username
11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
12 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2']
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Shadi Hamdan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🤔 ETA
2 |
3 | ## Highlight
4 | We propose "**E**fficiency through **T**hinking **A**head" (ETA), an asynchronous dual-system that pre-processes information from past frames using a large model in tandem with processing the current information with a small model to enable real-time decisions with strong performance.
5 |
6 |
7 | ## News
8 | - **`[2025/06/10]`** [ETA](https://arxiv.org/abs/2506.07725) paper and code release!
9 |
10 | ## Results
11 |
12 |
13 |
14 | ## Table of Contents
15 | 1. [Highlights](#highlight)
16 | 2. [News](#news)
17 | 3. [Results](#results)
18 | 2. [Getting Started](#gettingstarted)
19 | - [Training](docs/TRAIN_EVAL.md#trainingsetup)
20 | - [Evaluation](docs/TRAIN_EVAL.md#evalsetup)
21 | 4. [TODO List](#todolist)
22 | 6. [License and Citation](#licenseandcitation)
23 | 7. [Other Resources](#otherresources)
24 |
25 | ## Getting Started
26 |
27 | To get started with ETA:
28 | - Training
29 | - [Download Bench2Drive Data](docs/TRAIN_EVAL.md#b2ddata)
30 | - [Setup other prerequisites](docs/TRAIN_EVAL.md#trainingsetup)
31 | - [Start training](docs/TRAIN_EVAL.md#training)
32 | - Evaluation
33 | - [Setup Bench2Drive](docs/TRAIN_EVAL.md#evalsetup)
34 | - [Setup files and configs](docs/TRAIN_EVAL.md#evalfilesetup)
35 | - [Download checkpoints](docs/TRAIN_EVAL.md#evalcheckpoints)
36 | - [Run evaluation](docs/TRAIN_EVAL.md#runeval)
37 |
38 | ## TODO List
39 | - [x] ETA Training code
40 | - [x] ETA Evaluation
41 | - [x] Inference Code
42 | - [ ] Checkpoints
43 |
44 | ## Acknowledgements
45 |
46 | This codebase builds on open sourced code from [CARLA Garage](git@github.com:autonomousvision/carla_garage.git) and [Bench2DriveZoo](https://github.com/Thinklab-SJTU/Bench2DriveZoo/) among others. We thank the authors for their contributions. This project is funded by the European Union (ERC, ENSURE, 101116486) with additional compute support from Leonardo Booster (EuroHPC Joint Undertaking, EHPC-AI-2024A01-060). Views and opinions expressed are however those of the author(s) only and do not necessarily reflect those of the European Union or the European Research Council. Neither the European Union nor the granting authority can be held responsible for them. This study is also supported by National Natural Science Foundation of China (62206172) and Shanghai Committee of Science and Technology (23YF1462000).
47 |
48 | ## License and Citation
49 | This project is released under the [MIT License](LICENSE). If you find our project useful for your research, please consider citing our paper with the following BibTeX:
50 |
51 |
52 | ```bibtex
53 | @article{hamdan2025eta,
54 | title={ETA: Efficiency through Thinking Ahead, A Dual Approach to Self-Driving with Large Models},
55 | author={Hamdan, Shadi and Sima, Chonghao and Yang, Zetong and Li, Hongyang and G{\"u}ney, Fatma},
56 | journal={arXiv preprint arXiv:2506.07725},
57 | year={2025}
58 | }
59 | ```
60 |
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/assets/overview.png
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/assets/results.png
--------------------------------------------------------------------------------
/carformer/.gitignore:
--------------------------------------------------------------------------------
1 | # Carformer visualizations
2 | visualizations/**/
3 |
4 | # Byte-compiled / optimized / DLL files
5 | __pycache__/
6 | *.py[cod]
7 | *$py.class
8 |
9 | # C extensions
10 | *.so
11 |
12 | # Distribution / packaging
13 | .Python
14 | build/
15 | develop-eggs/
16 | dist/
17 | downloads/
18 | eggs/
19 | .eggs/
20 | lib/
21 | lib64/
22 | parts/
23 | sdist/
24 | var/
25 | wheels/
26 | pip-wheel-metadata/
27 | share/python-wheels/
28 | *.egg-info/
29 | .installed.cfg
30 | *.egg
31 | MANIFEST
32 |
33 | # PyInstaller
34 | # Usually these files are written by a python script from a template
35 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
36 | *.manifest
37 | *.spec
38 |
39 | # Installer logs
40 | pip-log.txt
41 | pip-delete-this-directory.txt
42 |
43 | # Unit test / coverage reports
44 | htmlcov/
45 | .tox/
46 | .nox/
47 | .coverage
48 | .coverage.*
49 | .cache
50 | nosetests.xml
51 | coverage.xml
52 | *.cover
53 | *.py,cover
54 | .hypothesis/
55 | .pytest_cache/
56 |
57 | # Translations
58 | *.mo
59 | *.pot
60 |
61 | # Django stuff:
62 | *.log
63 | local_settings.py
64 | db.sqlite3
65 | db.sqlite3-journal
66 |
67 | # Flask stuff:
68 | instance/
69 | .webassets-cache
70 |
71 | # Scrapy stuff:
72 | .scrapy
73 |
74 | # Sphinx documentation
75 | docs/_build/
76 |
77 | # PyBuilder
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
--------------------------------------------------------------------------------
/carformer/Makefile:
--------------------------------------------------------------------------------
1 | # By default echo the help message
2 | .DEFAULT_GOAL := help
3 |
4 | help: ## Display this help message
5 | @echo "Usage: make [target]"
6 | @echo ""
7 | @echo "Targets:"
8 | @awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z0-9_-]+:.*?## / {printf " %-20s %s\n", $$1, $$2}' $(MAKEFILE_LIST)
9 |
10 | ETA_base_model_s42:
11 | OMP_NUM_THREADS=6 python train.py deepspeed=zero2 hyperparams.batch_size=5 hyperparams.num_epochs=40 user=example wandb_tag=$@ dataset.dataset_path_rel=B2D-base gpus=4 nodes=8 training.parallel_dataset_init=False backbone=llama-base training.weighted_sampling=True training/goal=dual_target_point training.normalize_goal=True training.bucket_weights.type=preferturns training.rgb_crop.crop_size=896 start_saving_epoch=30 seed=42
12 |
13 | ETA_async_model_s42:
14 | OMP_NUM_THREADS=6 python train.py deepspeed=zero2 hyperparams.batch_size=5 hyperparams.num_epochs=40 user=example wandb_tag=$@ dataset.dataset_path_rel=B2D-base gpus=4 nodes=8 training.parallel_dataset_init=False training.weighted_sampling=True training/goal=dual_target_point training.normalize_goal=True training.bucket_weights.type=preferturns training.use_future_vehicle_forecast=True training.use_predicted_latent_with_gap=True training.action_gap=1 training.future_horizon=1 training.ema_enabled=False +experiments=[use_light_encoder] training.ignore_past_for_length=True training.rgb_crop.crop_size=896 seed=42
15 |
--------------------------------------------------------------------------------
/carformer/carformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/carformer/carformer/__init__.py
--------------------------------------------------------------------------------
/carformer/carformer/config/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from .carformer_config import CarformerConfig
4 |
5 | _filename_ascii_strip_re = re.compile(r"[^A-Za-z0-9_.,=-]")
6 |
7 |
8 | # Adapted from https://tedboy.github.io/flask/_modules/werkzeug/utils.html#secure_filename
9 | # Changed in order to work without importing flask
10 | def sanitize_shorten_ckppath(path):
11 | path = path.replace("use_light_encoder", "ltenc")
12 | # ema_every_steps = getattr(self.ponderer.config.training, "ema_every_steps", 1)
13 | # ema_start = getattr(self.ponderer.config.training, "ema_start", 0)
14 | # ema_end_epoch = getattr(self.ponderer.config.training, "ema_end_epoch", -1)
15 | path = path.replace("gen_mask_for_action", "amsk")
16 | path = path.replace("turns", "trn")
17 | path = path.replace("optimizer", "opt")
18 | path = path.replace("ema_every_steps", "emaint")
19 | path = path.replace("ema_start", "emast")
20 | path = path.replace("ema_end_epoch", "emagnd")
21 | path = path.replace("backbone", "bb")
22 | path = path.replace("training.rgb_crop.crop_size", "trncrp")
23 | path = path.replace("training.ema_decay", "emadc")
24 | path = path.replace("training/", "")
25 | path = path.replace("bev/", "")
26 | path = path.replace("llava1pt6", "lv16")
27 | path = path.replace("rgb_front", "rgb")
28 | path = path.replace("enabled", "T")
29 | path = path.replace("deepspeed", "ds")
30 | path = path.replace("use_gt_frc_only", "gtfrconly")
31 | path = path.replace("combofrc", "cfrc")
32 | path = path.replace("ignore_past_for_length", "igpastln")
33 | path = path.replace("zero_out_frc_branch", "zfbrc")
34 | path = path.replace("use_gt_frc", "gtfrc")
35 | path = path.replace("use_past_horizon", "up")
36 | path = path.replace("backbone.", "")
37 | path = path.replace("optimizer.", "")
38 | path = path.replace("kwargs.", "")
39 | path = path.replace("weight_decay", "")
40 | path = path.replace("hyperparams.", "")
41 | path = path.replace("hyperparams/", "")
42 | path = path.replace("training.", "")
43 | path = path.replace("use_real_latent_ratio", "rltrtio")
44 | path = path.replace("past_horizon", "phrz")
45 | path = path.replace("llama-", "lma-")
46 | path = path.replace("micro", "mc")
47 | path = path.replace("mini", "mn")
48 | path = path.replace("future_horizon", "hrz")
49 | path = path.replace("future_horizon", "ftr")
50 | path = path.replace("ema_enabled", "ema")
51 | path = path.replace("use_predicted_latent_with_gap", "prdltnt")
52 | path = path.replace("bucket_weights.type", "bktype")
53 | path = path.replace("preferturns", "pturns")
54 | path = path.replace("prefer", "p")
55 | path = path.replace("+experiments", "exps")
56 | path = path.replace("num_epochs", "eps")
57 | path = path.replace("batch_size", "bs")
58 | path = path.replace("False", "F")
59 | path = path.replace("True", "T")
60 | path = path.replace("forecast_steps", "frc_steps")
61 | path = path.replace("loss_params.", "")
62 | path = path.replace("action", "actn")
63 | path = path.replace("forecast", "frc")
64 | path = path.replace("classification", "cls")
65 | path = path.replace("state", "stt")
66 | path = path.replace("dataset", "dts")
67 | path = path.replace("subsample_ratio", "smplrtio")
68 | path = path.replace("reconstruction", "rcns")
69 | path = path.replace("wandb_tag", "wnb")
70 | path = path.replace("gradient_accumulation_steps", "gacc")
71 | path = path.replace("dropout", "dp")
72 | path = path.replace("use_future_vehicle_forcast", "dofrc")
73 | path = path.replace("normalize_goal", "nrmgoal")
74 | path = path.replace("weighted_sampling", "wtsmpl")
75 | path = path.replace("use_future_vehicle_frc", "dofrc")
76 | path = path.replace("light_select_layer", "slclr")
77 | path = path.replace("use_light_encoder_backbone", "ltenc")
78 | path = path.replace("actn_gap", "agap")
79 |
80 | for sep in os.path.sep, os.path.altsep:
81 | if sep:
82 | path = path.replace(sep, " ")
83 | path = str(_filename_ascii_strip_re.sub("", "_".join(path.split()))).strip("._")
84 | path = path.replace("training_bev_rgb_backbonergb_backbone", "rgbbb")
85 | path = path.replace("training_goal", "gl")
86 | path = path.replace("dual_target_point", "2tp")
87 | path = path.replace("rgb_backbone", "rgbb")
88 | return path
89 |
90 |
91 | def config_init():
92 | from dataclasses import dataclass
93 |
94 | from omegaconf import MISSING, OmegaConf
95 |
96 | import hydra
97 | from hydra.core.config_store import ConfigStore
98 |
99 | cs = ConfigStore.instance()
100 |
101 | def merge_keys(*cfg):
102 | all_keys = set()
103 | for c in cfg:
104 | all_keys.update([str(x) for x in c.keys()])
105 | return "-".join(sorted(all_keys))
106 |
107 | # If has key, return True, else return False
108 | def has_key(cfg, key):
109 | return key in cfg
110 |
111 | def get_key(cfg, key, *args):
112 | # If args is not empty, recurse after getting the key
113 | if len(args) > 0:
114 | if key in cfg:
115 | try:
116 | return get_key(cfg[key], *args)
117 | except KeyError:
118 | return get_key(cfg[key], "reward")
119 | raise KeyError(f"Key {args} not found in {cfg[key]}")
120 | else:
121 | raise KeyError(f"Key {key} not found in {cfg}")
122 |
123 | if key in cfg:
124 | return cfg[key]
125 | else:
126 | raise KeyError(f"Key {key} not found in {cfg}")
127 |
128 | def resolve_quantizer_path(keys, data_format, quantizer_dict):
129 | # quantizer_key = "plant" if plant_data else "legacy"
130 | if data_format == "plant":
131 | quantizer_key = "plant"
132 | else:
133 | quantizer_key = "legacy"
134 | return get_key(quantizer_dict, quantizer_key, keys)
135 |
136 | OmegaConf.register_new_resolver("merge_keys", merge_keys)
137 |
138 | OmegaConf.register_new_resolver("has_key", has_key)
139 |
140 | OmegaConf.register_new_resolver("get_key", get_key)
141 | OmegaConf.register_new_resolver("eval", eval)
142 |
143 | bool_to_str = lambda x: "true" if x else "false"
144 | OmegaConf.register_new_resolver("bool_to_str", bool_to_str)
145 | OmegaConf.register_new_resolver("resolve_quantizer_path", resolve_quantizer_path)
146 | OmegaConf.register_new_resolver("sanitize", sanitize_shorten_ckppath)
147 |
--------------------------------------------------------------------------------
/carformer/carformer/config/backbone/llama-base.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - llama
3 | architectures:
4 | - LlamaForCausalLM
5 | hidden_size: 768
6 | intermediate_size: 2048
7 | max_position_embeddings: 2048
8 | num_attention_heads: 16
9 | num_hidden_layers: 36
10 | vocab_size: 2
11 | init_from_lm_ckp: false
12 | init_name_or_path:
--------------------------------------------------------------------------------
/carformer/carformer/config/backbone/llama-micro.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - llama
3 | architectures:
4 | - LlamaForCausalLM
5 | hidden_size: 768
6 | intermediate_size: 2048
7 | max_position_embeddings: 2048
8 | num_attention_heads: 16
9 | num_hidden_layers: 12
10 | vocab_size: 2
11 | init_from_lm_ckp: false
12 | init_name_or_path:
--------------------------------------------------------------------------------
/carformer/carformer/config/backbone/llama.yaml:
--------------------------------------------------------------------------------
1 | architectures:
2 | - LlamaForCausalLM
3 | bos_token_id: 1
4 | eos_token_id: 2
5 | hidden_act: silu
6 | hidden_size: 2048
7 | initializer_range: 0.02
8 | intermediate_size: 5632
9 | max_position_embeddings: 2048
10 | model_type: llama
11 | num_attention_heads: 32
12 | num_hidden_layers: 22
13 | num_key_value_heads: 4
14 | pretraining_tp: 1
15 | rms_norm_eps: 1.0e-05
16 | rope_scaling:
17 | tie_word_embeddings: false
18 | torch_dtype: float32
19 | use_cache: true
20 | vocab_size: 32000
21 | attention_dropout: 0.0
--------------------------------------------------------------------------------
/carformer/carformer/config/carformer_config.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """ CarFormer configuration, adapted from GPT-2 configuration """
17 | from collections import OrderedDict
18 | from typing import Any, List, Mapping, Optional
19 |
20 | from transformers import PreTrainedTokenizer, TensorType, is_torch_available
21 |
22 | from transformers.configuration_utils import PretrainedConfig
23 | from transformers.utils import logging
24 | from transformers import CONFIG_MAPPING
25 |
26 | import os
27 | import yaml
28 | from typing import Union
29 | from transformers import AutoConfig
30 |
31 | logger = logging.get_logger(__name__)
32 |
33 |
34 | class CarformerConfig(PretrainedConfig):
35 | """
36 | This is the configuration class to store the configuration of a [`Wanderer`] model or [`Ponderer`] model. It is used to
37 | instantiate the model according to the specified arguments, defining the model architecture.
38 |
39 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
40 | documentation from [`PretrainedConfig`] for more information.
41 |
42 | Args:
43 | vocab_size (`int`, *optional*, defaults to 50257):
44 | Vocabulary size of the GPT-2 model. Defines the number of different tokens that can be represented by the
45 | `inputs_ids` passed when calling [`GPT2Model`] or [`TFGPT2Model`].
46 | n_positions (`int`, *optional*, defaults to 1024):
47 | The maximum sequence length that this model might ever be used with. Typically set this to something large
48 | just in case (e.g., 512 or 1024 or 2048).
49 | n_embd (`int`, *optional*, defaults to 768):
50 | Dimensionality of the embeddings and hidden states.
51 | n_layer (`int`, *optional*, defaults to 12):
52 | Number of hidden layers in the Transformer encoder.
53 | n_head (`int`, *optional*, defaults to 12):
54 | Number of attention heads for each attention layer in the Transformer encoder.
55 | n_inner (`int`, *optional*, defaults to None):
56 | Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
57 | activation_function (`str`, *optional*, defaults to `"gelu"`):
58 | Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
59 | resid_pdrop (`float`, *optional*, defaults to 0.1):
60 | The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
61 | embd_pdrop (`int`, *optionabackbone_configl*, defaults to 0.1):
62 | The dropout ratio for the embeddings.
63 | attn_pdrop (`float`, *optional*, defaults to 0.1):
64 | The dropout ratio for the attention.
65 | layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
66 | The epsilon to use in the layer normalization layers.
67 | initializer_range (`float`, *optional*, defaults to 0.02):
68 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69 | summary_type (`string`, *optional*, defaults to `"cls_index"`):
70 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
71 | [`TFGPT2DoubleHeadsModel`].
72 |
73 | Has to be one of the following options:
74 |
75 | - `"last"`: Take the last token hidden state (like XLNet).
76 | - `"first"`: Take the first token hidden state (like BERT).
77 | - `"mean"`: Take the mean of all tokens hidden states.
78 | - `"cls_index"`: Supply a Tensor of classification token position (like GPT/GPT-2).
79 | - `"attn"`: Not implemented now, use multi-head attention.
80 | summary_use_proj (`bool`, *optional*, defaults to `True`):
81 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
82 | [`TFGPT2DoubleHeadsModel`].
83 |
84 | Whether or not to add a projection after the vector extraction.
85 | summary_activation (`str`, *optional*):
86 | Argument used when doing sequence summary. Used in for the multiple choice head in
87 | [`GPT2DoubleHeadsModel`].
88 |
89 | Pass `"tanh"` for a tanh activation to the output, any other value will result in no activation.
90 | summary_proj_to_labels (`bool`, *optional*, defaults to `True`):
91 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
92 | [`TFGPT2DoubleHeadsModel`].
93 |
94 | Whether the projection outputs should have `config.num_labels` or `config.hidden_size` classes.
95 | summary_first_dropout (`float`, *optional*, defaults to 0.1):
96 | Argument used when doing sequence summary, used in the models [`GPT2DoubleHeadsModel`] and
97 | [`TFGPT2DoubleHeadsModel`].
98 |
99 | The dropout ratio to be used after the projection and activation.
100 | scale_attn_weights (`bool`, *optional*, defaults to `True`):
101 | Scale attention weights by dividing by sqrt(hidden_size)..
102 | use_cache (`bool`, *optional*, defaults to `True`):
103 | Whether or not the model should return the last key/values attentions (not used by all models).
104 | scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
105 | Whether to additionally scale attention weights by `1 / layer_idx + 1`.
106 | reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
107 | Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
108 | dot-product/softmax to float() when training with mixed precision.
109 |
110 | Example:
111 |
112 | ```python
113 | >>> from transformers import GPT2Config, GPT2Model
114 |
115 | >>> # Initializing a GPT2 configuration
116 | >>> configuration = GPT2Config()
117 |
118 | >>> # Initializing a model (with random weights) from the configuration
119 | >>> model = GPT2Model(configuration)
120 |
121 | >>> # Accessing the model configuration
122 | >>> configuration = model.config
123 | ```"""
124 |
125 | model_type = "gpt2"
126 | keys_to_ignore_at_inference = ["past_key_values"]
127 | attribute_map = {
128 | "hidden_size": "hidden_size",
129 | "max_position_embeddings": "max_position_embeddings",
130 | "num_attention_heads": "num_attention_heads",
131 | "num_hidden_layers": "num_hidden_layers",
132 | }
133 |
134 | def __init__(
135 | self,
136 | **kwargs,
137 | ):
138 | # import ipdb; ipdb.set_trace()
139 | # Load config from dict
140 | # import ipdb; ipdb.set_trace()
141 | # self.backbone_config = PretrainedConfig.from_dict(backbone)
142 | if "backbone" in kwargs:
143 | backbone = kwargs.pop("backbone")
144 | if "model_type" in backbone:
145 | backbone_cfg_cls = CONFIG_MAPPING[backbone["model_type"]]
146 | self.backbone = backbone_cfg_cls.from_dict(backbone)
147 | else:
148 | # import ipdb; ipdb.set_trace()
149 | self.backbone = AutoConfig.from_pretrained(
150 | backbone["init_name_or_path"]
151 | )
152 | elif "init_name_or_path" in kwargs:
153 | self.backbone = AutoConfig.from_pretrained(kwargs["init_name_or_path"])
154 | # import ipdb; ipdb.set_trace()
155 | super().__init__(**kwargs)
156 | if hasattr(self, "n_embd"):
157 | self.hidden_size = self.n_embd
158 |
159 | if hasattr(self, "training"):
160 | if "use_future_vehicle_forcast" in self.training:
161 | print("Fixing typo in config")
162 | self.training["use_future_vehicle_forecast"] = self.training[
163 | "use_future_vehicle_forcast"
164 | ]
165 |
166 | self.dotify()
167 |
168 | def dotify(self):
169 | # Convert all dict attributes to dotdict
170 | # from omegaconf.dictconfig import DictConfig
171 | from dotmap import DotMap as ddict
172 |
173 | # Iterate over all attributes
174 | for key, value in self.__dict__.items():
175 | if isinstance(value, dict):
176 | self.__dict__[key] = ddict(value, _dynamic=False)
177 |
178 | def dedotify(self):
179 | # Convert all dotdict attributes to dict
180 | # Iterate over all attributes
181 | from dotmap import DotMap as ddict
182 |
183 | for key, value in self.__dict__.items():
184 | if isinstance(value, ddict):
185 | self.__dict__[key] = dict(value)
186 |
187 | def __getattr__(self, attr: str, default=None):
188 | # Backward compatibility with old attribute names
189 | # Raise warning on each access to deprecated attributes
190 | # print(attr)
191 | if attr == "backbone":
192 | raise AttributeError(f"Attribute {attr} not found")
193 | # print(attr)
194 |
195 | if getattr(self, "backbone", None) is not None:
196 | # import warnings
197 | # print(f"attr {attr} not found, activating deprecated failsafe")
198 | # warnings.warn(
199 | # f"Attribute {attr} will be deprecated in favor of backbone.{attr}. Please update your code.",
200 | # DeprecationWarning,
201 | # )
202 | return getattr(self.backbone, attr)
203 | else:
204 | raise AttributeError(f"Attribute {attr} not found")
205 |
206 | @staticmethod
207 | def from_hydra(hydra_cfg: Any) -> "CarformerConfig":
208 | from omegaconf import OmegaConf
209 |
210 | return CarformerConfig.from_dict(
211 | OmegaConf.to_container(hydra_cfg, resolve=True)
212 | )
213 |
214 | def save_pretrained(
215 | self,
216 | save_directory: Union[str, os.PathLike],
217 | push_to_hub: bool = False,
218 | **kwargs,
219 | ):
220 | """
221 | Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
222 | [`~PretrainedConfig.from_pretrained`] class method.
223 |
224 | Args:
225 | save_directory (`str` or `os.PathLike`):
226 | Directory where the configuration JSON file will be saved (will be created if it does not exist).
227 | push_to_hub (`bool`, *optional*, defaults to `False`):
228 | Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
229 | repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
230 | namespace).
231 | kwargs:
232 | Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
233 | """
234 | if os.path.isfile(save_directory):
235 | raise AssertionError(
236 | f"Provided path ({save_directory}) should be a directory, not a file"
237 | )
238 |
239 | os.makedirs(save_directory, exist_ok=True)
240 |
241 | # Call super to save config.json
242 | super().save_pretrained(save_directory, push_to_hub=push_to_hub, **kwargs)
243 |
--------------------------------------------------------------------------------
/carformer/carformer/config/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - user: shadi
3 | - backbone: llama-mini
4 | - training: quantized
5 | - hyperparams: default
6 | - dataset: b2d
7 | - logging: wandb
8 | - deepspeed: disable
9 | - _self_
10 |
11 | hydra:
12 | run:
13 | dir: checkpoints/${expname}/${sanitize:'${hydra.job.override_dirname}_bev=${training.bev_type}'}/${now:%Y-%m-%d_%H-%M-%S}
14 | job:
15 | chdir: False
16 | config:
17 | override_dirname:
18 | exclude_keys:
19 | - expname
20 | - training/bev
21 | - num_workers
22 | - dataset.dataset_path_rel
23 | - user
24 | - preload_in_memory
25 | - augmentable_preloader
26 | - preload
27 | - visualize
28 | - amp
29 | - multi_gpu_strategy
30 | - user.dataset_dir
31 | - cpu
32 | - gpus
33 | - ckpt_path
34 | - deepspeed
35 | - training.parallel_dataset_init
36 | - force_save
37 | - visualize_start_epoch
38 | - force_log
39 | - overfit_batches
40 | - start_saving_epoch
41 | - nodes
42 | - logging.track_weight_changes
43 | kv_sep: '='
44 | item_sep: '_'
45 |
46 | seed: 1234
47 | debug: False
48 | visualize: True
49 | visualize_start_epoch: -1
50 | visualize_interval: 1
51 | overfit: 0
52 | cpu: False
53 | gpus: ${oc.decode:${oc.env:WORLD_SIZE,1}}
54 | multi_gpu_strategy: ddp
55 | amp: True
56 | num_workers: 20
57 | early_stopping: False
58 | early_stopping_patience: 5
59 | early_stopping_metric: action_classification_loss
60 |
61 | save_every: 1
62 | start_saving_epoch: -1
63 | force_save: False
64 | force_log: False
65 |
66 | expname: TRAINING
67 | wandb_name: training_PlanT_${hydra:job.override_dirname}
68 | wandb_tag:
69 | save_dir: ${hydra:run.dir}
70 | cp_command_dir: ${user.working_dir}/cpcommands
71 |
72 | data_dir: ${user.dataset_dir}/${dataset.dataset_path_rel} # Path to the data directory and name of data folder
73 | preload: True
74 | preload_in_memory: False
75 | augmentable_preloader: False
76 | cache_dir: ${user.working_dir}/.cache/
77 | wipe_cache: False
78 |
79 | use_deepspeed: False
80 | ckpt_path: null
81 | gradient_checkpointing: False
82 | overfit_batches: 1
83 | nodes: 1
--------------------------------------------------------------------------------
/carformer/carformer/config/dataset/b2d.yaml:
--------------------------------------------------------------------------------
1 | dataset_path_rel: B2D-base
2 | data_format: "b2d"
3 | subsample_ratio: 1.0
4 | fps: 10
--------------------------------------------------------------------------------
/carformer/carformer/config/deepspeed/basic.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default
3 | train_micro_batch_size_per_gpu: ${hyperparams.batch_size}
4 | gradient_accumulation_steps: ${hyperparams.gradient_accumulation_steps}
5 | optimizer:
6 | type: ${hyperparams.optimizer.name}
7 | params:
8 | lr: 5e-5
9 | betas:
10 | - 0.9
11 | - 0.999
12 | eps: 1e-6
13 | weight_decay: ${hyperparams.optimizer.kwargs.weight_decay}
14 | bf16:
15 | enabled: true
16 | zero_optimization: false
--------------------------------------------------------------------------------
/carformer/carformer/config/deepspeed/default.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | use_deepspeed: true
3 | train_batch_size: ${deepspeed.train_micro_batch_size_per_gpu}
4 |
--------------------------------------------------------------------------------
/carformer/carformer/config/deepspeed/disable.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | use_deepspeed: false
--------------------------------------------------------------------------------
/carformer/carformer/config/deepspeed/zero2.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default
3 | bf16:
4 | enabled: true
5 | zero_optimization:
6 | stage: 2
7 | overlap_comm: true
8 | contiguous_gradients: true
9 | sub_group_size: 1000000000
10 | reduce_bucket_size: 500000000
11 | stage3_prefetch_bucket_size: 500000000
12 | stage3_param_persistence_threshold: 1000000
13 | stage3_max_live_parameters: 1000000000
14 | stage3_max_reuse_distance: 1000000000
15 | stage3_gather_16bit_weights_on_model_save: false
16 | # optimizer:
17 | # type: AdamW
18 | # params:
19 | # lr: 5e-5
20 | # betas:
21 | # - 0.9
22 | # - 0.999
23 | # eps: 1e-6
24 | # weight_decay: ${hyperparams.optimizer.kwargs.weight_decay}
25 | gradient_clipping: 1
26 | steps_per_print: 5
27 | train_micro_batch_size_per_gpu: ${hyperparams.batch_size}
28 | zero_allow_untested_optimizer: true
--------------------------------------------------------------------------------
/carformer/carformer/config/deepspeed/zero3.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default
3 | bf16:
4 | enabled: true
5 | zero_optimization:
6 | stage: 3
7 | overlap_comm: true
8 | contiguous_gradients: true
9 | sub_group_size: 1000000000
10 | reduce_bucket_size: 500000000
11 | stage3_prefetch_bucket_size: 500000000
12 | stage3_param_persistence_threshold: 1000000
13 | stage3_max_live_parameters: 1000000000
14 | stage3_max_reuse_distance: 1000000000
15 | stage3_gather_16bit_weights_on_model_save: false
16 | # optimizer:
17 | # type: AdamW
18 | # params:
19 | # lr: 5e-5
20 | # betas:
21 | # - 0.9
22 | # - 0.999
23 | # eps: 1e-6
24 | # weight_decay: ${hyperparams.optimizer.kwargs.weight_decay}
25 | gradient_clipping: 1
26 | steps_per_print: 5
27 | train_micro_batch_size_per_gpu: ${hyperparams.batch_size}
28 | stage3_gather_16bit_weights_on_model_save: true
29 |
--------------------------------------------------------------------------------
/carformer/carformer/config/experiments/init_rgb_from_ckpt.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | training:
3 | rgb_backbone:
4 | init_from_ckpt:
5 | enabled: true
--------------------------------------------------------------------------------
/carformer/carformer/config/experiments/onlyfrc.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | training:
3 | loss_params:
4 | wp_loss: 0.0
5 | path_loss: 0.0
6 | mask_loss: 0.0
7 | state_forecast: 0.5
--------------------------------------------------------------------------------
/carformer/carformer/config/experiments/overfit.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | hydra:
4 | run:
5 | dir: ./checkpoints/debug/tmp_debug
6 |
7 | overfit: 1
8 |
9 | wandb_tag: overfit
10 |
11 | training:
12 | max_instances: ${eval:${hyperparams.batch_size} * 10}
13 | splits:
14 | train: val
15 |
16 | dataset:
17 | subsample_ratio: 1.0 # Since we are already using max_instances, we don't need to subsample.
18 |
19 | backbone:
20 | attn_pdrop: 0.0
21 | resid_pdrop: 0.0
22 | embd_pdrop: 0.0
--------------------------------------------------------------------------------
/carformer/carformer/config/experiments/use_light_encoder.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 | defaults:
3 | - /training/bev/rgb_backbone@training.light_rgb_backbone: llava1pt6
4 | - /backbone@backbone: llama-micro
5 | - _self_
6 |
7 | light_select_layer: 8
8 | training:
9 | utilize_fast_current_latent: True
10 | use_light_as_query: False
11 | light_rgb_backbone:
12 | select_layer: ${light_select_layer}
13 | ema_enabled: False
14 | model_path: ${user.working_dir}/${training.light_rgb_backbone.model_path_rel}
15 |
--------------------------------------------------------------------------------
/carformer/carformer/config/hyperparams/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - optimizer: adamw
3 | - scheduler: linear_warmup
4 | debug: False
5 | lr: 3e-5
6 | max_grad_norm: 1.0
7 | batch_size: 128
8 | gradient_accumulation_steps: 1
9 | num_epochs: 200
10 | patience: 40
--------------------------------------------------------------------------------
/carformer/carformer/config/hyperparams/optimizer/adam.yaml:
--------------------------------------------------------------------------------
1 | name: Adam
2 | kwargs:
3 | weight_decay: 0.0
--------------------------------------------------------------------------------
/carformer/carformer/config/hyperparams/optimizer/adamw.yaml:
--------------------------------------------------------------------------------
1 | name: AdamW
2 | kwargs:
3 | weight_decay: 1e-4
--------------------------------------------------------------------------------
/carformer/carformer/config/hyperparams/optimizer/sgd.yaml:
--------------------------------------------------------------------------------
1 | name: SGD
2 | kwargs:
3 | weight_decay: 1e-4
--------------------------------------------------------------------------------
/carformer/carformer/config/hyperparams/scheduler/cosine_annealing.yaml:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/carformer/carformer/config/hyperparams/scheduler/cosine_annealing.yaml
--------------------------------------------------------------------------------
/carformer/carformer/config/hyperparams/scheduler/linear_warmup.yaml:
--------------------------------------------------------------------------------
1 | name: linear
2 | warmup_ratio: 0.05
3 | kwargs:
4 | num_warmup_steps: ${eval:'int(${hyperparams.num_epochs} * ${hyperparams.scheduler.warmup_ratio})'}
5 | num_training_steps: ${hyperparams.num_epochs}
--------------------------------------------------------------------------------
/carformer/carformer/config/logging/wandb.yaml:
--------------------------------------------------------------------------------
1 | project: ponderer
2 | entity: ${user.wandb_entity}
3 | mode: online
--------------------------------------------------------------------------------
/carformer/carformer/config/training/action/path.yaml:
--------------------------------------------------------------------------------
1 | path:
2 | name: path
3 | width: ${eval:'${training.num_path} * 2'}
4 |
--------------------------------------------------------------------------------
/carformer/carformer/config/training/action/waypoints.yaml:
--------------------------------------------------------------------------------
1 | waypoints:
2 | name: waypoints
3 | width: ${eval:'${training.num_waypoints} * 2'}
4 | future_horizon: ${training.num_waypoints}
5 |
--------------------------------------------------------------------------------
/carformer/carformer/config/training/bev/rgb.yaml:
--------------------------------------------------------------------------------
1 | # @package training
2 | defaults:
3 | - _self_
4 | - rgb_backbone: llava1pt6
5 |
6 | bev:
7 | rgb_front:
8 | name: rgb_front
9 |
10 | rgb_crop:
11 | type: dualcenter
12 | crop_size: ${eval:'${training.rgb_backbone.input_size} * 2'}
13 | resize: ${training.rgb_backbone.input_size}
14 |
15 |
16 | tokenized_state: False
17 | object_level: False
18 | ema_enabled: False
19 | ema_decay: 0.992
20 | ema_every_steps: 1
21 | ema_start: 0
22 | ema_end_epoch: -1
--------------------------------------------------------------------------------
/carformer/carformer/config/training/bev/rgb_backbone/internvl2-4b.yaml:
--------------------------------------------------------------------------------
1 | model_path_rel: bin/rgb/internvl2-4b-visionenc/
2 | model_path: ${user.working_dir}/${training.rgb_backbone.model_path_rel}
3 |
4 | frozen: false
5 | outputs:
6 | whole: false
7 | patches: true
8 |
9 | projection_dim: ${backbone.hidden_size}
10 |
11 | input_size: 448
12 | select_layer: -1
13 | try_to_truncate_layers: true
14 | downsample: True
15 |
16 | ema_enabled: ${training.ema_enabled}
--------------------------------------------------------------------------------
/carformer/carformer/config/training/bev/rgb_backbone/internvl2-76b.yaml:
--------------------------------------------------------------------------------
1 | model_path_rel: bin/rgb/internvl2-76b-visionenc/
2 | model_path: ${user.working_dir}/${training.rgb_backbone.model_path_rel}
3 |
4 | frozen: false
5 | outputs:
6 | whole: false
7 | patches: true
8 |
9 | projection_dim: ${backbone.hidden_size}
10 |
11 | input_size: 448
12 | select_layer: -1
13 | try_to_truncate_layers: true
14 | downsample: True
15 |
16 | ema_enabled: ${training.ema_enabled}
--------------------------------------------------------------------------------
/carformer/carformer/config/training/bev/rgb_backbone/internvl2pt5-8b.yaml:
--------------------------------------------------------------------------------
1 | model_path_rel: bin/rgb/internvl2pt5-8b/
2 | model_path: ${user.working_dir}/${training.rgb_backbone.model_path_rel}
3 |
4 | frozen: false
5 | outputs:
6 | whole: false
7 | patches: true
8 |
9 | projection_dim: ${backbone.hidden_size}
10 |
11 | input_size: 448
12 | select_layer: -1
13 | try_to_truncate_layers: true
14 | downsample: True
15 |
16 | ema_enabled: ${training.ema_enabled}
17 |
--------------------------------------------------------------------------------
/carformer/carformer/config/training/bev/rgb_backbone/llava1pt6.yaml:
--------------------------------------------------------------------------------
1 | model_path_rel: bin/rgb/llava-v1.6-vicuna-visionenc
2 | model_path: ${user.working_dir}/${training.rgb_backbone.model_path_rel}
3 |
4 | frozen:
5 | model: False
6 | ema_model: False
7 | projector: False
8 |
9 | outputs:
10 | whole: false
11 | patches: true
12 |
13 | projection_dim: ${backbone.hidden_size}
14 |
15 | input_size: 336
16 | select_layer: -2
17 | dropout_attn: 0.0
18 | masking_rate: 0.0
19 |
20 | try_to_truncate_layers: true
21 |
22 | override_kwargs:
23 | attention_dropout: ${training.rgb_backbone.dropout_attn}
24 |
25 | downsample: True
26 | ema_enabled: ${training.ema_enabled}
27 |
28 | init_from_ckpt:
29 | enabled: False
30 | ckpt_path_rel: bin/rgb/llava-v14-micro.pt
31 | ckpt_path: ${user.working_dir}/${training.rgb_backbone.init_from_ckpt.ckpt_path_rel}
32 | projector: true
33 | ema_model: true
34 | model: true
35 | freeze: true
--------------------------------------------------------------------------------
/carformer/carformer/config/training/default.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - action:
3 | - path
4 | - steer
5 | - goal:
6 | - dual_target_point
7 | - reward:
8 | - reward
9 | - bev:
10 | - rgb
11 | - state:
12 | - speed
13 | - _self_
14 | loss_params:
15 | default:
16 | classification: 0
17 | action:
18 | classification: 1
19 | reconstruction: 1
20 |
21 | action_type: ${merge_keys:${training.action}}
22 | state_type: ${merge_keys:${training.state}, ${training.bev}}
23 | non_bev_state_type: ${merge_keys:${training.state}}
24 | bev_type: ${merge_keys:${training.bev}}
25 | goal_type: ${merge_keys:${training.goal}}
26 | reward_type: ${merge_keys:${training.reward}}
27 | condition_on_goal: true
28 | goal_conditioning_type: local
29 | max_token_types: 4
30 | quantized: false
31 |
32 | # Data settings
33 | integrate_rewards_to_go: false
34 | context_length: 1
35 | frame_stride: 5
36 | inter_window_stride: 2
37 | skip_noisy: true
38 | trim_first_and_last: true
39 | trim_count: 1
40 | max_instances: -1
41 | drop_last: true
42 | future_horizon: 0
43 | past_horizon: 0
44 | use_future_ego_waypoints: ${has_key:${training.action},waypoints}
45 | use_future_vehicle_forecast: false
46 | forecast_steps: 1
47 | include_noisy_in_action: false
48 | splits:
49 | train: train
50 | val: val
51 | dynamic_batching: true # false, whether or not to crop padding
52 | weighted_sampling: false # false, whether or not to sample based on class weights
53 | bucket_weights:
54 | type: uniform
55 | total_ratio: 0.6
56 | weights:
57 | - 1.0 # general
58 | - 1.0 # acc_scratch
59 | - 2.0 # acc_light_pedal
60 | - 2.0 # acc_medium_pedal
61 | - 1.0 # acc_heavy_pedal
62 | - 1.0 # acc_brake
63 | - 1.0 # acc_coast
64 | - 3.0 # steer_right
65 | - 3.0 # steer_left
66 | - 1.0 # vehicle_hazard_front
67 | - 1.0 # vehicle_hazard_back
68 | - 1.0 # vehicle_hazard_side
69 | - 1.0 # stop_sign
70 | - 1.0 # red_light
71 | - 1.0 # swerving
72 | - 1.0 # pedestrian
73 |
74 | get_weight_reduce_fn: mean
75 | get_noisy_reduce_fn: last
76 |
77 | # Training settings
78 | split_ratio: 0.8
79 |
80 | # Caching
81 | dataset_caching:
82 | enabled: True
83 | cache_metadata: True
84 | cache_slow_attributes: True
85 | cache_dir: ${cache_dir}
86 | parallel_dataset_init: True
87 | parallel_dataset_workers: 16
--------------------------------------------------------------------------------
/carformer/carformer/config/training/goal/default_goal.yaml:
--------------------------------------------------------------------------------
1 | # @package training
2 | goal_conditioning_type: local
3 | goal_continuous: True
--------------------------------------------------------------------------------
/carformer/carformer/config/training/goal/dual_target_point.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default_goal
3 | dual_target_point:
4 | name: dual_target_point
5 | width: 4
6 | mean: [[ 5.1162534 , -0.1575937 ],[26.005814 , -0.09633584]]
7 | std: [[ 0.8992543, 0.9467049],[15.428605 , 6.614513 ]]
--------------------------------------------------------------------------------
/carformer/carformer/config/training/goal/target_point.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default_goal
3 | target_point:
4 | name: target_point
5 | width: 2
6 | mean: [5.1162534, -0.1575937]
7 | std: [0.8992543, 0.9467049]
--------------------------------------------------------------------------------
/carformer/carformer/config/training/quantized.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - default
3 | - override action:
4 | - waypoints
5 | - path
6 | - override goal:
7 | - target_point
8 | quantized: true
9 |
10 | loss_params:
11 | wp_loss: 1.0
12 | path_loss: 1.0
13 | mask_loss: 0.0625
14 | state_forecast: 0.5
15 | bev: {} # Nothing
16 |
17 | # For waypoints
18 | num_waypoints: 10
19 | num_path: 20
20 | future_horizon: ${training.num_waypoints}
21 | # Forecasting
22 | use_future_vehicle_forecast: false
23 | vae_target_supervision: null
24 |
25 | use_predicted_latent_with_gap: false
26 | pred_latent_layers: 2
27 | pred_latent_ffn_hidden: 2048
28 | pred_latent_ffn_dropout: 0.1
29 | pred_latent_post_mlp: false
30 | pred_latent_use_metadata: true
31 | use_real_latent_ratio: 0.0
32 |
33 | # action gap (for future action offset)
34 | action_gap:
35 |
36 | # Supplemental supervision
37 | gen_masks_for_action: True
38 |
39 | normalize_goal: True
40 | create_goal_mask: True
41 |
42 | use_gt_frc: false
43 | zero_out_frc_branch: false
44 | use_gt_frc_only: false
45 | ignore_past_for_length: false
46 |
47 | use_past_horizon_states: false
--------------------------------------------------------------------------------
/carformer/carformer/config/training/reward/reward.yaml:
--------------------------------------------------------------------------------
1 | reward:
2 | name: reward
--------------------------------------------------------------------------------
/carformer/carformer/config/training/state/lights.yaml:
--------------------------------------------------------------------------------
1 | lights:
2 | name: lights
--------------------------------------------------------------------------------
/carformer/carformer/config/training/state/speed.yaml:
--------------------------------------------------------------------------------
1 | speed:
2 | name: speed
--------------------------------------------------------------------------------
/carformer/carformer/config/user/example.yaml:
--------------------------------------------------------------------------------
1 | dataset_dir: /leonardo_work/shamdan/
2 | working_dir: /leonardo/home/userexternal/shamdan0/research/Carformer
3 |
4 | wandb_entity: shadihamdan
5 |
--------------------------------------------------------------------------------
/carformer/carformer/data/__init__.py:
--------------------------------------------------------------------------------
1 | from .data_parser import *
2 | from .data_utils import *
3 | from .data import *
4 | from .wrapper import *
5 | from .utils import *
6 |
--------------------------------------------------------------------------------
/carformer/carformer/data/data_parser.py:
--------------------------------------------------------------------------------
1 | # import orjson as json
2 | import json
3 | import gzip
4 | import os
5 | from PIL import Image
6 | import numpy as np
7 | from .data_utils import (
8 | iterative_line_interpolation,
9 | get_hazard_directions,
10 | is_walker_hazard,
11 | )
12 |
13 | home_dir = os.environ["HOME"]
14 | path_cache = {}
15 |
16 |
17 | class Parser:
18 | def __init__(
19 | self,
20 | root_path,
21 | state_type,
22 | action_type,
23 | reward_type,
24 | goal_type,
25 | folder_to_ext=None,
26 | size=None,
27 | cache_dir=None,
28 | ):
29 | self.root_path = root_path
30 | self.state_type = state_type
31 | self.action_type = action_type
32 | self.reward_type = reward_type
33 | self.goal_type = goal_type
34 |
35 | self.states = state_type.split("-")
36 | self.actions = action_type.split("-")
37 | self.rewards = reward_type.split("-")
38 | self.goals = goal_type.split("-")
39 |
40 | if folder_to_ext is None:
41 | self._folder_to_ext = {}
42 | # Check the extension of the first file in every folder in path
43 | for folder in os.listdir(self.root_path):
44 | folder_path = os.path.join(self.root_path, folder)
45 | if os.path.isdir(folder_path):
46 | fl = os.listdir(folder_path)
47 | # Get extension
48 | if len(fl) > 0:
49 | ext = os.path.splitext(fl[0])[1]
50 | self._folder_to_ext[folder] = ext
51 | else:
52 | self._folder_to_ext = folder_to_ext
53 |
54 | if size is not None:
55 | self.length = size
56 | else:
57 | # check if self.root_path/anno/00000.json.gz exists. If not, set length to 0 and return
58 | sanity_file = os.path.join(self.root_path, "anno", "00000.json.gz")
59 | if not os.path.exists(sanity_file):
60 | self.length = 0
61 | return
62 | self.length = len(os.listdir(os.path.join(self.root_path, "anno")))
63 |
64 | self.cache_dir = cache_dir
65 | self.path_cache = None
66 |
67 | def get_state(
68 | self,
69 | idx,
70 | preprocessing_functions=None,
71 | filtering_functions=None,
72 | skip_keys=None,
73 | ):
74 | ts_prefix = str(idx).zfill(5)
75 | state_dict = self.gzip_json_load(os.path.join("anno", f"{ts_prefix}.json.gz"))
76 |
77 | state = {}
78 |
79 | for s in self.states:
80 | if skip_keys is not None and s in skip_keys:
81 | continue
82 | if "rgb" in s:
83 | rgb = Image.open(
84 | os.path.join(self.root_path, "camera", s, f"{ts_prefix}.jpg")
85 | )
86 |
87 | speed = state_dict["speed"]
88 |
89 | action = np.mean(state_dict["bounding_boxes"][0]["world2ego"])
90 |
91 | state[s] = rgb
92 | elif s == "speed":
93 | state["speed"] = state_dict["speed"]
94 | else:
95 | raise ValueError(f"State type {s} not recognized")
96 |
97 | if preprocessing_functions is not None and s in preprocessing_functions:
98 | state[s] = preprocessing_functions[s](state[s])
99 |
100 | return state
101 |
102 | def get_action(self, idx, include_noise=False, skip_keys=None):
103 | ts_prefix = str(idx).zfill(5)
104 | state_dict = self.gzip_json_load(os.path.join("anno", f"{ts_prefix}.json.gz"))
105 |
106 | action = {}
107 |
108 | for a in self.actions:
109 | if skip_keys is not None and a in skip_keys:
110 | continue
111 | if a == "waypoints":
112 | # Get the current ego matrix from the measurement dict
113 | if skip_keys is not None and "ego_matrix" in skip_keys:
114 | continue
115 | action["ego_matrix"] = state_dict["bounding_boxes"][0]["world2ego"]
116 | elif a == "path":
117 | # Default to 20 waypoints like CarLLaVA
118 | ego_matrix = state_dict["bounding_boxes"][0]["world2ego"]
119 | pointer_idx = idx
120 | points = []
121 | if self.path_cache is not None:
122 | relevant_points = self.path_cache[pointer_idx:]
123 |
124 | relevant_points_inv = np.linalg.inv(relevant_points)
125 | ego_matrix = np.asarray(ego_matrix)
126 | points = np.einsum(
127 | "Xi, BiY -> BXY", ego_matrix, relevant_points_inv
128 | )[:, :2, 3]
129 |
130 | # Multiply by -1 to get the correct y coordinate
131 | points[:, 1] = -points[:, 1]
132 | else:
133 | while True:
134 | print("Rebuilding path cache")
135 | path = os.path.join(
136 | self.root_path,
137 | "anno",
138 | f"{str(pointer_idx).zfill(5)}.json.gz",
139 | )
140 |
141 | if path in path_cache:
142 | cur_point = path_cache[path]
143 | else:
144 | if not os.path.exists(path):
145 | break
146 |
147 | cur_dct = self.gzip_json_load(
148 | os.path.join(
149 | "anno", f"{str(pointer_idx).zfill(5)}.json.gz"
150 | )
151 | )
152 | cur_point = cur_dct["bounding_boxes"][0]["world2ego"]
153 |
154 | x, y = np.dot(ego_matrix, np.linalg.inv(cur_point))[:2, 3]
155 |
156 | points.append((x, -y))
157 | pointer_idx += 1
158 |
159 | path = iterative_line_interpolation(points)
160 |
161 | action["path"] = path
162 | else:
163 | raise ValueError(f"Action type {a} not recognized")
164 |
165 | if include_noise:
166 | return action, False
167 | else:
168 | return action
169 |
170 | def get_noisy(self, idx):
171 | return False
172 |
173 | def get_reward(self, idx, skip_keys=None):
174 | rewards = {}
175 |
176 | for r in self.rewards:
177 | if skip_keys is not None and r in skip_keys:
178 | continue
179 | rewards[r] = 0.0
180 |
181 | return rewards
182 |
183 | def get_goal(self, idx, skip_keys=None):
184 | ts_prefix = str(idx).zfill(5)
185 | state_dict = self.gzip_json_load(os.path.join("anno", f"{ts_prefix}.json.gz"))
186 |
187 | goal = {}
188 |
189 | for g in self.goals:
190 | if skip_keys is not None and g in skip_keys:
191 | continue
192 | if g == "target_point":
193 | # Get the target point from the state
194 | ego = state_dict["bounding_boxes"][0]
195 | theta = ego["rotation"][-1] * np.pi / 180
196 |
197 | command_near_xy = np.array(
198 | [
199 | state_dict["x_command_near"] - state_dict["x"],
200 | -state_dict["y_command_near"] + state_dict["y"],
201 | ]
202 | )
203 | rotation_matrix = np.array(
204 | [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
205 | )
206 | local_command_xy = rotation_matrix @ command_near_xy
207 | # command_far_xy = np.array([state_dict["x_command_far"]-state_dict['x'],-state_dict["y_command_far"]+state_dict['y']])
208 | # local_command_far_xy = rotation_matrix @ command_far_xy
209 |
210 | goal["target_point"] = local_command_xy
211 | elif g == "dual_target_point":
212 | # Get the target point from the state
213 | # local_command_point = np.array(state_dict["target_point"])
214 | # goal["target_point"] = local_command_point
215 | ego = state_dict["bounding_boxes"][0]
216 | theta = ego["rotation"][-1] * np.pi / 180
217 |
218 | command_near_xy = np.array(
219 | [
220 | state_dict["x_command_near"] - state_dict["x"],
221 | -state_dict["y_command_near"] + state_dict["y"],
222 | ]
223 | )
224 | rotation_matrix = np.array(
225 | [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]
226 | )
227 | local_command_xy = rotation_matrix @ command_near_xy
228 | command_far_xy = np.array(
229 | [
230 | state_dict["x_command_far"] - state_dict["x"],
231 | -state_dict["y_command_far"] + state_dict["y"],
232 | ]
233 | )
234 | local_command_far_xy = rotation_matrix @ command_far_xy
235 |
236 | goal["dual_target_point"] = np.stack(
237 | [local_command_xy, local_command_far_xy], axis=0
238 | )
239 | else:
240 | raise ValueError(f"Goal type {g} not recognized")
241 |
242 | return goal
243 |
244 | def get_size(self):
245 | return self.length
246 |
247 | def get_weight(self, idx):
248 | return self.get_buckets(idx)
249 |
250 | def get_buckets(self, idx):
251 | # Get which bucket the current idx belongs to
252 | assert self.buckets is not None, "Buckets not initialized"
253 |
254 | return self.buckets[idx]
255 |
256 | @staticmethod
257 | def get_folder_to_ext(dir):
258 | folder_to_ext = {}
259 | for folder in os.listdir(dir):
260 | folder_path = os.path.join(dir, folder)
261 | if os.path.isdir(folder_path):
262 | fl = os.listdir(folder_path)
263 | # Get extension
264 | if len(fl) > 0:
265 | ext = os.path.splitext(fl[0])[1]
266 | folder_to_ext[folder] = ext
267 | else:
268 | raise ValueError(f"Folder {folder} is empty in directory {dir}")
269 |
270 | return folder_to_ext
271 |
272 | def gzip_json_load(self, rel_file_path, root_path=None):
273 | if root_path is None:
274 | root_path = self.root_path
275 | with gzip.open(os.path.join(root_path, rel_file_path), "r") as f:
276 | return json.loads(f.read().decode("utf-8"))
277 |
278 | def path_is_cached(self, cache_dir):
279 | return os.path.exists(
280 | os.path.join(cache_dir, f"{self.root_path.split('/')[-1]}.json.gz")
281 | )
282 |
283 | def bucket_is_cached(self, path_cache_dir):
284 | return os.path.exists(
285 | os.path.join(path_cache_dir, f"{self.root_path.split('/')[-1]}.npz")
286 | )
287 |
288 | def cache_buckets(self, bucket_cache_dir):
289 | if self.bucket_is_cached(bucket_cache_dir):
290 | return
291 |
292 | bucket_names, all_buckets = self.get_all_buckets()
293 |
294 | cache_path = os.path.join(
295 | bucket_cache_dir, f"{self.root_path.split('/')[-1]}.npz"
296 | )
297 |
298 | np.savez(cache_path, bucket_names=bucket_names, buckets=all_buckets)
299 |
300 | def cache_path(self, cache_dir):
301 | if self.path_is_cached(cache_dir):
302 | self.load_path_cache(cache_dir)
303 | return
304 |
305 | if "path" in self.action_type:
306 | all_paths = self.get_all_paths()
307 |
308 | cache_path = os.path.join(
309 | cache_dir, f"{self.root_path.split('/')[-1]}.json.gz"
310 | )
311 |
312 | with gzip.open(cache_path, "w") as f:
313 | f.write(json.dumps(all_paths).encode("utf-8"))
314 |
315 | def get_all_paths(self):
316 | if "path" not in self.action_type:
317 | return None
318 | all_egos = []
319 | for i in range(self.length):
320 | ts_prefix = str(i).zfill(5)
321 | state_dict = self.gzip_json_load(
322 | os.path.join("anno", f"{ts_prefix}.json.gz")
323 | )
324 | cur_point = state_dict["bounding_boxes"][0]["world2ego"]
325 | all_egos.append(cur_point)
326 | return all_egos
327 |
328 | def get_all_buckets(self):
329 | # Buckets are useful for bucketed sampling
330 | # Buckets defined as follows:
331 | # acceleration buckets:
332 | idx = 0
333 | all_state_dicts = []
334 | from skit.profiling import Ticker
335 |
336 | t = Ticker(verbose=False, track=True)
337 |
338 | for i in range(self.length):
339 | ts_prefix = str(i).zfill(5)
340 | state_dict = self.gzip_json_load(
341 | os.path.join("anno", f"{ts_prefix}.json.gz")
342 | )
343 | all_state_dicts.append(state_dict)
344 |
345 | swerving_scenarios = [
346 | "Accident",
347 | "BlockedIntersection",
348 | "ConstructionObstacle",
349 | "HazardAtSideLane",
350 | "ParkedObstacle",
351 | "VehicleOpensDoorTwoWays",
352 | ]
353 |
354 | data_is_swerving = any([x in self.root_path for x in swerving_scenarios])
355 |
356 | all_buckets = []
357 | general_bucket_name = ["general"]
358 | acceleration_bucket_names = [
359 | "acc_scratch",
360 | "acc_light_pedal",
361 | "acc_medium_pedal",
362 | "acc_heavy_pedal",
363 | "acc_brake",
364 | "acc_coast",
365 | ]
366 | steer_bucket_names = [
367 | "steer_right",
368 | "steer_left",
369 | ]
370 | vehicle_hazard_bucket_names = [
371 | "vehicle_hazard_front",
372 | "vehicle_hazard_back",
373 | "vehicle_hazard_side",
374 | ]
375 | stop_sign_bucket_names = ["stop_sign"]
376 | red_light_bucket_names = ["red_light"]
377 | swerving_bucket_names = ["swerving"]
378 | pedestrian_bucket_names = ["pedestrian"]
379 |
380 | for i in range(self.length):
381 | acceleration = all_state_dicts[i]["throttle"]
382 | brake = all_state_dicts[i]["brake"]
383 | speed = all_state_dicts[i]["speed"]
384 |
385 | acceleration_bucket = [
386 | 1 if (acceleration > 0.2 and brake < 1.0 and speed < 0.05) else 0,
387 | 1 if (acceleration > 0.2 and acceleration < 0.5) else 0,
388 | 1 if (acceleration > 0.5 and acceleration < 0.9) else 0,
389 | 1 if (acceleration > 0.9) else 0,
390 | 1 if (brake > 0.2) else 0,
391 | 1 if (acceleration < 0.2 and brake < 1.0) else 0,
392 | ]
393 |
394 | steer = all_state_dicts[i]["steer"]
395 |
396 | steer_bucket = [
397 | 1 if (steer > 0.2) else 0,
398 | 1 if (steer < -0.2) else 0,
399 | ]
400 |
401 | # other_vehicles = all_state_dicts[i]["vehicle_hazard"]
402 |
403 | vehicle_hazard_bucket = get_hazard_directions(
404 | all_state_dicts[i]["bounding_boxes"]
405 | )
406 | # print(i, vehicle_hazard_bucket)
407 |
408 | vehicle_hazard_bucket = [
409 | (
410 | 1 if any([x < 30 for x in vehicle_hazard_bucket]) else 0
411 | ), # Heading from front
412 | (
413 | 1 if any([x > 150 for x in vehicle_hazard_bucket]) else 0
414 | ), # Heading from back
415 | (
416 | 1 if any([x > 30 and x < 150 for x in vehicle_hazard_bucket]) else 0
417 | ), # Heading from side
418 | ]
419 |
420 | if data_is_swerving and abs(steer) > 0.1:
421 | swerving_bucket = 1
422 | else:
423 | swerving_bucket = 0
424 |
425 | all_objs = all_state_dicts[i]["bounding_boxes"]
426 |
427 | stopsigns = [
428 | x
429 | for x in all_objs
430 | if x["class"] == "traffic_sign" and x["type_id"] == "traffic.stop"
431 | ]
432 |
433 | stop_sign_bucket = 1 if any([x["affects_ego"] for x in stopsigns]) else 0
434 |
435 | redtrafficlights = [
436 | x for x in all_objs if x["class"] == "traffic_light" and x["state"] == 0
437 | ]
438 |
439 | red_light_bucket = (
440 | 1 if any([x["affects_ego"] for x in redtrafficlights]) else 0
441 | )
442 |
443 | pedestrian_hazards = is_walker_hazard(all_objs)
444 | pedestrian_hazard_bucket = 1 if pedestrian_hazards else 0
445 |
446 | instance_buckets = (
447 | [1]
448 | + acceleration_bucket
449 | + steer_bucket
450 | + vehicle_hazard_bucket
451 | + [
452 | stop_sign_bucket,
453 | red_light_bucket,
454 | swerving_bucket,
455 | pedestrian_hazard_bucket,
456 | ]
457 | )
458 |
459 | all_buckets.append(instance_buckets)
460 |
461 | bucket_names = (
462 | general_bucket_name
463 | + acceleration_bucket_names
464 | + steer_bucket_names
465 | + vehicle_hazard_bucket_names
466 | + stop_sign_bucket_names
467 | + red_light_bucket_names
468 | + swerving_bucket_names
469 | + pedestrian_bucket_names
470 | )
471 |
472 | all_buckets = np.asarray(all_buckets)
473 |
474 | return bucket_names, all_buckets
475 |
476 | def validate_cache(self, path_cache_dir):
477 | if not "path" in self.action_type:
478 | return True
479 |
480 | to_compare = self.load_path_cache(path_cache_dir)
481 |
482 | all_egos = self.get_all_paths()
483 |
484 | return np.allclose(np.asarray(all_egos), np.asarray(to_compare))
485 |
486 | def load_path_cache(self, path_cache_dir):
487 | if not "path" in self.action_type:
488 | return None
489 |
490 | cache_path = os.path.join(
491 | path_cache_dir, f"{self.root_path.split('/')[-1]}.json.gz"
492 | )
493 | if not os.path.exists(cache_path):
494 | return None
495 |
496 | path_cache = self.gzip_json_load(cache_path, root_path="")
497 |
498 | self.path_cache = np.asarray(path_cache)
499 |
500 | def load_bucket_cache(self, bucket_cache_dir):
501 | cache_path = os.path.join(
502 | bucket_cache_dir, f"{self.root_path.split('/')[-1]}.npz"
503 | )
504 | if not os.path.exists(cache_path):
505 | return None
506 |
507 | bucket_cache = np.load(cache_path)
508 |
509 | bucket_names = bucket_cache["bucket_names"]
510 | # cast names to string
511 | bucket_names = [str(x) for x in bucket_names]
512 | buckets = bucket_cache["buckets"]
513 |
514 | self.buckets = buckets
515 | self.bucket_names = bucket_names
516 |
--------------------------------------------------------------------------------
/carformer/carformer/data/data_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from torchvision.transforms import Compose, Resize, CenterCrop
4 |
5 |
6 | def get_rgb_preprocessing_function_from_config(config):
7 | return get_rgb_preprocessing_function(
8 | config.rgb_crop.type,
9 | config.rgb_crop.crop_size,
10 | config.rgb_crop.resize,
11 | )
12 |
13 |
14 | def get_rgb_preprocessing_function(crop_type, crop_size, resize):
15 | transforms = []
16 |
17 | if crop_type == "dualcenter":
18 | # Width: 2xcrop_size, height: crop_size
19 | # Then resize to 2xresize, resize
20 | transforms.append(CenterCrop((crop_size, 2 * crop_size)))
21 | if resize > 0:
22 | transforms.append(Resize((resize, 2 * resize)))
23 | elif crop_type == "center":
24 | transforms.append(CenterCrop(crop_size))
25 | if resize > 0:
26 | transforms.append(Resize(resize))
27 | else:
28 | raise NotImplementedError
29 |
30 | transform = Compose(transforms)
31 |
32 | return transform
33 |
34 |
35 | def get_virtual_lidar_to_vehicle_transform():
36 | # This is a fake lidar coordinate
37 | T = np.eye(4)
38 | T[0, 3] = 1.3
39 | T[1, 3] = 0.0
40 | T[2, 3] = 2.5
41 | return T
42 |
43 |
44 | def get_vehicle_to_virtual_lidar_transform():
45 | return np.linalg.inv(get_virtual_lidar_to_vehicle_transform())
46 |
47 |
48 | def transform_waypoints(waypoints):
49 | """transform waypoints to be origin at ego_matrix"""
50 | vehicle_matrix = np.array(waypoints[0])
51 | for i in range(
52 | 0, len(waypoints)
53 | ): # TODO: Start from 1 because 0 is ego vehicle initial position
54 | matrix = np.array(waypoints[i])
55 | waypoints[i] = vehicle_matrix @ np.linalg.inv(matrix)
56 |
57 | return waypoints
58 |
59 |
60 | ########################################################################
61 |
62 | ############## WARPING
63 |
64 | ###################################################################3####
65 |
66 |
67 | def compute_relative_transform(origin, current, pix_per_meter=5):
68 | result = torch.bmm(torch.linalg.inv(origin), (current))
69 |
70 | return result
71 |
72 |
73 | def get_affine_grid_transform(origin, current, inp_size=400, pix_per_meter=5):
74 | relative_transform = compute_relative_transform(origin, current, pix_per_meter)
75 |
76 | translation = relative_transform[:, :2, 3:] / ((inp_size / 2) / pix_per_meter)
77 | translation[:, [0, 1]] = translation[:, [1, 0]]
78 |
79 | affine_grid_transform = torch.cat(
80 | (torch.transpose(relative_transform[:, :2, :2], 2, 1), translation), axis=2
81 | )
82 |
83 | # rot x, y. dont take height.
84 | # affine_grid_transform = torch.from_numpy(affine_grid_transform).float()
85 |
86 | return affine_grid_transform
87 |
88 |
89 | def warp_sequence(x, ego_matrices, mode="nearest", spatial_extent=None):
90 | """
91 | Batch-compatible warping function.
92 |
93 | Warps a sequence based on the first frame using ego vehicle transformation matrices.
94 | """
95 | sequence_length = x.shape[1]
96 | if sequence_length == 1:
97 | return x
98 |
99 | out = [x[:, 0]]
100 |
101 | # print('X. SHAPE ', x.shape)
102 | base_frame = ego_matrices[:, 0] # torch.from_numpy()
103 |
104 | for t in range(1, sequence_length):
105 | curr_frame = ego_matrices[:, t] # torch.from_numpy()
106 | aff_grid = get_affine_grid_transform(
107 | base_frame, curr_frame, inp_size=x.shape[-1], pix_per_meter=5
108 | ) # .unsqueeze(0)
109 |
110 | grid = torch.nn.functional.affine_grid(
111 | aff_grid, size=x[:, 0].shape, align_corners=False
112 | )
113 |
114 | warped_bev = torch.nn.functional.grid_sample(
115 | (x[:, t]),
116 | grid.float(),
117 | mode="nearest",
118 | padding_mode="zeros",
119 | align_corners=False,
120 | )
121 |
122 | out.append(warped_bev)
123 |
124 | return torch.stack(out, 1)
125 |
126 |
127 | ### Forecast utils
128 | def extract_forecast_targets(
129 | timesteps,
130 | context_length,
131 | future_horizon,
132 | forecast_steps=1,
133 | use_slots=True,
134 | object_level=True,
135 | ):
136 | assert len(timesteps) == context_length + future_horizon
137 | for i in range(context_length):
138 | timesteps[i].state["target_rgb_front"] = timesteps[i + forecast_steps].state[
139 | "rgb_front"
140 | ]
141 |
142 |
143 | def circle_line_segment_intersection(
144 | circle_center, circle_radius, pt1, pt2, full_line=True, tangent_tol=1e-9
145 | ):
146 | """Find the points at which a circle intersects a line-segment. This can happen at 0, 1, or 2 points.
147 |
148 | :param circle_center: The (x, y) location of the circle center
149 | :param circle_radius: The radius of the circle
150 | :param pt1: The (x, y) location of the first point of the segment
151 | :param pt2: The (x, y) location of the second point of the segment
152 | :param full_line: True to find intersections along full line - not just in the segment.
153 | False will just return intersections within the segment.
154 | :param tangent_tol: Numerical tolerance at which we decide the intersections are close enough to consider it a
155 | tangent
156 | :return Sequence[Tuple[float, float]]: A list of length 0, 1, or 2, where each element is a point at which the
157 | circle intercepts a line segment.
158 |
159 | Note: We follow: http://mathworld.wolfram.com/Circle-LineIntersection.html
160 | Credit: https://stackoverflow.com/a/59582674/9173068
161 | """
162 |
163 | if np.linalg.norm(pt1 - pt2) < 0.000000001:
164 | # print('Problem')
165 | return []
166 |
167 | (p1x, p1y), (p2x, p2y), (cx, cy) = pt1, pt2, circle_center
168 | (x1, y1), (x2, y2) = (p1x - cx, p1y - cy), (p2x - cx, p2y - cy)
169 | dx, dy = (x2 - x1), (y2 - y1)
170 | dr = (dx**2 + dy**2) ** 0.5
171 | big_d = x1 * y2 - x2 * y1
172 | discriminant = circle_radius**2 * dr**2 - big_d**2
173 |
174 | if discriminant < 0: # No intersection between circle and line
175 | return []
176 | else: # There may be 0, 1, or 2 intersections with the segment
177 | # This makes sure the order along the segment is correct
178 | # intersections = [(cx + (big_d * dy + sign * (-1 if dy < 0 else 1) * dx * discriminant**.5) / dr**2,
179 | # cy + (-big_d * dx + sign * abs(dy) * discriminant**.5) / dr**2)
180 | # for sign in ((1, -1) if dy < 0 else (-1, 1))]
181 |
182 | # Write explicitly to avoid iteration
183 | if dy < 0:
184 | sign_1 = 1
185 | sign_2 = -1
186 | else:
187 | sign_1 = -1
188 | sign_2 = 1
189 |
190 | intersections = [
191 | (
192 | cx + (big_d * dy + sign_1 * sign_2 * dx * discriminant**0.5) / dr**2,
193 | cy + (-big_d * dx + sign_1 * abs(dy) * discriminant**0.5) / dr**2,
194 | ),
195 | (
196 | cx + (big_d * dy + sign_2 * sign_2 * dx * discriminant**0.5) / dr**2,
197 | cy + (-big_d * dx + sign_2 * abs(dy) * discriminant**0.5) / dr**2,
198 | ),
199 | ]
200 |
201 | if (
202 | not full_line
203 | ): # If only considering the segment, filter out intersections that do not fall within the segment
204 | fraction_along_segment = [
205 | (xi - p1x) / dx if abs(dx) > abs(dy) else (yi - p1y) / dy
206 | for xi, yi in intersections
207 | ]
208 | intersections = [
209 | pt
210 | for pt, frac in zip(intersections, fraction_along_segment)
211 | if 0 <= frac <= 1
212 | ]
213 | # If line is tangent to circle, return just one point (as both intersections have same location)
214 | if len(intersections) == 2 and abs(discriminant) <= tangent_tol:
215 | return [intersections[0]]
216 | else:
217 | return intersections
218 |
219 |
220 | import numpy as np
221 |
222 |
223 | def iterative_line_interpolation(route, num_points=20):
224 | if not isinstance(route, np.ndarray):
225 | route = np.array(route)
226 |
227 | interpolated_route_points = []
228 |
229 | min_distance = 0.5
230 | last_interpolated_point = np.array([0.0, 0.0])
231 | current_route_index = 0
232 | current_point = route[current_route_index]
233 | last_point = route[current_route_index]
234 |
235 | while len(interpolated_route_points) < num_points:
236 | # First point should be min_distance away from the vehicle.
237 | dist = np.linalg.norm(current_point - last_interpolated_point)
238 | if dist < min_distance:
239 | current_route_index += 1
240 | last_point = current_point
241 |
242 | if current_route_index < route.shape[0]:
243 | current_point = route[current_route_index]
244 | intersection = circle_line_segment_intersection(
245 | circle_center=last_interpolated_point,
246 | circle_radius=min_distance,
247 | pt1=last_point,
248 | pt2=current_point,
249 | full_line=False,
250 | )
251 |
252 | else: # We hit the end of the input route. We extrapolate the last 2 points
253 | current_point = route[-1]
254 | last_point = route[-2]
255 | intersection = circle_line_segment_intersection(
256 | circle_center=last_interpolated_point,
257 | circle_radius=min_distance,
258 | pt1=last_point,
259 | pt2=current_point,
260 | full_line=True,
261 | )
262 |
263 | # 3 cases: 0 intersection, 1 intersection, 2 intersection
264 | if len(intersection) > 1: # 2 intersections
265 | # Take the one that is closer to current point
266 | point_1 = np.array(intersection[0])
267 | point_2 = np.array(intersection[1])
268 | direction = current_point - last_point
269 | dot_p1_to_last = np.dot(point_1, direction)
270 | dot_p2_to_last = np.dot(point_2, direction)
271 |
272 | if dot_p1_to_last > dot_p2_to_last:
273 | intersection_point = point_1
274 | else:
275 | intersection_point = point_2
276 | add_point = True
277 | elif len(intersection) == 1: # 1 Intersections
278 | intersection_point = np.array(intersection[0])
279 | add_point = True
280 | else: # 0 Intersection
281 | add_point = False
282 |
283 | if add_point:
284 | last_interpolated_point = intersection_point
285 | interpolated_route_points.append(intersection_point)
286 | min_distance = 1.0 # After the first point we want each point to be 1 m away from the last.
287 |
288 | interpolated_route_points = np.array(interpolated_route_points)
289 |
290 | return interpolated_route_points
291 |
292 |
293 | def interpolate_waypoints(wps, num_points=10):
294 | assert len(wps) > 1, "Need at least 2 waypoints to interpolate"
295 | last_vector = wps[-1] - wps[-2]
296 |
297 | if len(wps) >= num_points:
298 | return wps[:num_points]
299 |
300 | interpolated_points = []
301 | for i in range(len(wps) - num_points):
302 | # interpolated_points.append(wps[i] + last_vector * (i + 1))
303 | interpolated_points.append((-99999, -99999))
304 |
305 | return np.concatenate([wps, np.stack(interpolated_points)], axis=0)
306 |
307 |
308 | class iterative_intepolator:
309 | def __init__(self, num_points=20):
310 | self.num_points = num_points
311 | self.min_distance = 0.5
312 | self.last_interpolated_point = np.array([0.0, 0.0])
313 | self.current_route_index = 0
314 | self.current_point = None
315 | self.last_point = None
316 | self.interpolated_route_points = []
317 | self.last_inputs = []
318 |
319 | def __call__(self, point):
320 | if self.current_point is None:
321 | self.current_point = point
322 | self.last_point = point
323 |
324 | if point is not None:
325 | self.last_inputs.append(point)
326 | self.last_inputs = self.last_inputs[-2:]
327 |
328 | dist = np.linalg.norm(self.current_point - self.last_interpolated_point)
329 | if dist < self.min_distance:
330 | # self.current_route_index += 1
331 | self.last_point = self.current_point
332 | return len(self)
333 |
334 | self.current_point = np.asarray(point)
335 | intersection = circle_line_segment_intersection(
336 | circle_center=self.last_interpolated_point,
337 | circle_radius=self.min_distance,
338 | pt1=self.last_point,
339 | pt2=self.current_point,
340 | full_line=False,
341 | )
342 | else:
343 | current_point = self.last_inputs[-1]
344 | last_point = self.last_inputs[-2]
345 | intersection = circle_line_segment_intersection(
346 | circle_center=last_interpolated_point,
347 | circle_radius=min_distance,
348 | pt1=last_point,
349 | pt2=current_point,
350 | full_line=True,
351 | )
352 |
353 | if len(intersection) > 1: # 2 intersections
354 | # Take the one that is closer to current point
355 | point_1 = np.array(intersection[0])
356 | point_2 = np.array(intersection[1])
357 | direction = current_point - last_point
358 | dot_p1_to_last = np.dot(point_1, direction)
359 | dot_p2_to_last = np.dot(point_2, direction)
360 |
361 | if dot_p1_to_last > dot_p2_to_last:
362 | intersection_point = point_1
363 | else:
364 | intersection_point = point_2
365 | elif len(intersection) == 1: # 1 Intersections
366 | intersection_point = np.array(intersection[0])
367 | else: # 0 Intersection
368 | return len(self)
369 |
370 | self.last_interpolated_point = intersection_point
371 | self.interpolated_route_points.append(intersection_point)
372 | self.min_distance = 1.0 # After the first point we want each point to be 1 m away from the last.
373 |
374 | def get_interpolated_points(self):
375 | if len(self.interpolated_route_points) < self.num_points:
376 | for i in range(len(self.interpolated_route_points), self.num_points):
377 | self(None)
378 |
379 | print(len(self))
380 |
381 | return self.interpolated_route_points
382 |
383 | def __len__(self):
384 | return self.num_points
385 |
386 |
387 | # Adapted from https://github.com/OpenDriveLab/TCP/blob/9ec4db0f0424801cdd607f1de930290830c5e88e/leaderboard/team_code/auto_pilot.py#L339
388 | def get_hazard_directions(vehicle_list):
389 | ego_vehicles = [x for x in vehicle_list if x["class"] == "ego_vehicle"]
390 |
391 | if len(ego_vehicles) == 0:
392 | return []
393 |
394 | if len(ego_vehicles) > 1:
395 | print("More than one ego vehicle found")
396 | return []
397 |
398 | ego_vehicle = ego_vehicles[0]
399 |
400 | z = ego_vehicle["location"][1]
401 |
402 | o1 = _orientation(ego_vehicle["rotation"][-1])
403 | p1 = np.asarray(ego_vehicle["location"][:2])
404 | s1 = max(2, 3.0 * ego_vehicle["speed"]) # increases the threshold distance
405 | v1_hat = o1
406 | v1 = s1 * v1_hat
407 |
408 | hazard_directions = []
409 |
410 | for target_vehicle in vehicle_list:
411 | if target_vehicle["class"] == "ego_vehicle":
412 | continue
413 |
414 | if target_vehicle.get("base_type", None) != "car":
415 | continue
416 |
417 | o2 = _orientation(target_vehicle["rotation"][-1])
418 | p2 = np.asarray(target_vehicle["location"][:2])
419 | s2 = max(5.0, 2.0 * target_vehicle["speed"])
420 | v2_hat = o2
421 | v2 = s2 * v2_hat
422 |
423 | p2_p1 = p2 - p1
424 | distance = np.linalg.norm(p2_p1)
425 | p2_p1_hat = p2_p1 / (distance + 1e-4)
426 |
427 | angle_to_car = np.degrees(np.arccos(v1_hat.dot(p2_p1_hat)))
428 |
429 | angle_between_heading = np.degrees(np.arccos(np.clip(o1.dot(o2), -1, 1)))
430 |
431 | # print()
432 | angle_from_ego = np.degrees(np.arccos(v2_hat.dot(p2_p1_hat)))
433 |
434 | # to consider -ve angles too
435 | angle_to_car = min(angle_to_car, 360.0 - angle_to_car)
436 | angle_between_heading = min(
437 | angle_between_heading, 360.0 - angle_between_heading
438 | )
439 |
440 | if angle_between_heading > 60.0 and not (angle_to_car < 15 and distance < s1):
441 | continue
442 | elif angle_to_car > 30.0:
443 | continue
444 | elif distance > s1:
445 | continue
446 |
447 | # print(target_vehicle["type_id"], target_vehicle["distance"], target_vehicle["color"])
448 | # print("s1", s1, "dist", distance, "tgt_dist", target_vehicle["distance"])
449 | # print(angle_to_car, angle_between_heading, distance, s1)
450 | # print("angle from ego", angle_from_ego)
451 |
452 | hazard_directions.append(angle_from_ego)
453 |
454 | return hazard_directions
455 |
456 |
457 | def _orientation(yaw):
458 | return np.float32([np.cos(np.radians(yaw)), np.sin(np.radians(yaw))])
459 |
460 |
461 | def get_collision(p1, v1, p2, v2):
462 | A = np.stack([v1, -v2], 1)
463 | b = p2 - p1
464 |
465 | if abs(np.linalg.det(A)) < 1e-3:
466 | return False, None
467 |
468 | x = np.linalg.solve(A, b)
469 | collides = all(x >= 0) and all(x <= 4) # how many seconds until collision
470 |
471 | return collides, p1 + x[0] * v1
472 |
473 |
474 | def is_walker_hazard(objects_list):
475 | ego_vehicles = [x for x in objects_list if x["class"] == "ego_vehicle"]
476 | if len(ego_vehicles) == 0:
477 | print("No ego vehicle found")
478 | return False
479 |
480 | ego_vehicle = ego_vehicles[0]
481 |
482 | z = ego_vehicle["location"][-1]
483 |
484 | walkers = [x for x in objects_list if x["class"] == "walker"]
485 | p1 = np.asarray(ego_vehicle["location"][:2])
486 | v1 = 10.0 * _orientation(ego_vehicle["rotation"][-1])
487 |
488 | for walker in walkers:
489 | v2_hat = _orientation(walker["rotation"][-1])
490 | s2 = walker["speed"]
491 |
492 | if s2 < 0.05:
493 | v2_hat *= s2
494 |
495 | p2 = -3.0 * v2_hat + np.asarray(walker["location"][:2])
496 | v2 = 8.0 * v2_hat
497 |
498 | collides, collision_point = get_collision(p1, v1, p2, v2)
499 |
500 | if collides:
501 | return True
502 |
503 | return False
504 |
--------------------------------------------------------------------------------
/carformer/carformer/data/utils.py:
--------------------------------------------------------------------------------
1 | from carformer.data import (
2 | B2DSequenceDataset,
3 | DatasetPreloader,
4 | InMemoryDatasetPreloader,
5 | )
6 | import os
7 |
8 |
9 | def get_datasets(config, model=None, return_all=False, splits=["train", "val"]):
10 | if return_all:
11 | return _get_entire_dataset(
12 | config.data_dir,
13 | config.training,
14 | config.dataset.data_format,
15 | preload=config.preload,
16 | preload_in_memory=config.preload_in_memory,
17 | wipe_cache=config.wipe_cache,
18 | cache_dir=config.cache_dir,
19 | model=model,
20 | )
21 |
22 | return _get_datasets(
23 | config.data_dir,
24 | config.training,
25 | config.dataset.data_format,
26 | preload=config.preload,
27 | preload_in_memory=config.preload_in_memory,
28 | wipe_cache=config.wipe_cache,
29 | cache_dir=config.cache_dir,
30 | model=model,
31 | splits=splits,
32 | )
33 |
34 |
35 | def _get_datasets(
36 | data_dir,
37 | train_cfg,
38 | data_format,
39 | preload=False,
40 | preload_in_memory=False,
41 | wipe_cache=False,
42 | cache_dir="",
43 | model=None,
44 | splits=["train", "val"],
45 | ):
46 | # If is_plant is a boolean, give a deprecated warning
47 | if isinstance(data_format, bool):
48 | print(
49 | "Warning: is_plant boolean is deprecated, please use the data_format field in the dataset configuration"
50 | )
51 |
52 | if data_format == "plant":
53 | data_module = PlantSequenceDataset
54 | elif data_format == "b2d":
55 | data_module = B2DSequenceDataset
56 | elif data_format == "pdm":
57 | data_module = PDMSequenceDataset
58 | elif data_format == "tf":
59 | raise ValueError("Transfuser dataset not supported")
60 | # data_module = SequenceDataset
61 | else:
62 | raise ValueError(f"Invalid data format {data_format}")
63 |
64 | if "train" in splits:
65 | train_dataset = data_module(
66 | data_dir,
67 | train_cfg.splits.train,
68 | train_cfg,
69 | )
70 | else:
71 | train_dataset = None
72 |
73 | if "val" in splits:
74 | val_dataset = data_module(
75 | data_dir,
76 | train_cfg.splits.val,
77 | train_cfg,
78 | )
79 | else:
80 | val_dataset = None
81 |
82 | if preload:
83 | assert cache_dir != "", "Cache dir must be specified if preloading is enabled"
84 |
85 | preloader = (
86 | DatasetPreloader if not preload_in_memory else InMemoryDatasetPreloader
87 | )
88 | args = []
89 |
90 | if train_dataset is not None:
91 | train_dataset = preloader(
92 | train_dataset,
93 | os.path.join(cache_dir, train_dataset.get_parametrized_dirname()),
94 | *args,
95 | wipe_cache=wipe_cache,
96 | )
97 |
98 | if val_dataset is not None:
99 | val_dataset = preloader(
100 | val_dataset,
101 | os.path.join(cache_dir, val_dataset.get_parametrized_dirname()),
102 | *args,
103 | wipe_cache=wipe_cache,
104 | )
105 | if train_dataset is not None:
106 | train_dataset.load_state()
107 |
108 | if val_dataset is not None:
109 | val_dataset.load_state()
110 |
111 | return train_dataset, val_dataset
112 |
113 |
114 | def _get_entire_dataset(
115 | data_dir,
116 | train_cfg,
117 | is_plant,
118 | preload=False,
119 | preload_in_memory=False,
120 | wipe_cache=False,
121 | cache_dir="",
122 | model=None,
123 | ):
124 | if data_format == "b2d":
125 | data_module = B2DSequenceDataset
126 | else:
127 | raise ValueError(f"Invalid data format {data_format}")
128 |
129 | all_dataset = data_module(
130 | data_dir,
131 | "all",
132 | train_cfg,
133 | )
134 |
135 | if preload:
136 | assert cache_dir != "", "Cache dir must be specified if preloading is enabled"
137 |
138 | preloader = (
139 | DatasetPreloader if not preload_in_memory else InMemoryDatasetPreloader
140 | )
141 |
142 | args = []
143 |
144 | all_dataset = preloader(
145 | all_dataset,
146 | os.path.join(cache_dir, all_dataset.get_parametrized_dirname()),
147 | *args,
148 | wipe_cache=wipe_cache,
149 | )
150 |
151 | all_dataset.load_state()
152 |
153 | return all_dataset
154 |
--------------------------------------------------------------------------------
/carformer/carformer/data/wrapper.py:
--------------------------------------------------------------------------------
1 | # A wrapper around a dataset
2 | # First, it loads batches, 1 instance at a time, then writes it into a disk cache
3 | # Future calls to load batches will load from the cache instead of the original dataset
4 | # This is to avoid the overhead of loading from disk every time
5 |
6 | from skit import DatasetPreloader, InMemoryDatasetPreloader, AugmentableDatasetPreloader
7 |
--------------------------------------------------------------------------------
/carformer/carformer/perception/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OpenDriveLab/ETA/cdbe5e1174cf33028d5cc26921f4fb72bf9b92f5/carformer/carformer/perception/__init__.py
--------------------------------------------------------------------------------
/carformer/carformer/perception/rgb.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import yaml
4 | from transformers import AutoModel, AutoImageProcessor, AutoConfig, CONFIG_MAPPING
5 |
6 | try:
7 | from transformers.models.auto.image_processing_auto import (
8 | image_processor_class_from_name,
9 | )
10 | except ImportError:
11 | from transformers.models.auto.image_processing_auto import (
12 | get_image_processor_class_from_name as image_processor_class_from_name,
13 | )
14 |
15 | rgb_models = {}
16 |
17 |
18 | class RGBEncoder(nn.Module):
19 | def __init__(self, config):
20 | super(RGBEncoder, self).__init__()
21 | # Load config
22 | if isinstance(config, str):
23 | with open(config, "r") as file:
24 | try:
25 | config = yaml.safe_load(file)
26 | self.config = config
27 | except yaml.YAMLError as exc:
28 | print(exc)
29 | raise exc
30 | else:
31 | self.config = config
32 |
33 | pretrained_path = self.config.get("model_path", None)
34 | pretrained_override_kwargs = self.config.get("override_kwargs", {})
35 |
36 | self.ema_enabled = self.config.get("ema_enabled", False)
37 | self.keep_time_dim = self.config.get("keep_time_dim", False)
38 |
39 | if pretrained_path is not None:
40 | print(
41 | "Overriding main model config using kwargs: ",
42 | pretrained_override_kwargs,
43 | )
44 | self.cfg = AutoConfig.from_pretrained(
45 | pretrained_path, **pretrained_override_kwargs
46 | )
47 |
48 | self.processor = AutoImageProcessor.from_pretrained(pretrained_path)
49 | self.model = AutoModel.from_pretrained(
50 | pretrained_path, **pretrained_override_kwargs
51 | )
52 |
53 | self.config["override_kwargs"] = {}
54 | self.config["model_config"] = self.cfg.to_dict()
55 | self.config["processor_config"] = self.processor.to_dict()
56 |
57 | if self.ema_enabled:
58 | self.ema_model = AutoModel.from_pretrained(
59 | pretrained_path, **pretrained_override_kwargs
60 | )
61 | else:
62 | self.ema_model = None
63 | else:
64 | assert (
65 | "model_config" in self.config
66 | ), "model_config must be provided if model_path is not provided"
67 | assert (
68 | "processor_config" in self.config
69 | ), "processor_config must be provided if model_path is not provided"
70 |
71 | model_type = self.config["model_config"]["model_type"]
72 | image_processor_type = self.config["processor_config"][
73 | "image_processor_type"
74 | ]
75 |
76 | self.cfg = CONFIG_MAPPING[model_type].from_dict(self.config["model_config"])
77 | self.processor = image_processor_class_from_name(
78 | image_processor_type
79 | ).from_dict(self.config["processor_config"])
80 | self.model = AutoModel.from_config(self.cfg)
81 |
82 | if self.ema_enabled:
83 | self.ema_model = AutoModel.from_config(self.cfg)
84 | else:
85 | self.ema_model = None
86 |
87 | self.projector = nn.Linear(
88 | self.cfg.hidden_size, self.config.projection_dim
89 | ) # self.cfg: AutoConfig, backbone specific. self.config: user provided config, contains main backbone hidden dim
90 |
91 | self.select_layer = self.config.get("select_layer", None)
92 | # Dummy rgb encoder utils
93 | self.dummy_output = self.config.get("dummy_output", False)
94 |
95 | self.move_channels_last_and_flatten = self.config.get(
96 | "move_channels_last_and_flatten", False
97 | )
98 | self.ignore_cls = self.config.get("ignore_cls", True)
99 |
100 | if self.select_layer is None:
101 | self.select_layer = -2
102 | print("Warning: select_layer not provided, using default -2")
103 |
104 | init_config = self.config.get("init_from_ckpt", None)
105 | if init_config is not None and init_config["enabled"]:
106 | init_config = init_config.copy()
107 | self.config["init_from_ckpt"]["enabled"] = False # Do not init again
108 | ckpt = torch.load(init_config["ckpt_path"], map_location="cpu")
109 | if init_config["model"]:
110 | mdl_dct = {
111 | k[len("model.") :]: v
112 | for k, v in ckpt.items()
113 | if k.startswith("model.")
114 | }
115 |
116 | self.model.load_state_dict(mdl_dct, strict=True)
117 |
118 | if init_config["ema_model"] and self.ema_enabled:
119 | ema_dct = {
120 | k[len("model.") :]: v
121 | for k, v in ckpt.items()
122 | if k.startswith("model.")
123 | }
124 | self.ema_model.load_state_dict(ema_dct, strict=True)
125 |
126 | if init_config["projector"]:
127 | proj_dct = {
128 | k[len("projector.") :]: v
129 | for k, v in ckpt.items()
130 | if k.startswith("projector.")
131 | }
132 | self.projector.load_state_dict(proj_dct, strict=True)
133 |
134 | self.config["frozen"]["projector"] = init_config["projector"]
135 | self.config["frozen"]["model"] = init_config["model"]
136 | self.config["frozen"]["ema_model"] = init_config["ema_model"]
137 |
138 | self.try_to_truncate_layers = self.config.get("try_to_truncate_layers", False)
139 |
140 | if self.try_to_truncate_layers:
141 | import transformers
142 |
143 | if isinstance(
144 | self.model, transformers.models.clip.modeling_clip.CLIPVisionModel
145 | ):
146 | layer_to_select = self.select_layer
147 | if layer_to_select < 0:
148 | layer_to_select = (
149 | len(self.model.vision_model.encoder.layers) + layer_to_select
150 | )
151 |
152 | self.model.vision_model.encoder.layers = (
153 | self.model.vision_model.encoder.layers[: layer_to_select + 1]
154 | )
155 |
156 | if self.ema_enabled:
157 | self.ema_model.vision_model.encoder.layers = (
158 | self.ema_model.vision_model.encoder.layers[
159 | : layer_to_select + 1
160 | ]
161 | )
162 |
163 | self.select_layer = -1
164 | else:
165 | print(
166 | "Warning: try_to_truncate_layers is set to True, but model is not CLIP. "
167 | "This will not truncate any layers."
168 | )
169 |
170 | frozen = self.config.get("frozen", False)
171 | self.disable_ema_update = False
172 | if isinstance(frozen, bool):
173 | if frozen:
174 | self.model.requires_grad_(False)
175 | self.model.eval()
176 | self.frozen = True
177 | else:
178 | self.frozen = False
179 | self.frozen_dict = {}
180 | else:
181 | if frozen["model"]:
182 | self.model.requires_grad_(False)
183 | self.model.eval()
184 | if frozen["ema_model"]:
185 | if not self.ema_enabled:
186 | self.model.requires_grad_(False)
187 | self.model.eval()
188 | else:
189 | self.ema_model.requires_grad_(False)
190 | self.ema_model.eval()
191 | self.disable_ema_update = True
192 | if frozen["projector"]:
193 | self.projector.requires_grad_(False)
194 | self.frozen = False
195 | self.frozen_dict = {
196 | "model": frozen["model"]
197 | or (frozen["ema_model"] and not self.ema_enabled),
198 | "ema_model": frozen["ema_model"],
199 | "projector": frozen["projector"],
200 | }
201 |
202 | self.output_whole = self.config["outputs"]["whole"]
203 | self.output_patches = self.config["outputs"]["patches"]
204 |
205 | # center crop of width self.config.input_size*2, height self.config.input_size
206 | def center_crop(x):
207 | wdesired = self.config.input_size * 2
208 | hdesired = self.config.input_size
209 | wcur = x.shape[-2]
210 | hcur = x.shape[-3]
211 |
212 | return x[
213 | ...,
214 | (hcur - hdesired) // 2 : (hcur + hdesired) // 2,
215 | (wcur - wdesired) // 2 : (wcur + wdesired) // 2,
216 | :,
217 | ]
218 |
219 | self.center_crop = center_crop
220 |
221 | self.camera_embeddings = nn.Embedding(2, self.cfg.hidden_size)
222 |
223 | self.masking_rate = self.config.get("masking_rate", 0.0)
224 |
225 | self.do_downsample = self.config.get("downsample", False)
226 | self.downsample_type = self.config.get("downsample_type", "conv")
227 |
228 | if self.do_downsample:
229 | if self.downsample_type == "conv":
230 | self.downsample = nn.Conv2d(
231 | self.cfg.hidden_size,
232 | self.cfg.hidden_size,
233 | kernel_size=2,
234 | stride=2,
235 | padding=0,
236 | bias=False,
237 | )
238 | elif self.downsample_type == "avgpool":
239 | self.downsample = nn.AvgPool2d(2)
240 | else:
241 | raise ValueError("Invalid downsample type")
242 | else:
243 | self.downsample = None
244 |
245 | # Display warning that output_whole, patches is ignored currently and will be implemented later
246 | print(
247 | "Warning: output_whole, output_patches is ignored currently and will be implemented later"
248 | )
249 |
250 | def forward(self, x, y=None, ema=False):
251 | # Dims: B, T, NPatch, 3, H, W
252 | if self.frozen:
253 | self.model.eval()
254 |
255 | with torch.no_grad():
256 | return self.encode(x, y, ema=ema)
257 | else:
258 | return self.encode(x, y, ema=ema)
259 |
260 | def encode(self, x, y=None, ema=False):
261 | # print("x shape: ", x.shape, "y shape: ", y.shape if y is not None else None)
262 | # x = self.center_crop(x) # TODO: This is too hacky, fix this
263 | original_shape = x.shape
264 | if len(original_shape) == 5:
265 | T = original_shape[1]
266 | x = x.reshape(
267 | original_shape[0] * original_shape[1], *original_shape[2:]
268 | ) # B*T, H, W, 3
269 | else:
270 | T = 1
271 | # Encode
272 | prepped_x = self.processor(
273 | list(x.squeeze(1).permute(0, 3, 1, 2).half()), return_tensors="pt"
274 | )["pixel_values"]
275 | prepped_y = (
276 | self.processor(
277 | list(y.squeeze(1).permute(0, 3, 1, 2).half()),
278 | return_tensors="pt",
279 | do_normalize=False,
280 | do_convert_rgb=False,
281 | )["pixel_values"]
282 | if y is not None
283 | else None
284 | )
285 |
286 | # Prep the inputs before feeding to the vision encoders
287 | # Remove first patch, which is the entire image
288 | prepped_x = prepped_x[:, 1:]
289 | if self.frozen_dict.get("ema_model", False) and self.ema_enabled:
290 | self.ema_model.eval()
291 |
292 | if self.frozen_dict.get("model", False):
293 | self.model.eval()
294 |
295 | if prepped_y is not None:
296 | prepped_y = prepped_y[:, 1:]
297 |
298 | B, Npatch = prepped_x.shape[:2]
299 | By = prepped_y.shape[0] if prepped_y is not None else None
300 |
301 | if self.dummy_output:
302 | return self.projector(
303 | torch.zeros(B, Npatch, self.cfg.hidden_size)
304 | .to(self.model.device)
305 | .to(self.model.dtype)
306 | )
307 |
308 | # Flatten, encode
309 | if self.keep_time_dim:
310 | prepped_x = prepped_x.reshape(B // T, T, *prepped_x.shape[1:])
311 | # swap dim 1 and 2
312 | prepped_x = prepped_x.permute(0, 2, 1, 3, 4, 5)
313 |
314 | if prepped_y is not None:
315 | prepped_y = prepped_y.reshape(By, 1, *prepped_y.shape[1:])
316 | prepped_y = prepped_y.permute(0, 2, 1, 3, 4, 5)
317 |
318 | elif ema and self.ema_enabled:
319 | vision_feats = self.ema_model(
320 | prepped_x.flatten(0, 1)
321 | .to(self.ema_model.device)
322 | .to(self.ema_model.dtype),
323 | output_hidden_states=True,
324 | )["hidden_states"]
325 | else:
326 |
327 | vision_feats = self.model(
328 | prepped_x.flatten(0, 1).to(self.model.device).to(self.model.dtype),
329 | output_hidden_states=True,
330 | )["hidden_states"]
331 |
332 | if prepped_y is not None:
333 | with torch.no_grad():
334 | if hasattr(self.model, "vision_model"):
335 | vision_patch_labels = (
336 | self.model.vision_model.embeddings.patch_embedding(
337 | prepped_y.flatten(0, 1)
338 | .to(self.model.device)
339 | .to(self.model.dtype)
340 | )
341 | )
342 | else:
343 | raise ValueError(
344 | "Model does not have a patcher, cannot process labels"
345 | )
346 |
347 | if self.do_downsample:
348 | vision_patch_labels = self.downsample(vision_patch_labels)
349 |
350 | vision_patch_labels = vision_patch_labels.abs().sum(1).flatten(1, 2)
351 |
352 | if self.ignore_cls:
353 | # Vision feat to use, default from llama, ignore cls
354 | vision_feats = vision_feats[self.select_layer][:, 1:]
355 | else:
356 | assert y is None, "y is not None, but ignore_cls is False"
357 | assert self.do_downsample is False, "Cannot downsample with cls token"
358 | vision_feats = vision_feats[self.select_layer]
359 |
360 | if self.move_channels_last_and_flatten:
361 | vision_feats = vision_feats.permute(0, 2, 3, 1).flatten(1, 2)
362 |
363 | if self.do_downsample:
364 | vision_feat_side_size = int(vision_feats.shape[1] ** 0.5)
365 |
366 | vision_feats = vision_feats.view(
367 | B * Npatch, vision_feat_side_size, vision_feat_side_size, -1
368 | )
369 | # Permute hidden dim to channel dim
370 | vision_feats = vision_feats.permute(0, 3, 1, 2)
371 |
372 | vision_feats = self.downsample(vision_feats)
373 |
374 | vision_feats = vision_feats.permute(0, 2, 3, 1).flatten(
375 | 1, 2
376 | ) # B*Npatch, H, W, C -> B*Npatch, H'*W', C
377 |
378 | vision_feats = vision_feats.view(
379 | B // T if self.keep_time_dim else B, Npatch, *vision_feats.shape[1:]
380 | )
381 | vision_patch_labels = (
382 | vision_patch_labels.view(By, Npatch, *vision_patch_labels.shape[1:])
383 | if prepped_y is not None
384 | else None
385 | )
386 | vision_feats = vision_feats + self.camera_embeddings(
387 | torch.tensor([0, 1], device=vision_feats.device)
388 | ).unsqueeze(0).unsqueeze(-2)
389 |
390 | # Merge Npatch with vectors per image
391 | vision_feats = vision_feats.flatten(1, 2)
392 | if prepped_y is not None:
393 | vision_patch_labels = vision_patch_labels.flatten(1, 2)
394 |
395 | if self.masking_rate > 0.0 and self.training:
396 | mask = (
397 | torch.rand(B, vision_feats.shape[1], 1, device=vision_feats.device)
398 | > self.masking_rate
399 | )
400 | vision_feats = vision_feats * mask
401 |
402 | # Project to required dim (2048)
403 | if self.frozen_dict.get("projector", False):
404 | self.projector.eval()
405 |
406 | vision_feats = self.projector(vision_feats)
407 |
408 | if len(original_shape) == 5:
409 | if self.keep_time_dim:
410 | vision_feats = vision_feats.reshape(
411 | original_shape[0], -1, *vision_feats.shape[1:]
412 | )
413 | else:
414 | vision_feats = vision_feats.reshape(
415 | original_shape[0], original_shape[1], *vision_feats.shape[1:]
416 | )
417 | vision_patch_labels = (
418 | vision_patch_labels.reshape(
419 | original_shape[0], 1, *vision_patch_labels.shape[1:]
420 | )
421 | if prepped_y is not None
422 | else None
423 | )
424 | if prepped_y is not None:
425 | return vision_feats, vision_patch_labels
426 |
427 | return vision_feats
428 |
429 | def decode(self, z):
430 | raise NotImplementedError
431 |
432 | def interpret(self, z):
433 | raise NotImplementedError
434 |
--------------------------------------------------------------------------------
/carformer/carformer/ponderer_lit.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | import cv2
5 | import lightning as L
6 | import numpy as np
7 | import torch
8 | import wandb
9 | from torch.utils.data import DataLoader
10 |
11 | from carformer.data import get_datasets
12 | from carformer.ponderer import Ponderer
13 | from carformer.utils import (
14 | WeightedDistributedSampler,
15 | )
16 | from carformer.utils.distributed import get_rank
17 | from carformer.visualization.visutils import (
18 | visualize_input_from_batch,
19 | visualize_trajectory_action_predictions,
20 | )
21 |
22 |
23 | class PondererLit(L.LightningModule):
24 | def __init__(self, ponderer):
25 | super(PondererLit, self).__init__()
26 | self.ponderer = ponderer
27 |
28 | self.params = ponderer.cfg
29 |
30 | self.save_hyperparameters(ponderer.cfg.to_dict())
31 |
32 | print("Save dir: ", self.params.save_dir)
33 |
34 | def setup(self, stage):
35 | if stage != "fit":
36 | return
37 |
38 | folder = self.params.save_dir
39 |
40 | if self.trainer.is_global_zero:
41 | os.makedirs(folder, exist_ok=True)
42 | os.makedirs(os.path.join(folder, "train_predictions"), exist_ok=True)
43 | os.makedirs(os.path.join(folder, "val_predictions"), exist_ok=True)
44 |
45 | # Save the model config and training args
46 | self.params.save_pretrained(folder)
47 | # Save all training arguments
48 | env_vars = []
49 |
50 | vars_to_save = [
51 | "HF_ENDPOINT",
52 | "OMP_NUM_THREADS",
53 | "MKL_NUM_THREADS",
54 | ]
55 |
56 | for var in vars_to_save:
57 | if var in os.environ:
58 | env_vars.append(f"{var}={os.environ[var]}")
59 |
60 | # Detect if ddp or not
61 | if "WORLD_SIZE" in os.environ:
62 | run_command = f"torchrun --nproc_per_node={os.environ['WORLD_SIZE']}"
63 | else:
64 | run_command = "python"
65 |
66 | # Save the command used to run the training into the file {folder}/command.sh
67 | with open(os.path.join(folder, "command.sh"), "w") as f:
68 | all_args = env_vars + [run_command] + sys.argv
69 | f.write(" ".join(all_args))
70 | f.write("\n")
71 | # If wandb is used, log the url to the run
72 | try:
73 | url = wandb.run.get_url()
74 | f.write(f"Wandb URL: {url}")
75 | f.write("\n")
76 | except:
77 | pass
78 |
79 | def training_step(self, batch, batch_idx):
80 | preds, loss, pred_labels, preprocessed_inputs = self.ponderer(
81 | batch, return_labels=True, return_inputs=True
82 | )
83 |
84 | loss = {k: v.mean() for k, v in loss.items()}
85 |
86 | self.log_dict(
87 | {f"train/loss/{k}": v for k, v in loss.items()},
88 | sync_dist=True,
89 | prog_bar=True,
90 | )
91 | rank = get_rank()
92 |
93 | world_size = self.trainer.world_size
94 | # import ipdb; ipdb.set_trace()
95 |
96 | if (
97 | batch_idx < 16
98 | and self.params.visualize
99 | and self.trainer.current_epoch > self.params.visualize_start_epoch
100 | and (self.trainer.current_epoch % self.params.visualize_interval == 0)
101 | and rank < 4 # only log from first 4 ranks
102 | ):
103 | for idd in range(max(self.params.hyperparams.batch_size // 8, 1)):
104 | # import ipdb; ipdb.set_trace()
105 |
106 | ac_val = batch["action"][0][0, 20][0].item()
107 |
108 | impath = visualize_input_from_batch(
109 | batch,
110 | idd,
111 | preds,
112 | pred_labels,
113 | self.params.save_dir,
114 | f"ep{self.trainer.current_epoch}_st{rank*8 + idd + world_size*4*batch_idx}_ac_{ac_val:.2f}",
115 | self.ponderer,
116 | "train",
117 | )
118 | dir_preds = os.path.join(self.params.save_dir, "train_predictions")
119 |
120 | os.makedirs(dir_preds, exist_ok=True)
121 |
122 | path2 = visualize_trajectory_action_predictions(
123 | batch,
124 | preds,
125 | labels=pred_labels,
126 | save_dir=dir_preds,
127 | model=self.ponderer,
128 | save_suffix="ep{}_st{}".format(
129 | self.trainer.current_epoch,
130 | rank * 8 + idd + world_size * 4 * batch_idx,
131 | ),
132 | save_idx=idd,
133 | action_source="transformer-regression",
134 | visualize_gt=True,
135 | )
136 | if not preds["bev"] is None:
137 | pred_heatmaps = preds["bev"].float().detach().cpu().numpy()[idd]
138 |
139 | label_heatmaps = (
140 | pred_labels["bev_mask"].float().detach().cpu().numpy()[idd]
141 | )
142 |
143 | size_per_patch = label_heatmaps.shape[-1] // 2
144 |
145 | patch_1 = (
146 | label_heatmaps[:size_per_patch].reshape(
147 | int(size_per_patch**0.5), int(size_per_patch**0.5)
148 | )
149 | > 0.5
150 | ) * 255
151 | patch_2 = (
152 | label_heatmaps[size_per_patch:].reshape(
153 | int(size_per_patch**0.5), int(size_per_patch**0.5)
154 | )
155 | > 0.5
156 | ) * 255
157 |
158 | patch_1_pred = (
159 | pred_heatmaps[:size_per_patch].reshape(
160 | int(size_per_patch**0.5), int(size_per_patch**0.5)
161 | )
162 | > 0.5
163 | ) * 255
164 | patch_2_pred = (
165 | pred_heatmaps[size_per_patch:].reshape(
166 | int(size_per_patch**0.5), int(size_per_patch**0.5)
167 | )
168 | > 0.5
169 | ) * 255
170 |
171 | patches = np.concatenate([patch_1, patch_2], axis=1)
172 | patches_pred = np.concatenate([patch_1_pred, patch_2_pred], axis=1)
173 |
174 | # Make pred red, label green
175 | patch = np.stack(
176 | [patches, patches_pred, np.zeros_like(patches)], axis=-1
177 | )
178 |
179 | cv2.imwrite(
180 | os.path.join(
181 | dir_preds,
182 | f"ep{self.trainer.current_epoch}_st{rank*8 + idd + world_size*4*batch_idx}_patch.png",
183 | ),
184 | patch,
185 | )
186 |
187 | return loss
188 |
189 | def validation_step(self, batch, batch_idx):
190 | preds, loss, pred_labels, preprocessed_inputs = self.ponderer(
191 | batch, return_labels=True, return_inputs=True
192 | )
193 |
194 | loss = {k: v.mean() for k, v in loss.items()}
195 |
196 | self.log_dict({f"val/loss/{k}": v for k, v in loss.items()}, sync_dist=True)
197 | rank = get_rank()
198 |
199 | world_size = self.trainer.world_size
200 |
201 | if (
202 | batch_idx < 16
203 | and self.params.visualize
204 | and self.trainer.current_epoch > self.params.visualize_start_epoch
205 | and (self.trainer.current_epoch % self.params.visualize_interval == 0)
206 | and rank < 4 # only log from first 4 ranks
207 | ):
208 | for idd in range(max(self.params.hyperparams.batch_size // 8, 1)):
209 | impath = visualize_input_from_batch(
210 | batch,
211 | idd,
212 | preds,
213 | pred_labels,
214 | self.params.save_dir,
215 | f"ep{self.trainer.current_epoch}_st{rank*8 + idd + world_size*4*batch_idx}",
216 | self.ponderer,
217 | "val",
218 | )
219 | dir_preds = os.path.join(self.params.save_dir, "val_predictions")
220 |
221 | os.makedirs(dir_preds, exist_ok=True)
222 |
223 | path2 = visualize_trajectory_action_predictions(
224 | batch,
225 | preds,
226 | labels=pred_labels,
227 | save_dir=dir_preds,
228 | model=self.ponderer,
229 | save_suffix="ep{}_st{}".format(
230 | self.trainer.current_epoch,
231 | rank * 8 + idd + world_size * 4 * batch_idx,
232 | ),
233 | save_idx=idd,
234 | action_source="transformer-regression",
235 | )
236 | if not preds["bev"] is None:
237 | pred_heatmaps = preds["bev"].float().detach().cpu().numpy()[idd]
238 |
239 | label_heatmaps = (
240 | pred_labels["bev_mask"].float().detach().cpu().numpy()[idd]
241 | )
242 |
243 | size_per_patch = label_heatmaps.shape[-1] // 2
244 |
245 | patch_1 = (
246 | label_heatmaps[:size_per_patch].reshape(
247 | int(size_per_patch**0.5), int(size_per_patch**0.5)
248 | )
249 | > 0.5
250 | ) * 255
251 | patch_2 = (
252 | label_heatmaps[size_per_patch:].reshape(
253 | int(size_per_patch**0.5), int(size_per_patch**0.5)
254 | )
255 | > 0.5
256 | ) * 255
257 |
258 | patch_1_pred = (
259 | pred_heatmaps[:size_per_patch].reshape(
260 | int(size_per_patch**0.5), int(size_per_patch**0.5)
261 | )
262 | > 0.5
263 | ) * 255
264 | patch_2_pred = (
265 | pred_heatmaps[size_per_patch:].reshape(
266 | int(size_per_patch**0.5), int(size_per_patch**0.5)
267 | )
268 | > 0.5
269 | ) * 255
270 |
271 | patches = np.concatenate([patch_1, patch_2], axis=1)
272 | patches_pred = np.concatenate([patch_1_pred, patch_2_pred], axis=1)
273 |
274 | # Make pred red, label green
275 | patch = np.stack(
276 | [patches, patches_pred, np.zeros_like(patches)], axis=-1
277 | )
278 |
279 | cv2.imwrite(
280 | os.path.join(
281 | dir_preds,
282 | f"ep{self.trainer.current_epoch}_st{rank*8 + idd + world_size*4*batch_idx}_patch.png",
283 | ),
284 | patch,
285 | )
286 |
287 | return loss
288 |
289 | def configure_optimizers(self):
290 | return_dict = {}
291 |
292 | params = self.parameters()
293 |
294 | opt_kwargs = self.params.hyperparams.optimizer.kwargs
295 |
296 | if "weight_decay" in opt_kwargs:
297 | # Pop weight decay from optimizer kwargs
298 | weight_decay = opt_kwargs.pop("weight_decay")
299 |
300 | # Create optimizer groups
301 | optim_groups = self.create_optimizer_groups(weight_decay)
302 | else:
303 | optim_groups = params
304 |
305 | opt = getattr(torch.optim, self.params.hyperparams.optimizer.name)(
306 | optim_groups,
307 | lr=self.params.hyperparams.lr,
308 | **self.params.hyperparams.optimizer.kwargs,
309 | )
310 | stepping_batches = self.trainer.estimated_stepping_batches
311 |
312 | from torch.optim.lr_scheduler import (
313 | CosineAnnealingWarmRestarts,
314 | )
315 |
316 | scheduler = CosineAnnealingWarmRestarts(
317 | opt, T_0=stepping_batches // 30, T_mult=2, eta_min=1e-6
318 | )
319 |
320 | return_dict["optimizer"] = opt
321 |
322 | return_dict["lr_scheduler"] = {
323 | "scheduler": scheduler,
324 | "interval": "step",
325 | }
326 |
327 | return return_dict
328 |
329 | def train_dataloader(self):
330 | train_dataset, _ = get_datasets(self.params, self.ponderer, splits=["train"])
331 | subsample_ratio = self.params.dataset.subsample_ratio
332 | if self.params.training.weighted_sampling:
333 | initial_weights = train_dataset.getweights()
334 |
335 | self.sample_weights = initial_weights
336 |
337 | bucket_names = train_dataset.get_bucket_names()
338 | self.bucket_names = bucket_names
339 |
340 | bucket_config = self.params.training.bucket_weights
341 |
342 | if bucket_config.type == "uniform":
343 | weights = np.asarray(initial_weights)
344 | weights = (weights / weights.sum(0)).mean(-1)
345 | subsample_ratio *= bucket_config.total_ratio
346 | if bucket_config.type == "preferturns":
347 | weights = np.asarray(initial_weights)
348 | weights = (
349 | (weights / weights.sum(0)) * np.asarray(bucket_config.weights)
350 | ).sum(-1) / np.asarray(bucket_config.weights).sum()
351 | subsample_ratio *= bucket_config.total_ratio
352 | else:
353 | raise NotImplementedError(
354 | f"Bucketing type {bucket_config.type} not implemented yet"
355 | )
356 | else:
357 | weights = None
358 |
359 | sampler = WeightedDistributedSampler(
360 | train_dataset,
361 | subsample_ratio,
362 | shuffle=True if not self.params.overfit else False,
363 | weights=weights,
364 | )
365 |
366 | train_loader = DataLoader(
367 | train_dataset,
368 | batch_size=self.params.hyperparams.batch_size,
369 | sampler=sampler,
370 | num_workers=self.params.num_workers,
371 | )
372 |
373 | self.train_loader = train_loader
374 |
375 | return train_loader
376 |
377 | def val_dataloader(self):
378 | _, val_dataset = get_datasets(self.params, self.ponderer, splits=["val"])
379 |
380 | val_sampler = WeightedDistributedSampler(
381 | val_dataset,
382 | shuffle=False,
383 | )
384 |
385 | val_loader = DataLoader(
386 | val_dataset,
387 | batch_size=self.params.hyperparams.batch_size,
388 | sampler=val_sampler,
389 | num_workers=self.params.num_workers,
390 | )
391 |
392 | return val_loader
393 |
394 | @staticmethod
395 | def from_pretrained(path, epoch=30, deepspeed="auto", apply_deepspeed=True):
396 | assert os.path.exists(path), "Path {} does not exist".format(path)
397 |
398 | if deepspeed == "auto":
399 | # Infer whether the checkpoint is in deepspeed format
400 | if os.path.exists(os.path.join(path, "zero_to_fp32.py")):
401 | deepspeed = True
402 | else:
403 | deepspeed = False
404 |
405 | config_path = os.path.join(path, "config.json")
406 |
407 | assert os.path.exists(config_path), "Config path {} does not exist".format(
408 | config_path
409 | )
410 |
411 | config = CarformerConfig.from_pretrained(config_path)
412 |
413 | # Fix quantizer paths if needed
414 | import inspect
415 |
416 | import carformer
417 |
418 | carformer_path = os.path.dirname(inspect.getfile(carformer))
419 | # Go one layer up
420 | carformer_path = os.path.dirname(carformer_path)
421 |
422 | # Replace encoder path
423 | if "rgb_backbone" in config.training:
424 | if config.training.rgb_backbone.frozen:
425 | config.training.rgb_backbone.model_path = os.path.join(
426 | carformer_path, config.training.rgb_backbone.model_path_rel
427 | )
428 | else:
429 | # Do not load path, the weights and config should already exist.
430 | config.training.rgb_backbone.model_path = None
431 |
432 | model = Ponderer(config)
433 |
434 | if epoch in ["last", "best"]:
435 | checkpoint_path = os.path.join(path, "{}_model.pt".format(epoch))
436 | else:
437 | if epoch is not None:
438 | if deepspeed:
439 | checkpoint_path = os.path.join(
440 | path, "{}".format(epoch)
441 | ) # Deepspeed style
442 | else:
443 | checkpoint_path = os.path.join(
444 | path, "epochs", "epoch_{}.pt".format(epoch)
445 | )
446 | else:
447 | checkpoint_path = None
448 |
449 | if deepspeed == True:
450 | import deepspeed as ds
451 |
452 | arg_dict = {}
453 | if checkpoint_path is not None:
454 | arg_dict["checkpoint"] = checkpoint_path
455 |
456 | print("Loading args ", arg_dict)
457 | model = ds.init_inference(model, arg_dict)
458 | else:
459 | checkpoint = torch.load(checkpoint_path, map_location="cpu")
460 | dtype = next(iter(checkpoint["model"].values())).dtype
461 | model.to(dtype)
462 | model.load_state_dict(checkpoint["model"], strict=True)
463 | if apply_deepspeed:
464 | import deepspeed as ds
465 |
466 | return model
467 |
468 | def on_train_epoch_start(self):
469 | # Only do this on rank 0
470 | if not self.trainer.is_global_zero:
471 | return
472 |
473 | if self.params.training.weighted_sampling:
474 | weights = self.sample_weights
475 | self.train_loader.sampler.__iter__()
476 |
477 | indices = self.train_loader.sampler.last_indices
478 |
479 | bucket_names = self.bucket_names
480 |
481 | dist = weights[indices].sum(0) / (
482 | self.train_loader.sampler.num_samples
483 | * self.train_loader.sampler.num_replicas
484 | )
485 | baseline_dist = weights.sum(0) / len(weights)
486 |
487 | print("Weighted sampling statistics, epoch: ", self.trainer.current_epoch)
488 | for bucket_name, dist_val, baseline_val in zip(
489 | bucket_names, dist, baseline_dist
490 | ):
491 | print(
492 | f"Bucket {bucket_name}: {dist_val*100:.2f}% (Baseline: {baseline_val*100:.2f}%)",
493 | end="\t",
494 | )
495 |
496 | def create_optimizer_groups(self, weight_decay):
497 | """
498 | This long function is unfortunately doing something very simple and is
499 | being very defensive:
500 | We are separating out all parameters of the model into two buckets:
501 | those that will experience
502 | weight decay for regularization and those that won't
503 | (biases, and layernorm/embedding weights).
504 | We are then returning the optimizer groups.
505 | """
506 |
507 | # separate out all parameters to those that will and won't experience
508 | # regularizing weight decay
509 | decay = set()
510 | no_decay = set()
511 | whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d)
512 | blacklist_weight_modules = (
513 | torch.nn.LayerNorm,
514 | torch.nn.Embedding,
515 | torch.nn.BatchNorm2d,
516 | )
517 | for mn, m in self.named_modules():
518 | for pn, _ in m.named_parameters():
519 | fpn = f"{mn}.{pn}" if mn else pn # full param name
520 | # if "attn_pool" in fpn:
521 | # import ipdb; ipdb.set_trace()
522 | # print(fpn)
523 | # if len(no_decay&decay) > 0:
524 | # import ipdb; ipdb.set_trace()
525 | # fpn = pn # full param name
526 | if pn.endswith("bias"):
527 | # all biases will not be decayed
528 | no_decay.add(fpn)
529 | elif "attn_pool" in fpn or "vqvae" in fpn:
530 | no_decay.add(fpn)
531 | elif "norm.weight" in fpn:
532 | no_decay.add(fpn)
533 | elif pn.endswith("weight") and isinstance(m, whitelist_weight_modules):
534 | # weights of whitelist modules will be weight decayed
535 | decay.add(fpn)
536 | elif pn.endswith("weight") and isinstance(m, blacklist_weight_modules):
537 | # weights of blacklist modules will NOT be weight decayed
538 | no_decay.add(fpn)
539 | elif (
540 | pn.endswith("weight") and "conv." in pn
541 | ): # Add decay for convolutional layers.
542 | decay.add(fpn)
543 | elif pn.endswith("weight") and ".bn" in pn: # No decay for batch norms.
544 | no_decay.add(fpn)
545 | elif pn.endswith("weight") and ".ln" in pn: # No decay for layer norms.
546 | no_decay.add(fpn)
547 | elif (
548 | pn.endswith("weight") and "downsample.0.weight" in pn
549 | ): # Conv2D layer with stride 2
550 | decay.add(fpn)
551 | elif pn.endswith("weight") and "downsample.1.weight" in pn: # BN layer
552 | no_decay.add(fpn)
553 | elif pn.endswith("weight") and ".attn" in pn: # Attention linear layers
554 | decay.add(fpn)
555 | elif (
556 | pn.endswith("weight") and "channel_to_" in pn
557 | ): # Convolutional layers for channel change
558 | decay.add(fpn)
559 | elif pn.endswith("weight") and ".mlp" in pn: # MLP linear layers
560 | decay.add(fpn)
561 | elif (
562 | pn.endswith("weight") and "target_speed_network" in pn
563 | ): # MLP linear layers
564 | decay.add(fpn)
565 | elif (
566 | pn.endswith("weight") and "join." in pn and not ".norm" in pn
567 | ): # MLP layers
568 | decay.add(fpn)
569 | elif (
570 | pn.endswith("weight") and "join." in pn and ".norm" in pn
571 | ): # Norm layers
572 | no_decay.add(fpn)
573 | elif pn.endswith("weight") and "layernorm" in pn: # Norm layers
574 | no_decay.add(fpn)
575 | elif pn.endswith("weight") and ".norm" in fpn: # Norm layers
576 | no_decay.add(fpn)
577 | elif "class_embedding" in fpn: # cls embeds
578 | no_decay.add(fpn)
579 | elif pn.endswith("_ih") or pn.endswith("_hh"):
580 | # all recurrent weights will not be decayed
581 | no_decay.add(fpn)
582 | elif pn.endswith("_emb") or "_token" in pn:
583 | no_decay.add(fpn)
584 | elif pn.endswith("_embed"):
585 | no_decay.add(fpn)
586 | elif "pos_embed" in pn:
587 | no_decay.add(fpn)
588 | elif "patch_embed" in pn:
589 | decay.add(fpn)
590 | elif "bias_ih_l0" in pn or "bias_hh_l0" in pn:
591 | no_decay.add(fpn)
592 | elif "weight_ih_l0" in pn or "weight_hh_l0" in pn:
593 | decay.add(fpn)
594 | elif "_query" in pn or "weight_hh_l0" in pn:
595 | no_decay.add(fpn)
596 | elif "proj_weight" in pn:
597 | decay.add(fpn)
598 | elif "valid_bev_pixels" in pn:
599 | no_decay.add(fpn)
600 | elif "ls1" in pn or "ls2" in pn:
601 | no_decay.add(fpn)
602 | elif "position_embedding" in pn:
603 | no_decay.add(fpn)
604 |
605 | # validate that we considered every parameter
606 | param_dict = dict(self.named_parameters())
607 | inter_params = decay & no_decay
608 | union_params = decay | no_decay
609 | assert (
610 | len(inter_params) == 0
611 | ), f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
612 | assert len(param_dict.keys() - union_params) == 0, (
613 | f"parameters {str(param_dict.keys() - union_params)} were not "
614 | f"separated into either decay/no_decay set!"
615 | )
616 |
617 | # create the pytorch optimizer object
618 | optim_groups = [
619 | {
620 | "params": [param_dict[pn] for pn in sorted(list(decay))],
621 | "weight_decay": weight_decay,
622 | },
623 | {
624 | "params": [param_dict[pn] for pn in sorted(list(no_decay))],
625 | "weight_decay": 0.0,
626 | },
627 | ]
628 | return optim_groups
629 |
630 | def on_train_batch_end(self, *args, **kwargs):
631 |
632 | ema_enabled = self.ponderer.bev_encoder.ema_enabled
633 | disable_ema_updates = self.ponderer.bev_encoder.disable_ema_update
634 | if disable_ema_updates:
635 | return
636 |
637 | if ema_enabled:
638 | ema_every_steps = getattr(
639 | self.ponderer.config.training, "ema_every_steps", 1
640 | )
641 | ema_start = getattr(self.ponderer.config.training, "ema_start", 0)
642 | ema_end_epoch = getattr(self.ponderer.config.training, "ema_end_epoch", -1)
643 |
644 | if self.global_step % ema_every_steps == 0 and self.global_step > ema_start:
645 | # Stop if end epoch is reached
646 | if ema_end_epoch == -1 or self.trainer.current_epoch < ema_end_epoch:
647 | ema_decay = self.ponderer.config.training.ema_decay
648 |
649 | vision_encoder = self.ponderer.bev_encoder
650 |
651 | model = vision_encoder.model
652 |
653 | ema_model = vision_encoder.ema_model
654 | self.ponderer.ema_update(model, ema_model, beta=ema_decay)
655 |
--------------------------------------------------------------------------------
/carformer/carformer/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .utils import *
2 | from .distributedsampler import *
3 |
--------------------------------------------------------------------------------
/carformer/carformer/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # DistributedDataParallel required imports
4 | import torch.distributed as dist
5 | import torch.multiprocessing as mp
6 | from torch.nn.parallel import DistributedDataParallel as DDP
7 | from torch.distributed import init_process_group, destroy_process_group
8 | import os # For os.environ
9 |
10 |
11 | def disable_printing(is_master):
12 | """
13 | This function disables printing when not in master process
14 | """
15 | import builtins as __builtin__
16 |
17 | builtin_print = __builtin__.print
18 |
19 | def print(*args, **kwargs):
20 | force = kwargs.pop("force", False)
21 | if is_master or force:
22 | builtin_print(*args, **kwargs)
23 |
24 | __builtin__.print = print
25 |
26 |
27 | # For distributed training, we need to use the following function
28 | # This function sets up the environment for distributed training, and lets every process know
29 | # which process it is (rank) and how many processes there are (world_size)
30 | def ddp_setup(args, dist_url="env://", use_deepspeed=False):
31 | # We get the rank and world size from the environment variables
32 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
33 | rank = int(os.environ["RANK"])
34 | world_size = int(os.environ["WORLD_SIZE"])
35 | if args.gpus != world_size:
36 | # Warn
37 | print(
38 | "WARNING: args.gpus != os.environ['WORLD_SIZE']",
39 | flush=True,
40 | )
41 | args.gpus = int(os.environ["WORLD_SIZE"])
42 | else:
43 | raise ValueError("Rank and world size must be set. Please make sure you are using torchrun")
44 |
45 | # Single gpu:
46 | # python train_hydra_ds.py +deepspeed=zero3
47 | # 8 gpus, with 10 cpus per GPU:
48 | # OMP_NUM_THREADS=10 torchrun --nproc_per_node=8 train_hydra_ds.py +deepspeed=zero3
49 |
50 |
51 | print(
52 | "| Initializing process with rank {} out of {} processes |".format(
53 | rank, world_size
54 | ),
55 | flush=True,
56 | )
57 |
58 | # This is a useful hack I like to do sometimes. This DISABLES printing on nodes that are not rank 0, to make the output cleaner
59 | disable_printing(rank == 0)
60 | torch.cuda.set_device(rank)
61 |
62 | if use_deepspeed:
63 | import deepspeed as ds
64 | ds.init_distributed(dist_backend="nccl")
65 | else:
66 | init_process_group(
67 | # backend="nccl", # just GPU, commented out in order to use both CPU and GPU
68 | init_method=dist_url,
69 | rank=rank,
70 | world_size=args.gpus,
71 | )
72 |
73 |
74 | def is_dist_avail_and_initialized():
75 | if not dist.is_available():
76 | return False
77 | if not dist.is_initialized():
78 | return False
79 | return True
80 |
81 |
82 | def get_rank():
83 | if not is_dist_avail_and_initialized():
84 | return 0
85 | return dist.get_rank()
86 |
87 |
88 | def is_main_process():
89 | return get_rank() == 0
90 |
91 |
92 | def save_on_master(*args, **kwargs):
93 | if is_main_process():
94 | torch.save(*args, **kwargs)
95 |
--------------------------------------------------------------------------------
/carformer/carformer/utils/distributedsampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import math
4 | import numpy as np
5 | import warnings
6 |
7 |
8 | class WeightedDistributedSampler(torch.utils.data.Sampler):
9 | """Sampler that restricts data loading to a subset of the dataset.
10 | It is especially useful in conjunction with
11 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
12 | process can pass a DistributedSampler instance as a DataLoader sampler,
13 | and load a subset of the original dataset that is exclusive to it.
14 | .. note::
15 | Dataset is assumed to be of constant size.
16 | Arguments:
17 | dataset: Dataset used for sampling.
18 | num_replicas (optional): Number of processes participating in
19 | distributed training.
20 | bins: a list of bins containing the indices of the dataset in each bin. Bins are chosen at equal probability, and then a sample is chosen from the bin at equal probability.
21 | rank (optional): Rank of the current process within num_replicas.
22 | shuffle (optional): If true (default), sampler will shuffle the indices
23 | split_strategy (optional):
24 | Assuming 4 gpus of index 0-3:
25 | interleaved: 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3
26 | partition: 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3
27 | interleaved is default behavior, closer behavior to using single device
28 | partition is suitable for cases where you want to split the read operations to different disks, ideally used with shuffle=False for preprocessing purposes
29 | """
30 |
31 | def __init__(
32 | self,
33 | dataset,
34 | subsample_ratio=1.0,
35 | num_replicas=None,
36 | rank=None,
37 | weights=None,
38 | shuffle=False,
39 | split_strategy="interleaved"
40 | ):
41 | # If weights, make sure they are a torch tensor
42 | if weights is not None:
43 | weights = torch.tensor(weights, dtype=torch.float)
44 | if num_replicas is None:
45 | if not dist.is_available():
46 | # Default to single process
47 | num_replicas = 1
48 | else:
49 | if dist.is_initialized():
50 | num_replicas = dist.get_world_size()
51 | else:
52 | # Warn user that num_replicas is set to 1
53 | warnings.warn(
54 | "DistributedSampler is initialized without "
55 | "dist being initialized. Set num_replicas "
56 | "to 1 instead."
57 | )
58 | num_replicas = 1
59 | if rank is None:
60 | if not dist.is_available():
61 | # Default to single process
62 | rank = 0
63 | else:
64 | if dist.is_initialized():
65 | rank = dist.get_rank()
66 | else:
67 | # Warn user that rank is set to 0
68 | warnings.warn(
69 | "DistributedSampler is initialized without "
70 | "dist being initialized. Set rank "
71 | "to 0 instead."
72 | )
73 | rank = 0
74 |
75 | self.dataset = dataset
76 | self.num_replicas = num_replicas
77 | self.rank = rank
78 | self.epoch = 0
79 | self.num_samples = int(
80 | math.ceil(len(self.dataset) * 1.0 / self.num_replicas * subsample_ratio)
81 | )
82 | self.split_strategy = split_strategy
83 | assert self.split_strategy in ["interleaved", "partition"]
84 | self.total_size = self.num_samples * self.num_replicas
85 | self.shuffle = shuffle
86 | self.weights = weights
87 |
88 | self.last_indices = None
89 | self.iter_counter = 0
90 |
91 | def __iter__(self):
92 | # exit(0)
93 | if self.iter_counter > 2 and (self.shuffle or self.weights is not None):
94 | print(("="*10+"\n")*2)
95 | print("WARNING: Iterating over the sampler using the same epoch more than 2 times. Double check if the sampler is being reset properly.")
96 | print(("="*10+"\n")*2)
97 |
98 | # deterministically shuffle based on epoch
99 | g = torch.Generator()
100 | g.manual_seed(self.epoch)
101 | if self.weights is not None:
102 | # if self.epoch == 0:
103 | print("Generating indices with weights for epoch {}".format(self.epoch))
104 | indices = torch.multinomial(
105 | self.weights, self.total_size, replacement=True, generator=g
106 | )
107 | self.last_indices = indices
108 | indices = indices.tolist()
109 | # Print histogram of indices
110 | # if self.epoch == 0:
111 | # hist = np.histogram(
112 | # np.asarray(indices), bins=range(0, len(self.dataset) + 1)
113 | # )
114 | # print(hist)
115 | else:
116 | if self.shuffle:
117 | indices = torch.randperm(len(self.dataset), generator=g)
118 | self.last_indices = indices
119 | indices = indices.tolist()
120 | else:
121 | indices = list(range(len(self.dataset)))
122 | self.last_indices = indices
123 |
124 | # add extra samples to make it evenly divisible
125 | if len(indices) < self.total_size:
126 | indices += indices[: (self.total_size - len(indices))]
127 | assert len(indices) == self.total_size
128 |
129 | # subsample
130 | if self.split_strategy == "interleaved":
131 | indices = indices[self.rank : self.total_size : self.num_replicas]
132 | elif self.split_strategy == "partition":
133 | indices = indices[(self.total_size // self.num_replicas)* (self.rank) : (self.total_size // self.num_replicas)* (self.rank+1)]
134 | else:
135 | raise ValueError("split_strategy must be one of 'interleaved' or 'partition'")
136 | # print(self.rank, indices[:100])
137 | assert len(indices) == self.num_samples
138 |
139 | return iter(indices)
140 |
141 | def set_subsample_ratio(self, subsample_ratio):
142 | self.num_samples = int(
143 | math.ceil(len(self.dataset) * 1.0 / self.num_replicas * subsample_ratio)
144 | )
145 | self.total_size = self.num_samples * self.num_replicas
146 |
147 | def __len__(self):
148 | return self.num_samples
149 |
150 | def set_epoch(self, epoch):
151 | self.epoch = epoch
152 | self.iter_counter = 0
153 |
154 |
155 | # # Bin sampler
156 | # class BinDistributedSampler(torch.utils.data.Sampler):
157 | # """Sampler that restricts data loading to a subset of the dataset.
158 | # It is especially useful in conjunction with
159 | # :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
160 | # process can pass a DistributedSampler instance as a DataLoader sampler,
161 | # and load a subset of the original dataset that is exclusive to it.
162 | # .. note::
163 | # Dataset is assumed to be of constant size.
164 | # Arguments:
165 | # dataset: Dataset used for sampling.
166 | # num_replicas (optional): Number of processes participating in
167 | # distributed training.
168 | # rank (optional): Rank of the current process within num_replicas.
169 | # shuffle (optional): If true (default), sampler will shuffle the indices
170 | # """
171 |
172 | # def __init__(
173 | # self,
174 | # dataset,
175 | # subsample_ratio=1.0,
176 | # num_replicas=None,
177 | # rank=None,
178 | # shuffle=False,
179 | # ):
180 | # if num_replicas is None:
181 | # if not dist.is_available():
182 | # # Default to single process
183 | # num_replicas = 1
184 | # else:
185 | # if dist.is_initialized():
186 | # num_replicas = dist.get_world_size()
187 | # else:
188 | # # Warn user that num_replicas is set to 1
189 | # warnings.warn(
190 | # "DistributedSampler is initialized without "
191 | # "dist being initialized. Set num_replicas "
192 | # "to 1 instead."
193 | # )
194 | # num_replicas = 1
195 | # if rank is None:
196 | # if not dist.is_available():
197 | # # Default to single process
198 | # rank = 0
199 | # else:
200 | # if dist.is_initialized():
201 | # rank = dist.get_rank()
202 | # else:
203 | # # Warn user that rank is set to 0
204 | # warnings.warn(
205 | # "DistributedSampler is initialized without "
206 | # "dist being initialized. Set rank "
207 | # "to 0 instead."
208 | # )
209 | # rank = 0
210 |
211 | # self.dataset = dataset
212 | # self.num_replicas = num_replicas
213 | # self.rank = rank
214 | # self.epoch = 0
215 | # self.num_samples = int(
216 | # math.ceil(len(self.dataset) * 1.0 / self.num_replicas * subsample_ratio)
217 | # )
218 | # self.total_size = self.num_samples * self.num_replicas
219 | # self.shuffle = shuffle
220 |
221 | # def __iter__(self):
222 | # # deterministically shuffle based on epoch
223 | # g = torch.Generator()
224 | # g.manual_seed(self
225 |
--------------------------------------------------------------------------------
/carformer/carformer/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from enum import IntEnum
4 | from collections import defaultdict
5 | from collections.abc import MutableMapping
6 | import random
7 |
8 |
9 | def seed_everything(seed):
10 | random.seed(seed)
11 | np.random.seed(seed)
12 | torch.manual_seed(seed)
13 | torch.cuda.manual_seed(seed)
14 |
15 |
16 | class TokenTypeIDs(IntEnum):
17 | GOAL = 4
18 | STATE = 0
19 | BEV = 1
20 | ACTION = 2
21 | REWARD = 3
22 | EOS = -1
23 |
24 |
25 | def deinterleave(interleaved_tensors, interleaved_token_type_ids, axis=1):
26 | """
27 | Deinterleave tensors along the specified axis.
28 | Args:
29 | interleaved_tensors: A tensor of shape (..., step_width, ...) to deinterleave
30 | interleaved_token_type_ids: A tensor of shape (batch, step_width) containing the token type ids
31 | axis: The axis to deinterleave along
32 | Returns:
33 | A dictionary of deinterleaved tensors, keyed by token type id
34 | """
35 | results_by_tokentype = defaultdict(list)
36 |
37 | for tokenType in TokenTypeIDs:
38 | max_axis_length = 0
39 | for batch_idx in range(interleaved_token_type_ids.shape[0]):
40 | # Get the indices of the unique token type ids
41 | token_type_id_indices = torch.where(
42 | interleaved_token_type_ids[batch_idx] == tokenType
43 | )[0]
44 |
45 | # Get the tensor corresponding to the token type id
46 | results_by_tokentype[tokenType].append(
47 | torch.index_select(
48 | interleaved_tensors[batch_idx].unsqueeze(0),
49 | axis,
50 | token_type_id_indices,
51 | )
52 | )
53 | max_axis_length = max(
54 | max_axis_length, results_by_tokentype[tokenType][-1].shape[axis]
55 | )
56 |
57 | # Pad the tensors to the max length
58 | output_shape = list(results_by_tokentype[tokenType][0].shape)
59 | for i in range(len(results_by_tokentype[tokenType])):
60 | output_shape[axis] = (
61 | max_axis_length - results_by_tokentype[tokenType][i].shape[axis]
62 | )
63 | if output_shape[axis] > 0:
64 | results_by_tokentype[tokenType][i] = torch.cat(
65 | [
66 | results_by_tokentype[tokenType][i],
67 | torch.zeros(
68 | output_shape, dtype=results_by_tokentype[tokenType][i].dtype
69 | ).to(results_by_tokentype[tokenType][i].device),
70 | ],
71 | axis=axis,
72 | )
73 |
74 | # Concatenate the tensors
75 | results_by_tokentype[tokenType] = torch.cat(
76 | results_by_tokentype[tokenType], axis=0
77 | )
78 |
79 | return results_by_tokentype
80 |
81 |
82 | # Normalize version compatible with torch tensors
83 | def normalize_angle_torch(x):
84 | x = x % (2 * np.pi) # force in range [0, 2 pi)
85 | x = torch.where(x > np.pi, x - 2 * np.pi, x) # move to [-pi, pi)
86 | return x
87 |
88 |
89 | def unwrap_model(model):
90 | if hasattr(model, "module"):
91 | return model.module
92 | else:
93 | return model
94 |
95 |
96 | # Flatten backbone config nested dicts into a single dict
97 | # a: {b: c} -> {"a.b": c}
98 | def flatten_dict(d, parent_key="", sep="."):
99 | items = []
100 | for k, v in d.items():
101 | new_key = parent_key + sep + k if parent_key else k
102 | if isinstance(v, MutableMapping):
103 | items.extend(flatten_dict(v, new_key, sep=sep).items())
104 | else:
105 | items.append((new_key, v))
106 | return dict(items)
107 |
108 |
109 | def fuzzy_extract_state_dict_from_checkpoint(checkpoint, consume="ponderer."):
110 | if "model" in checkpoint:
111 | checkpoint = checkpoint["model"]
112 |
113 | if consume: # Remove the prefix from the keys
114 | checkpoint = {
115 | (k[len(consume) :] if k.startswith(consume) else k): v
116 | for k, v in checkpoint.items()
117 | }
118 |
119 | return checkpoint
120 |
--------------------------------------------------------------------------------
/carformer/carformer/visualization/visutils.py:
--------------------------------------------------------------------------------
1 | from carformer.utils import TokenTypeIDs
2 | from PIL import Image, ImageDraw, ImageFilter
3 | import torch
4 | import os
5 | import numpy as np
6 | import cv2
7 |
8 |
9 | def extract_target_point_from_trajectory_goals(goals, goal_type):
10 | if "dual_target_point" in goal_type:
11 | goal_idx = goal_type.index("dual_target_point")
12 | elif "target_point" in goal_type:
13 | goal_idx = goal_type.index("target_point")
14 | else:
15 | raise ValueError(
16 | f"Unknown goal type {goal_type}. 'target_point' or 'dual_target_point' must be in the goal types."
17 | )
18 |
19 | return goals.reshape(-1, 2).numpy()
20 |
21 |
22 | def draw_points_on_camera(
23 | camera,
24 | points,
25 | color=(0, 255, 255),
26 | first_origin=None,
27 | radius=4,
28 | blur=-1,
29 | ):
30 | rgb = Image.fromarray(cv2.cvtColor(camera, cv2.COLOR_BGR2RGB))
31 |
32 | if blur > 0:
33 | # Start with a blank image and draw the points on it, blur it and then overlay it on the original image
34 | rgb_tmp = Image.new("RGBA", rgb.size, tuple([*color, 0]))
35 | draw = ImageDraw.Draw(rgb_tmp)
36 | else:
37 | draw = ImageDraw.Draw(rgb)
38 |
39 | for i, pt in enumerate(points):
40 | x, y = pt[:2].astype(int)
41 |
42 | y = first_origin[-1] + y
43 | x = first_origin[0] + x
44 | if x < 0 or x >= rgb.size[0] or y < 0 or y >= rgb.size[-1]:
45 | continue
46 | draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color)
47 |
48 | # point_list = first_origin + point_list * pix_per_meter
49 |
50 | # point_list[0] += i * skip_size
51 |
52 | # for j, point in enumerate(point_list):
53 | # cv2.circle(
54 | # img=canvas,
55 | # center=tuple(point.astype(np.int32)),
56 | # radius=np.rint(radius).astype(int),
57 | # color=tuple(c * (color_decay**j) for c in color),
58 | # thickness=cv2.FILLED,
59 | # )
60 |
61 | if blur > 0:
62 | rgb_tmp = rgb_tmp.filter(ImageFilter.GaussianBlur(blur))
63 | if rgb.mode == "RGB":
64 | rgb = rgb.convert("RGBA")
65 |
66 | rgb = Image.alpha_composite(rgb, rgb_tmp)
67 |
68 | # convert back to RGB
69 | rgb = rgb.convert("RGB")
70 |
71 | # draw.ellipse([first_origin[0]-radius, first_origin[-1]-radius, first_origin[0]+radius, first_origin[-1]+radius], fill=color)
72 |
73 | # Return cv2 format array
74 | return cv2.cvtColor(np.array(rgb), cv2.COLOR_RGB2BGR)
75 |
76 |
77 | def get_action_canvas(
78 | action,
79 | action_types,
80 | size=192,
81 | bev_canvas=None,
82 | copy_bev_canvas=False,
83 | bev_crop_type="front",
84 | pix_per_meter=5,
85 | supp_action=None,
86 | ):
87 | if "path" in action_types:
88 | return get_action_waypoints_path(
89 | action,
90 | action_types,
91 | size=size,
92 | bev_canvas=bev_canvas,
93 | copy_bev_canvas=copy_bev_canvas,
94 | bev_crop_type=bev_crop_type,
95 | pix_per_meter=pix_per_meter,
96 | supp_action=supp_action,
97 | )
98 | else:
99 | raise ValueError(
100 | f"Unknown action type {action_types}. 'path' must be in the action types."
101 | )
102 |
103 |
104 | def get_action_waypoints_path(
105 | action,
106 | action_types,
107 | size=192,
108 | bev_canvas=None,
109 | copy_bev_canvas=False,
110 | bev_crop_type="front",
111 | pix_per_meter=5,
112 | supp_action=None,
113 | ):
114 | if bev_canvas is None:
115 | bev_canvas = np.zeros((size, size * action.shape[0], 3), dtype=np.uint8)
116 | else:
117 | if copy_bev_canvas:
118 | bev_canvas = bev_canvas.copy()
119 |
120 | waypoints = action[0, 20:].copy()
121 |
122 | waypoints[:, -1] *= -1
123 |
124 | path = action[0, :20]
125 |
126 | waypoints = point_to_canvas_coordinates_rel_to_center(waypoints, height=-1.6) / 2
127 | path = point_to_canvas_coordinates_rel_to_center(path, height=-1.6) / 2
128 |
129 | if bev_crop_type == "dualfront":
130 | origin = (size, size // 2)
131 | elif bev_crop_type == "front":
132 | origin = (size // 2, size)
133 | else:
134 | origin = (size // 2, size // 2)
135 |
136 | bev_canvas = draw_points_on_camera(
137 | bev_canvas, waypoints, color=(0, 0, 255), first_origin=origin, radius=5
138 | )
139 |
140 | bev_canvas = draw_points_on_camera(
141 | bev_canvas, path, color=(255, 255, 0), first_origin=origin, radius=2
142 | )
143 | if supp_action is not None:
144 | supp_waypoints = supp_action[0, 20:].copy()
145 | supp_waypoints[:, -1] *= -1
146 | supp_path = supp_action[0, :20]
147 |
148 | supp_waypoints = (
149 | point_to_canvas_coordinates_rel_to_center(supp_waypoints, height=-1.6) / 2
150 | )
151 | supp_path = (
152 | point_to_canvas_coordinates_rel_to_center(supp_path, height=-1.6) / 2
153 | )
154 |
155 | bev_canvas = draw_points_on_camera(
156 | bev_canvas, supp_waypoints, color=(0, 0, 155), first_origin=origin, radius=3
157 | )
158 | bev_canvas = draw_points_on_camera(
159 | bev_canvas, supp_path, color=(155, 0, 155), first_origin=origin, radius=1
160 | )
161 |
162 | if not copy_bev_canvas:
163 | raise ValueError("Not implemented")
164 | return bev_canvas
165 |
166 |
167 | def point_to_canvas_coordinates_rel_to_center(
168 | points, original_size=(1600, 800), height=-1
169 | ): # shape: Nx2
170 | rgb_front_intrinsic = np.asarray(
171 | [
172 | [1142.5184053936916, 0.0, 800.0],
173 | [0.0, 1142.5184053936916, 450.0],
174 | [0.0, 0.0, 1.0],
175 | ]
176 | )
177 |
178 | points = np.stack(
179 | [points[:, 1], np.ones_like(points[:, 0]) * height, points[:, 0]], axis=-1
180 | )
181 | # points = np.concatenate([points, np.ones((points.shape[0], 1))], axis=1)
182 | # If any of points[:, 2:3] are between 1e-4 and -1e-4, then we set them to plus 1e-4
183 | mask = np.logical_and(points[:, 2:3] < 1e-4, points[:, 2:3] > -1e-4)
184 |
185 | points[:, 2:3][mask] = 1e-4
186 |
187 | points = points / points[:, 2:3]
188 | points = (rgb_front_intrinsic @ points.T[:3, :]).T
189 |
190 | points = -points
191 |
192 | points[:, 0] += original_size[0] // 2
193 | points[:, 1] += original_size[1] // 2
194 |
195 | return points
196 |
197 |
198 | def get_goal_canvas(
199 | commands,
200 | command_types,
201 | size=192,
202 | color=(255, 0, 255),
203 | bev_canvas=None,
204 | copy_bev_canvas=False,
205 | bev_crop_type="front",
206 | pix_per_meter=5,
207 | ):
208 | if "command" in command_types:
209 | return get_command_canvas(commands, size=size, color=color)
210 | else:
211 | return get_targetpoint_canvas(
212 | commands,
213 | command_types,
214 | size=size,
215 | color=color,
216 | bev_canvas=bev_canvas,
217 | copy_bev_canvas=copy_bev_canvas,
218 | bev_crop_type=bev_crop_type,
219 | pix_per_meter=pix_per_meter,
220 | )
221 |
222 |
223 | def get_targetpoint_canvas(
224 | commands,
225 | command_types,
226 | color=(255, 0, 255),
227 | size=192,
228 | bev_canvas=None,
229 | copy_bev_canvas=False,
230 | bev_crop_type="front",
231 | pix_per_meter=5,
232 | ):
233 | if bev_canvas is None:
234 | bev_canvas = np.zeros((size, size * commands.shape[0], 3), dtype=np.uint8)
235 | else:
236 | if copy_bev_canvas:
237 | bev_canvas = bev_canvas.copy()
238 |
239 | # For now, assert size of action is 8 and no other action types
240 | assert len(command_types) == 1 and (
241 | command_types[0] == "target_point" or command_types[0] == "dual_target_point"
242 | )
243 |
244 | target_points = extract_target_point_from_trajectory_goals(commands, command_types)
245 | # print(target_points)
246 | target_points = (
247 | point_to_canvas_coordinates_rel_to_center(target_points, height=-1.6) / 2
248 | )
249 |
250 | # print(target_points)
251 | # canvas = draw_points_on_camera(canvas, target_points, color=(255, 0, 255), first_origin=origin, radius=10, blur=8)
252 |
253 | if bev_crop_type == "dualfront":
254 | origin = (size, size // 2)
255 | elif bev_crop_type == "front":
256 | origin = (size // 2, size)
257 | else:
258 | origin = (size // 2, size // 2)
259 |
260 | # Draw the target points
261 | bev_canvas = draw_points_on_camera(
262 | bev_canvas, target_points, color=color, first_origin=origin, radius=12, blur=8
263 | )
264 |
265 | if not copy_bev_canvas:
266 | return None
267 | return bev_canvas
268 |
269 |
270 | def visualize_trajectory_action_predictions(
271 | batch,
272 | batch_outputs,
273 | save_dir=None,
274 | labels=None,
275 | save_idx=0,
276 | action_source="transformer-regression",
277 | save_suffix="",
278 | do_write=True,
279 | visualize_gt=True,
280 | *args,
281 | **kwargs,
282 | ):
283 | # Move batch to cpu
284 | # batch = {k: v.cpu() for k, v in batch.items() if isinstance(v, torch.Tensor)}
285 | # Copy batch shallowly
286 | batch = {k: v for k, v in batch.items()}
287 |
288 | # Convert labels to empty dict if None
289 | if labels is None:
290 | labels = {}
291 |
292 | if action_source == "transformer":
293 | if "output_dict" in batch_outputs:
294 | batch_outputs = batch_outputs["output_dict"]
295 | batch_outputs = {k: v for k, v in batch_outputs.items()}
296 |
297 | # Change the batch action to be the predicted action
298 | pred_actions = batch_outputs[TokenTypeIDs.ACTION].to(batch["action"].device)
299 | elif action_source == "gru":
300 | if "waypoints" in batch_outputs:
301 | pred_actions = batch_outputs["waypoints"].to(batch["action"].device)
302 | else:
303 | pred_actions = batch_outputs.to(batch["action"].device)
304 | elif action_source == "transformer-regression":
305 | waypoints = batch_outputs["action"]["wps"].float().detach() # 1x10x2
306 | path = batch_outputs["action"]["path"].float().detach() # 1x20x2
307 |
308 | waypoints = waypoints.cumsum(-2)
309 | path = path.cumsum(-2)
310 | if len(waypoints.shape) == 3:
311 | waypoints = waypoints.unsqueeze(1)
312 | path = path.unsqueeze(1)
313 |
314 | if visualize_gt:
315 | supp_action = batch["action"] # [save_idx].unsqueeze(0)
316 | else:
317 | supp_action = None
318 |
319 | batch["action"] = torch.zeros(
320 | (waypoints.shape[0], 1, 30, 2), dtype=path.dtype, device=path.device
321 | )
322 | batch["action"][:, :, :20, :] = path
323 | batch["action"][:, :, 20:, :] = waypoints
324 | else:
325 | raise ValueError(f"Unknown action source {action_source}")
326 |
327 | # pred_actions_len = pred_actions.shape[0]
328 |
329 | # batch["action"] = torch.cat(
330 | # (
331 | # batch["action"],
332 | # pred_actions.reshape(
333 | # batch["action"].shape[0], -1, batch["action"].shape[-1]
334 | # ),
335 | # ),
336 | # dim=1,
337 | # )
338 | # import ipdb; ipdb.set_trace()
339 | # return visualize_trajectory(batch, save_dir, save_idx=save_idx, *args, **kwargs)
340 | return visualize_input_from_batch(
341 | batch=batch,
342 | batch_idx=save_idx,
343 | batch_outputs=batch_outputs,
344 | labels=labels,
345 | save_dir=save_dir,
346 | save_affix=f"pred_" + "{}".format(save_suffix),
347 | model=kwargs.get("model", None),
348 | save_prefix=None,
349 | do_write=do_write,
350 | supp_action=supp_action,
351 | )
352 |
353 |
354 | def get_bev_canvas(
355 | batch,
356 | batch_idx,
357 | model=None,
358 | batch_outputs=None,
359 | labels=None,
360 | include_targets=True,
361 | use_target_mask_if_available=True,
362 | ):
363 | if model is not None:
364 | if "rgb_front" in batch:
365 | bev_mode = "rgb_front"
366 | else:
367 | raise ValueError(
368 | "Model is provided but no rgb_front in batch. "
369 | "This is unexpected, please check your model and batch."
370 | )
371 | else:
372 | if "rgb_front" in batch:
373 | bev_mode = "rgb_front"
374 | else:
375 | raise ValueError("No rgb in batch")
376 |
377 | targets_rgb = None
378 | if bev_mode == "rgb_front":
379 | # shape: BxTxHxWx3
380 | if use_target_mask_if_available and "goal_mask" in batch:
381 | rgb_front = batch["goal_mask"]
382 | else:
383 | rgb_front = batch["rgb_front"]
384 |
385 | gt_rgb = (
386 | rgb_front[batch_idx]
387 | .cpu()
388 | .numpy()
389 | .transpose((0, 1, 2, 3))
390 | .astype(np.float32)
391 | / 255.0
392 | )
393 |
394 | gt_rgb = gt_rgb[:1].reshape(-1, *gt_rgb.shape[2:])
395 | gt_reproduced = None
396 | preds_rgb = None
397 | if include_targets:
398 | if "target_rgb_front" in batch:
399 | targets_rgb = (
400 | batch["target_rgb_front"][batch_idx]
401 | .cpu()
402 | .numpy()
403 | .transpose((1, 0, 2, 3))
404 | .astype(np.float32)
405 | / 255.0
406 | )
407 | targets_rgb = targets_rgb.squeeze(1)
408 |
409 | waypoints = batch["frc_wps"][batch_idx, 0].cpu().numpy().copy()
410 |
411 | waypoints[:, -1] *= -1
412 |
413 | waypoints_wrld = (
414 | point_to_canvas_coordinates_rel_to_center(waypoints, height=-1.6)
415 | / 2
416 | )
417 |
418 | speed = batch["frc_speed"][batch_idx, 0].cpu().item()
419 | goals = batch["frc_goal"][batch_idx, 0].cpu()
420 | goals = (
421 | point_to_canvas_coordinates_rel_to_center(goals, height=-1.6) / 2
422 | )
423 |
424 | targets_rgb_pil = Image.fromarray((targets_rgb * 255).astype(np.uint8))
425 | from PIL import ImageDraw, ImageFont
426 |
427 | draw = ImageDraw.Draw(targets_rgb_pil)
428 |
429 | try:
430 | font = ImageFont.truetype(
431 | "/usr/share/fonts/dejavu/DejaVuSans.ttf", 15
432 | )
433 | except IOError:
434 | font = ImageFont.load_default()
435 |
436 | # Get center of the image
437 | width, height = targets_rgb_pil.size
438 |
439 | # Calculate the position for the text
440 | position = (0, height // 2)
441 | # if x > y:
442 | # break
443 | # Draw the text at the calculated position with the specified angle
444 | draw.text(
445 | position,
446 | f"Speed: {speed:.2f}\n Waypoints: {waypoints}\n Goals: {goals}",
447 | font=font,
448 | fill=(255, 255, 255, 128),
449 | anchor="lm",
450 | )
451 |
452 | targets_rgb = np.array(targets_rgb_pil)
453 |
454 | draw_points_on_camera(
455 | targets_rgb,
456 | waypoints_wrld,
457 | color=(255, 0, 0),
458 | first_origin=(width // 2, height // 2),
459 | radius=4,
460 | )
461 | draw_points_on_camera(
462 | targets_rgb,
463 | goals,
464 | color=(200, 200, 0),
465 | first_origin=(width // 2, height // 2),
466 | radius=2,
467 | )
468 |
469 | targets_rgb = targets_rgb.astype(np.float32) / 255.0
470 |
471 | if gt_reproduced is not None:
472 | gt_reproduced = gt_reproduced.reshape(
473 | gt_reproduced.shape[0], -1, gt_reproduced.shape[-1]
474 | )
475 |
476 | rows = [x for x in [gt_rgb, gt_reproduced, targets_rgb, preds_rgb] if x is not None]
477 | canvas_unit_size = gt_rgb.shape[0]
478 |
479 | canvas = np.zeros((canvas_unit_size * len(rows), gt_rgb.shape[1], 3))
480 |
481 | canvas = np.concatenate(rows, axis=0)
482 |
483 | # Convert canvas of 0-1 floats to 0-255 uint8
484 | canvas = (canvas * 255).astype(np.uint8)
485 | # Rgb to bgr
486 | canvas = canvas[:, :, ::-1]
487 |
488 | return canvas, canvas_unit_size
489 |
490 |
491 | def visualize_input_from_batch(
492 | batch,
493 | batch_idx,
494 | batch_outputs,
495 | labels,
496 | save_dir,
497 | save_affix,
498 | model,
499 | save_prefix,
500 | do_write=True,
501 | supp_action=None,
502 | use_target_mask_if_available=True,
503 | ):
504 | canvas, img_size = get_bev_canvas(
505 | batch=batch,
506 | batch_idx=batch_idx,
507 | model=model,
508 | batch_outputs=batch_outputs,
509 | labels=labels,
510 | use_target_mask_if_available=use_target_mask_if_available,
511 | )
512 |
513 | canvas_to_reuse = canvas
514 |
515 | # Add a black copy of the canvas
516 | black = np.zeros_like(canvas_to_reuse)
517 |
518 | canvas_to_reuse = np.concatenate([canvas_to_reuse, black], axis=0)
519 |
520 | actions = batch["action"][batch_idx].cpu().numpy()
521 | if supp_action is not None:
522 | supp_action = supp_action[batch_idx].cpu().numpy()
523 |
524 | action_canvas = get_action_canvas(
525 | actions,
526 | ["path", "waypoints"],
527 | size=img_size,
528 | bev_canvas=canvas_to_reuse,
529 | copy_bev_canvas=True,
530 | bev_crop_type="dualfront",
531 | pix_per_meter=3,
532 | supp_action=supp_action,
533 | )
534 |
535 | goal = batch["goal"][batch_idx].cpu()
536 | goal_canvas = get_goal_canvas(
537 | goal,
538 | ["target_point"],
539 | size=img_size,
540 | bev_canvas=action_canvas,
541 | copy_bev_canvas=True,
542 | bev_crop_type="dualfront",
543 | pix_per_meter=3,
544 | )
545 |
546 | if not do_write:
547 | return goal_canvas
548 |
549 | impath = os.path.join(
550 | save_dir,
551 | f"{save_prefix}_predictions" if save_prefix else "",
552 | "epoch_{}.png".format(save_affix),
553 | )
554 |
555 | # Save as epoch_{epoch}.png in the log directory
556 | cv2.imwrite(
557 | impath,
558 | goal_canvas,
559 | )
560 |
561 | return impath
562 |
--------------------------------------------------------------------------------
/carformer/requirements.txt:
--------------------------------------------------------------------------------
1 | PyYAML
2 | torch>=2.0.0
3 | transformers>4.30
4 | tokenizers
5 | plotly
6 | omegaconf==2.3.0
7 | hydra-core==1.2.0
8 | mpi4py==3.1.4
9 | #deepspeed==0.10.1
10 |
--------------------------------------------------------------------------------
/carformer/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name="carformer",
5 | version="0.2",
6 | author="Shadi Hamdan et al.",
7 | author_email="shamdan17@ku.edu.tr",
8 | description="Transformers for Sequential Decision making in Autonomous Driving",
9 | long_description="Transformers for Sequential Decision making in Autonomous Driving",
10 | long_description_content_type="text",
11 | url="https://github.com/Shamdan17/ETA",
12 | project_urls={
13 | "Bug Tracker": "https://github.com/Shamdan17/ETA/issues",
14 | },
15 | classifiers=[
16 | "Programming Language :: Python :: 3",
17 | "License :: OSI Approved :: MIT License",
18 | "Operating System :: OS Independent",
19 | ],
20 | packages=setuptools.find_packages(where="."),
21 | python_requires=">=3.7",
22 | )
23 |
--------------------------------------------------------------------------------
/carformer/train.py:
--------------------------------------------------------------------------------
1 | from carformer.ponderer import Ponderer
2 | from carformer.ponderer_lit import PondererLit
3 |
4 | from carformer.utils import seed_everything
5 | import os
6 | from carformer.config import config_init, CarformerConfig
7 | from omegaconf import OmegaConf
8 | import hydra
9 | from skit.distributed.deepspeed_utils import ZeroLightModelCheckpoint
10 |
11 | config_init()
12 |
13 |
14 | @hydra.main(version_base="1.1", config_path="carformer/config", config_name="config")
15 | def main(cfg):
16 | seed_everything(cfg.seed)
17 |
18 | args = cfg
19 |
20 | if args.ckpt_path is not None:
21 | print("Loading model from checkpoint")
22 | cfg.save_dir = os.path.dirname(args.ckpt_path)
23 |
24 | config = CarformerConfig.from_hydra(args)
25 |
26 | print(config)
27 |
28 | model = Ponderer(config)
29 |
30 | model = PondererLit(model).train()
31 |
32 | ds_cfg = OmegaConf.to_container(cfg.deepspeed, resolve=True)
33 |
34 | folder = cfg.save_dir
35 |
36 | from lightning import Trainer
37 | from lightning.pytorch.loggers import WandbLogger
38 | from lightning.pytorch.callbacks import (
39 | ModelCheckpoint,
40 | LearningRateMonitor,
41 | )
42 | from lightning.pytorch.strategies import DeepSpeedStrategy
43 |
44 | callbacks = [LearningRateMonitor(logging_interval="step")]
45 |
46 | if (not args.overfit) or args.force_save:
47 | callbacks.extend(
48 | [
49 | ModelCheckpoint(
50 | dirpath=os.path.join(folder),
51 | filename="last_model",
52 | every_n_epochs=0,
53 | save_last=True,
54 | ),
55 | ]
56 | )
57 |
58 | callbacks.extend(
59 | [
60 | ZeroLightModelCheckpoint(
61 | dirpath=os.path.join(folder, "epochs"),
62 | filename="{epoch}",
63 | monitor="val/loss/wp_loss",
64 | mode="min",
65 | every_n_epochs=args.save_every,
66 | save_top_k=-1,
67 | save_weights_only=True,
68 | start_save_epoch=args.start_saving_epoch,
69 | )
70 | ]
71 | )
72 |
73 | if args.wandb_tag:
74 | tags = [args.wandb_tag]
75 | else:
76 | tags = []
77 |
78 | if not (args.overfit or args.debug) or args.force_log:
79 | log = True
80 | else:
81 | log = False
82 |
83 | trainer = Trainer(
84 | accelerator="gpu" if not args.cpu else "cpu",
85 | devices=args.gpus,
86 | num_nodes=args.nodes,
87 | strategy=(DeepSpeedStrategy(config=ds_cfg)),
88 | max_epochs=args.hyperparams.num_epochs,
89 | logger=(
90 | WandbLogger(
91 | project=args.logging.project,
92 | entity=args.logging.entity,
93 | mode=args.logging.mode,
94 | tags=tags,
95 | offline=True,
96 | save_code=True,
97 | )
98 | if log
99 | else None
100 | ),
101 | callbacks=callbacks,
102 | accumulate_grad_batches=args.hyperparams.gradient_accumulation_steps,
103 | use_distributed_sampler=False,
104 | overfit_batches=args.overfit_batches if args.overfit else 0.0,
105 | limit_val_batches=0 if args.overfit else 1.0,
106 | log_every_n_steps=1 if args.overfit else 25,
107 | enable_checkpointing=False if args.overfit else True,
108 | )
109 |
110 | trainer.fit(model, ckpt_path=args.ckpt_path)
111 |
112 |
113 | if __name__ == "__main__":
114 | main()
115 |
--------------------------------------------------------------------------------
/docs/TRAIN_EVAL.md:
--------------------------------------------------------------------------------
1 | ## Training Setup
2 | ### Bench2Drive data
3 | Download the [Bench2Drive](https://github.com/Thinklab-SJTU/Bench2Drive) dataset and unzip all the directories to a single folder. The dataset should have a structure similar to the following:
4 |
5 | ```
6 | MainFolder/
7 | HazardAtSideLaneTwoWays_Town12_Route1133_Weather15/
8 | ParkingCutIn_Town12_Route765_Weather11/
9 | ```
10 |
11 | ### Libraries
12 | Install the [skit](https://github.com/Shamdan17/skit) toolkit for distributed sampling and training wrappers. More dependencies are listed in the `requirements.txt` file. Install them using pip:
13 |
14 | ```bash
15 | pip install -r requirements.txt
16 | ```
17 |
18 | ### Config files
19 | Under `carformer/config/user`, create a yaml file following the `example.yaml` example.
20 |
21 | ```yaml
22 | dataset_dir: /PATH/TO/DATASET/FOLDER/
23 | working_dir: /PATH/TO/carformer/
24 |
25 | wandb_entity: WANDBUSERNAME
26 | ```
27 |
28 | Furthermore, modify the Makefile to update the username to the config name you just created.
29 |
30 | ### Training
31 |
32 | To train the base and async models on 8 nodes with 4 GPUs each, run the following commands from the `carformer` directory. We run these commands in a SLURM job. Directly using make will not properly parallelize the training across multiple nodes, so modify it accordingly if not working in a SLURM environment.
33 |
34 | ```bash
35 | make ETA_base_model_s42
36 | make ETA_async_model_s42
37 | ```
38 |
39 | ## Evaluation Setup
40 |
41 | ### Bench2Drive
42 |
43 | [Bench2Drive](https://github.com/Thinklab-SJTU/Bench2Drive) is required for evaluation. Please follow the "Eval Tools" section.
44 |
45 | ### File setup:
46 |
47 | Only follow these steps **AFTER** Bench2Drive is set up following the Bench2Drive instructions. Please place the files found in "misc" and "team_code" in the following structure in the Bench2Drive repo:
48 |
49 | ```
50 | Bench2Drive\
51 | assets\
52 | docs\
53 | leaderboard\
54 | leaderboard\
55 | --> Copy "leaderboard_evaluator_local.py" from the "misc" folder here
56 | scripts\
57 | --> Copy "run_eval_leaderboard.py" from the "misc" folder here
58 | team_code\
59 | --> Copy files from "team_code" folder here
60 | scenario_runner\
61 | tools\
62 | ```
63 |
64 | ### Config file setup:
65 |
66 | For evaluation, you need to update the config files under teamcode/config:
67 |
68 | ```yaml
69 | working_dir: /path/to/Bench2Drive
70 | b2d_path: /path/to/Bench2Drive
71 | ```
72 |
73 | ### Download checkpoints:
74 |
75 | Checkpoints uploading is in progress.
76 |
77 | ### Running the evaluation:
78 |
79 | **Note:** Unlike Bench2Drive's setup, this evaluation code requires a separate instance of CARLA running. It will NOT run CARLA for you.
80 |
81 | #### Running carla
82 | You can run CARLA on port 30000 persistently (restarts 10 seconds after crashing) with the following command:
83 |
84 | ```bash
85 | while : ; do ./CarlaUE4.sh -carla-rpc-port=30000 ; sleep 10 ; done
86 | ```
87 |
88 | #### Running the evaluation
89 | From the Bench2Drive directory, run the following command to evaluate MODEL_NAME using carla at port 30000. Update the "user" to your username.
90 |
91 | ```bash
92 | python leaderboard/scripts/run_eval_leaderboard.py user=shadi port=30000 trafficManagerPort=20000 experiments=Ponderer viz=0 experiments.ponderer_model_name=MODEL_NAME checkpoint_file=results.json experiments.agent_root=/PATH/TO/CHECKPOINT/MODEL_NAME experiments.root_path=/PATH/TO/CHECKPOINT/MODEL_NAME/ experiments.runnickname=NICKNAMEHERE resume=0 experiments.epoch_num=37
93 | ```
94 |
--------------------------------------------------------------------------------
/misc/leaderboard_evaluator_local.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Copyright (c) 2018-2019 Intel Corporation.
3 | # authors: German Ros (german.ros@intel.com), Felipe Codevilla (felipe.alcm@gmail.com)
4 | #
5 | # This work is licensed under the terms of the MIT license.
6 | # For a copy, see .
7 |
8 | """
9 | CARLA Challenge Evaluator Routes
10 |
11 | Provisional code to evaluate Autonomous Agents for the CARLA Autonomous Driving challenge
12 | """
13 | from __future__ import print_function
14 |
15 | import traceback
16 | import argparse
17 | from argparse import RawTextHelpFormatter
18 | from distutils.version import LooseVersion
19 | import importlib
20 | import os
21 | import pkg_resources
22 | import sys
23 | import carla
24 | import signal
25 |
26 | from srunner.scenariomanager.carla_data_provider import *
27 | from srunner.scenariomanager.timer import GameTime
28 | from srunner.scenariomanager.watchdog import Watchdog
29 |
30 | from leaderboard.scenarios.scenario_manager import ScenarioManager
31 | from leaderboard.scenarios.route_scenario import RouteScenario
32 | from leaderboard.envs.sensor_interface import SensorConfigurationInvalid
33 | from leaderboard.autoagents.agent_wrapper import (
34 | AgentError,
35 | validate_sensor_configuration,
36 | TickRuntimeError,
37 | )
38 | from leaderboard.utils.statistics_manager import StatisticsManager, FAILURE_MESSAGES
39 | from leaderboard.utils.route_indexer import RouteIndexer
40 | import atexit
41 | import subprocess
42 | import time
43 | import random
44 | from datetime import datetime
45 |
46 | sensors_to_icons = {
47 | "sensor.camera.rgb": "carla_camera",
48 | "sensor.lidar.ray_cast": "carla_lidar",
49 | "sensor.other.radar": "carla_radar",
50 | "sensor.other.gnss": "carla_gnss",
51 | "sensor.other.imu": "carla_imu",
52 | "sensor.opendrive_map": "carla_opendrive_map",
53 | "sensor.speedometer": "carla_speedometer",
54 | }
55 |
56 | import socket
57 |
58 |
59 | def find_free_port(starting_port):
60 | port = starting_port
61 | while True:
62 | try:
63 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
64 | s.bind(("localhost", port))
65 | return port
66 | except OSError:
67 | port += 1
68 |
69 |
70 | def get_weather_id(weather_conditions, args, config):
71 | from xml.etree import ElementTree as ET
72 |
73 | tree = ET.parse(os.path.join(args.b2d_path, "leaderboard/data/weather.xml"))
74 | root = tree.getroot()
75 |
76 | def conditions_match(weather, conditions):
77 | for key, value in weather:
78 | if key == "route_percentage":
79 | continue
80 | if str(getattr(conditions, key)) != value:
81 | return False
82 | return True
83 |
84 | for case in root.findall("case"):
85 | weather = case[0].items()
86 | if conditions_match(weather, weather_conditions):
87 | return case.items()[0][1]
88 | return None
89 |
90 |
91 | class LeaderboardEvaluator(object):
92 | """
93 | Main class of the Leaderboard. Everything is handled from here,
94 | from parsing the given files, to preparing the simulation, to running the route.
95 | """
96 |
97 | # Tunable parameters
98 | client_timeout = 300.0 # in seconds
99 | frame_rate = 20.0 # in Hz
100 |
101 | def __init__(self, args, statistics_manager):
102 | """
103 | Setup CARLA client and world
104 | Setup ScenarioManager
105 | """
106 | self.world = None
107 | self.manager = None
108 | self.sensors = None
109 | self.sensors_initialized = False
110 | self.sensor_icons = []
111 | self.agent_instance = None
112 | self.route_scenario = None
113 |
114 | self.statistics_manager = statistics_manager
115 |
116 | # This is the ROS1 bridge server instance. This is not encapsulated inside the ROS1 agent because the same
117 | # instance is used on all the routes (i.e., the server is not restarted between routes). This is done
118 | # to avoid reconnection issues between the server and the roslibpy client.
119 | self._ros1_server = None
120 |
121 | # Setup the simulation
122 | self.client, self.client_timeout, self.traffic_manager = self._setup_simulation(
123 | args
124 | )
125 |
126 | dist = pkg_resources.get_distribution("carla")
127 | if dist.version != "leaderboard":
128 | if LooseVersion(dist.version) < LooseVersion("0.9.10"):
129 | raise ImportError(
130 | "CARLA version 0.9.10.1 or newer required. CARLA version found: {}".format(
131 | dist
132 | )
133 | )
134 |
135 | # Load agent
136 | module_name = os.path.basename(args.agent).split(".")[0]
137 | sys.path.insert(0, os.path.dirname(args.agent))
138 | self.module_agent = importlib.import_module(module_name)
139 |
140 | # Create the ScenarioManager
141 | self.manager = ScenarioManager(
142 | args.timeout, self.statistics_manager, args.debug
143 | )
144 |
145 | # Time control for summary purposes
146 | self._start_time = GameTime.get_time()
147 | self._end_time = None
148 |
149 | # Prepare the agent timer
150 | self._agent_watchdog = None
151 | # signal.signal(signal.SIGINT, self._signal_handler)
152 |
153 | self._client_timed_out = False
154 |
155 | def _signal_handler(self, signum, frame):
156 | """
157 | Terminate scenario ticking when receiving a signal interrupt.
158 | Either the agent initialization watchdog is triggered, or the runtime one at scenario manager
159 | """
160 | if self._agent_watchdog and not self._agent_watchdog.get_status():
161 | raise RuntimeError(
162 | "Timeout: Agent took longer than {}s to setup".format(
163 | self.client_timeout
164 | )
165 | )
166 | elif self.manager:
167 | self.manager.signal_handler(signum, frame)
168 |
169 | def __del__(self):
170 | """
171 | Cleanup and delete actors, ScenarioManager and CARLA world
172 | """
173 | if hasattr(self, "manager") and self.manager:
174 | del self.manager
175 | if hasattr(self, "world") and self.world:
176 | del self.world
177 |
178 | def _get_running_status(self):
179 | """
180 | returns:
181 | bool: False if watchdog exception occured, True otherwise
182 | """
183 | if self._agent_watchdog:
184 | return self._agent_watchdog.get_status()
185 | return False
186 |
187 | def _cleanup(self):
188 | """
189 | Remove and destroy all actors
190 | """
191 | CarlaDataProvider.cleanup()
192 |
193 | if self._agent_watchdog:
194 | self._agent_watchdog.stop()
195 |
196 | try:
197 | if self.agent_instance:
198 | self.agent_instance.destroy()
199 | self.agent_instance = None
200 | except Exception as e:
201 | print("\n\033[91mFailed to stop the agent:", flush=True)
202 | print(f"\n{traceback.format_exc()}\033[0m", flush=True)
203 |
204 | if self.route_scenario:
205 | self.route_scenario.remove_all_actors()
206 | self.route_scenario = None
207 | if self.statistics_manager:
208 | self.statistics_manager.remove_scenario()
209 |
210 | if self.manager:
211 | self._client_timed_out = not self.manager.get_running_status()
212 | self.manager.cleanup()
213 |
214 | # Make sure no sensors are left streaming
215 | alive_sensors = self.world.get_actors().filter("*sensor*")
216 | for sensor in alive_sensors:
217 | sensor.stop()
218 | sensor.destroy()
219 |
220 | def _setup_simulation(self, args):
221 | """
222 | Prepares the simulation by getting the client, and setting up the world and traffic manager settings
223 | """
224 | self.carla_path = os.environ["CARLA_ROOT"]
225 |
226 | attempts = 0
227 | num_max_restarts = 20
228 | while attempts < num_max_restarts:
229 | try:
230 | client = carla.Client(args.host, args.port)
231 | if args.timeout:
232 | client_timeout = args.timeout
233 | client.set_timeout(client_timeout)
234 |
235 | settings = carla.WorldSettings(
236 | synchronous_mode=True,
237 | fixed_delta_seconds=1.0 / self.frame_rate,
238 | deterministic_ragdolls=True,
239 | spectator_as_ego=False,
240 | )
241 | client.get_world().apply_settings(settings)
242 | print(f"load_world success , attempts={attempts}", flush=True)
243 | break
244 | except Exception as e:
245 | print(f"load_world failed , attempts={attempts}", flush=True)
246 | print(e, flush=True)
247 | attempts += 1
248 | time.sleep(5)
249 | attempts = 0
250 | num_max_restarts = 40
251 | while attempts < num_max_restarts:
252 | try:
253 | args.traffic_manager_port = find_free_port(args.traffic_manager_port)
254 | traffic_manager = client.get_trafficmanager(args.traffic_manager_port)
255 | traffic_manager.set_synchronous_mode(True)
256 | traffic_manager.set_hybrid_physics_mode(True)
257 | print(f"traffic_manager init success, try_time={attempts}", flush=True)
258 | break
259 | except Exception as e:
260 | print(f"traffic_manager init fail, try_time={attempts}", flush=True)
261 | print(e, flush=True)
262 | attempts += 1
263 | time.sleep(5)
264 | return client, client_timeout, traffic_manager
265 |
266 | def _reset_world_settings(self):
267 | """
268 | Changes the modified world settings back to asynchronous
269 | """
270 | # Has simulation failed?
271 | if self.world and self.manager and not self._client_timed_out:
272 | # Reset to asynchronous mode
273 | self.world.tick() # TODO: Make sure all scenario actors have been destroyed
274 | settings = self.world.get_settings()
275 | settings.synchronous_mode = False
276 | settings.fixed_delta_seconds = None
277 | settings.deterministic_ragdolls = False
278 | settings.spectator_as_ego = True
279 | self.world.apply_settings(settings)
280 |
281 | # Make the TM back to async
282 | self.traffic_manager.set_synchronous_mode(False)
283 | self.traffic_manager.set_hybrid_physics_mode(False)
284 |
285 | def _load_and_wait_for_world(self, args, town):
286 | """
287 | Load a new CARLA world without changing the settings and provide data to CarlaDataProvider
288 | """
289 | self.world = self.client.load_world(town, reset_settings=False)
290 |
291 | # Large Map settings are always reset, for some reason
292 | settings = self.world.get_settings()
293 | settings.tile_stream_distance = 650
294 | settings.actor_active_distance = 650
295 | self.world.apply_settings(settings)
296 |
297 | self.world.reset_all_traffic_lights()
298 | CarlaDataProvider.set_client(self.client)
299 | CarlaDataProvider.set_traffic_manager_port(args.traffic_manager_port)
300 | CarlaDataProvider.set_world(self.world)
301 |
302 | # This must be here so that all route repetitions use the same 'unmodified' seed
303 | self.traffic_manager.set_random_device_seed(args.traffic_manager_seed)
304 |
305 | # Wait for the world to be ready
306 | self.world.tick()
307 |
308 | map_name = CarlaDataProvider.get_map().name.split("/")[-1]
309 | if map_name != town:
310 | raise Exception(
311 | "The CARLA server uses the wrong map!"
312 | " This scenario requires the use of map {}".format(town)
313 | )
314 |
315 | def _register_statistics(self, route_index, entry_status, crash_message=""):
316 | """
317 | Computes and saves the route statistics
318 | """
319 | print("\033[1m> Registering the route statistics\033[0m", flush=True)
320 | self.statistics_manager.save_entry_status(entry_status)
321 | self.statistics_manager.compute_route_statistics(
322 | route_index,
323 | self.manager.scenario_duration_system,
324 | self.manager.scenario_duration_game,
325 | crash_message,
326 | )
327 |
328 | def _load_and_run_scenario(self, args, config):
329 | """
330 | Load and run the scenario given by config.
331 |
332 | Depending on what code fails, the simulation will either stop the route and
333 | continue from the next one, or report a crash and stop.
334 | """
335 | crash_message = ""
336 | entry_status = "Started"
337 |
338 | print(
339 | "\n\033[1m========= Preparing {} (repetition {}) =========\033[0m".format(
340 | config.name, config.repetition_index
341 | ),
342 | flush=True,
343 | )
344 |
345 | # Prepare the statistics of the route
346 | route_name = f"{config.name}_rep{config.repetition_index}"
347 | scenario_name = config.scenario_configs[0].name
348 | town_name = str(config.town)
349 | weather_id = get_weather_id(config.weather[0][1], args, config)
350 | currentDateAndTime = datetime.now()
351 | currentTime = currentDateAndTime.strftime("%m_%d_%H_%M_%S")
352 | save_name = (
353 | f"{route_name}_{town_name}_{scenario_name}_{weather_id}_{currentTime}"
354 | )
355 | self.statistics_manager.create_route_data(
356 | route_name, scenario_name, weather_id, save_name, town_name, config.index
357 | )
358 |
359 | print("\033[1m> Loading the world\033[0m", flush=True)
360 |
361 | # Load the world and the scenario
362 | try:
363 | self._load_and_wait_for_world(args, config.town)
364 | self.route_scenario = RouteScenario(
365 | world=self.world, config=config, debug_mode=args.debug
366 | )
367 | self.statistics_manager.set_scenario(self.route_scenario)
368 |
369 | except Exception:
370 | # The scenario is wrong -> set the ejecution to crashed and stop
371 | print("\n\033[91mThe scenario could not be loaded:", flush=True)
372 | print(f"\n{traceback.format_exc()}\033[0m", flush=True)
373 |
374 | entry_status, crash_message = FAILURE_MESSAGES["Simulation"]
375 | self._register_statistics(config.index, entry_status, crash_message)
376 | self._cleanup()
377 | return True
378 |
379 | print("\033[1m> Setting up the agent\033[0m", flush=True)
380 |
381 | # Set up the user's agent, and the timer to avoid freezing the simulation
382 | try:
383 | self._agent_watchdog = Watchdog(args.timeout)
384 | self._agent_watchdog.start()
385 | agent_class_name = getattr(self.module_agent, "get_entry_point")()
386 | agent_class_obj = getattr(self.module_agent, agent_class_name)
387 |
388 | # Start the ROS1 bridge server only for ROS1 based agents.
389 | if (
390 | getattr(agent_class_obj, "get_ros_version")() == 1
391 | and self._ros1_server is None
392 | ):
393 | from leaderboard.autoagents.ros1_agent import ROS1Server
394 |
395 | self._ros1_server = ROS1Server()
396 | self._ros1_server.start()
397 |
398 | self.agent_instance = agent_class_obj(args.host, args.port, args.debug)
399 | self.agent_instance.set_global_plan(
400 | self.route_scenario.gps_route, self.route_scenario.route
401 | )
402 | # args.agent_config = args.agent_config + "+" + save_name
403 | # self.agent_instance.setup(args.agent_config)
404 | self.agent_instance.setup(args)
405 |
406 | # Check and store the sensors
407 | if not self.sensors:
408 | self.sensors = self.agent_instance.sensors()
409 | track = self.agent_instance.track
410 |
411 | # validate_sensor_configuration(self.sensors, track, args.track)
412 |
413 | self.sensor_icons = [
414 | sensors_to_icons[sensor["type"]] for sensor in self.sensors
415 | ]
416 | self.statistics_manager.save_sensors(self.sensor_icons)
417 | self.statistics_manager.write_statistics()
418 |
419 | self.sensors_initialized = True
420 |
421 | self._agent_watchdog.stop()
422 | self._agent_watchdog = None
423 |
424 | except SensorConfigurationInvalid as e:
425 | # The sensors are invalid -> set the ejecution to rejected and stop
426 | print("\n\033[91mThe sensor's configuration used is invalid:", flush=True)
427 | print(f"{e}\033[0m\n", flush=True)
428 |
429 | entry_status, crash_message = FAILURE_MESSAGES["Sensors"]
430 | self._register_statistics(config.index, entry_status, crash_message)
431 | self._cleanup()
432 | return True
433 |
434 | except Exception as e:
435 | # The agent setup has failed -> start the next route
436 | print("\n\033[91mCould not set up the required agent:", flush=True)
437 | print(f"\n{traceback.format_exc()}\033[0m", flush=True)
438 | print(f"{e}\033[0m\n", flush=True)
439 |
440 | entry_status, crash_message = FAILURE_MESSAGES["Agent_init"]
441 | self._register_statistics(config.index, entry_status, crash_message)
442 | self._cleanup()
443 | return True
444 |
445 | print("\033[1m> Running the route\033[0m", flush=True)
446 |
447 | # Run the scenario
448 | try:
449 | # Load scenario and run it
450 | if args.record:
451 | self.client.start_recorder(
452 | "{}/{}_rep{}.log".format(
453 | args.record, config.name, config.repetition_index
454 | )
455 | )
456 | self.manager.load_scenario(
457 | self.route_scenario,
458 | self.agent_instance,
459 | config.index,
460 | config.repetition_index,
461 | )
462 | self.manager.tick_count = 0
463 | self.manager.run_scenario()
464 |
465 | except AgentError:
466 | # The agent has failed -> stop the route
467 | print("\n\033[91mStopping the route, the agent has crashed:", flush=True)
468 | print(f"\n{traceback.format_exc()}\033[0m")
469 |
470 | entry_status, crash_message = FAILURE_MESSAGES["Agent_runtime"]
471 |
472 | except KeyboardInterrupt:
473 | return True
474 |
475 | except TickRuntimeError:
476 | entry_status, crash_message = "Started", "TickRuntime"
477 |
478 | except Exception:
479 | print("\n\033[91mError during the simulation:", flush=True)
480 | print(f"\n{traceback.format_exc()}\033[0m", flush=True)
481 |
482 | entry_status, crash_message = FAILURE_MESSAGES["Simulation"]
483 |
484 | # Stop the scenario
485 | try:
486 | print("\033[1m> Stopping the route\033[0m", flush=True)
487 | self.manager.stop_scenario()
488 | self._register_statistics(config.index, entry_status, crash_message)
489 |
490 | if args.record:
491 | self.client.stop_recorder()
492 |
493 | self._cleanup()
494 |
495 | except Exception:
496 | print(
497 | "\n\033[91mFailed to stop the scenario, the statistics might be empty:",
498 | flush=True,
499 | )
500 | print(f"\n{traceback.format_exc()}\033[0m", flush=True)
501 |
502 | _, crash_message = FAILURE_MESSAGES["Simulation"]
503 |
504 | # If the simulation crashed, stop the leaderboard, for the rest, move to the next route
505 | return crash_message == "Simulation crashed"
506 |
507 | def run(self, args):
508 | """
509 | Run the challenge mode
510 | """
511 | route_indexer = RouteIndexer(args.routes, args.repetitions, args.routes_subset)
512 |
513 | if args.resume:
514 | resume = route_indexer.validate_and_resume(args.checkpoint)
515 | else:
516 | resume = False
517 |
518 | if resume:
519 | self.statistics_manager.add_file_records(args.checkpoint)
520 | else:
521 | self.statistics_manager.clear_records()
522 | self.statistics_manager.save_progress(route_indexer.index, route_indexer.total)
523 | self.statistics_manager.write_statistics()
524 |
525 | crashed = False
526 | t1 = time.time()
527 | while route_indexer.peek() and not crashed:
528 |
529 | # Run the scenario
530 | config = route_indexer.get_next_config()
531 | crashed = self._load_and_run_scenario(args, config)
532 | # Save the progress and write the route statistics
533 | self.statistics_manager.save_progress(
534 | route_indexer.index, route_indexer.total
535 | )
536 | self.statistics_manager.write_statistics()
537 | if crashed:
538 | print(
539 | f"{route_indexer.index} crash, [{route_indexer.index}/{route_indexer.total}], please restart",
540 | flush=True,
541 | )
542 | break
543 |
544 | # Shutdown ROS1 bridge server if necessary
545 | if self._ros1_server is not None:
546 | self._ros1_server.shutdown()
547 |
548 | # Go back to asynchronous mode
549 | self._reset_world_settings()
550 |
551 | if not crashed:
552 | # Save global statistics
553 | print(f"cost time={time.time()-t1}", flush=True)
554 | print("\033[1m> Registering the global statistics\033[0m", flush=True)
555 | self.statistics_manager.compute_global_statistics()
556 | self.statistics_manager.validate_and_write_statistics(
557 | self.sensors_initialized, crashed
558 | )
559 |
560 | return crashed
561 |
562 |
563 | def main_eval(arguments):
564 | # def main():
565 | # description = (
566 | # "CARLA AD Leaderboard Evaluation: evaluate your Agent in CARLA scenarios\n"
567 | # )
568 |
569 | # # general parameters
570 | # parser = argparse.ArgumentParser(
571 | # description=description, formatter_class=RawTextHelpFormatter
572 | # )
573 | # parser.add_argument(
574 | # "--host", default="localhost", help="IP of the host server (default: localhost)"
575 | # )
576 | # parser.add_argument(
577 | # "--port", default=2000, type=int, help="TCP port to listen to (default: 2000)"
578 | # )
579 | # parser.add_argument(
580 | # "--traffic-manager-port",
581 | # default=8000,
582 | # type=int,
583 | # help="Port to use for the TrafficManager (default: 8000)",
584 | # )
585 | # parser.add_argument(
586 | # "--traffic-manager-seed",
587 | # default=0,
588 | # type=int,
589 | # help="Seed used by the TrafficManager (default: 0)",
590 | # )
591 | # parser.add_argument("--debug", type=int, help="Run with debug output", default=0)
592 | # parser.add_argument(
593 | # "--record",
594 | # type=str,
595 | # default="",
596 | # help="Use CARLA recording feature to create a recording of the scenario",
597 | # )
598 | # parser.add_argument(
599 | # "--timeout",
600 | # default=600.0,
601 | # type=float,
602 | # help="Set the CARLA client timeout value in seconds",
603 | # )
604 |
605 | # # simulation setup
606 | # parser.add_argument(
607 | # "--routes", required=True, help="Name of the routes file to be executed."
608 | # )
609 | # parser.add_argument(
610 | # "--routes-subset", default="", type=str, help="Execute a specific set of routes"
611 | # )
612 | # parser.add_argument(
613 | # "--repetitions", type=int, default=1, help="Number of repetitions per route."
614 | # )
615 |
616 | # # agent-related options
617 | # parser.add_argument(
618 | # "-a",
619 | # "--agent",
620 | # type=str,
621 | # help="Path to Agent's py file to evaluate",
622 | # required=True,
623 | # )
624 | # parser.add_argument(
625 | # "--agent-config",
626 | # type=str,
627 | # help="Path to Agent's configuration file",
628 | # default="",
629 | # )
630 |
631 | # parser.add_argument(
632 | # "--track", type=str, default="SENSORS", help="Participation track: SENSORS, MAP"
633 | # )
634 | # parser.add_argument(
635 | # "--resume",
636 | # type=bool,
637 | # default=False,
638 | # help="Resume execution from last checkpoint?",
639 | # )
640 | # parser.add_argument(
641 | # "--checkpoint",
642 | # type=str,
643 | # default="./simulation_results.json",
644 | # help="Path to checkpoint used for saving statistics and resuming",
645 | # )
646 | # parser.add_argument(
647 | # "--debug-checkpoint",
648 | # type=str,
649 | # default="./live_results.txt",
650 | # help="Path to checkpoint used for saving live results",
651 | # )
652 | # parser.add_argument("--gpu-rank", type=int, default=0)
653 | # arguments = parser.parse_args()
654 |
655 | statistics_manager = StatisticsManager(
656 | arguments.checkpoint, arguments.debug_checkpoint
657 | )
658 | leaderboard_evaluator = LeaderboardEvaluator(arguments, statistics_manager)
659 | crashed = leaderboard_evaluator.run(arguments)
660 |
661 | del leaderboard_evaluator
662 |
663 |
664 | def main():
665 | description = (
666 | "CARLA AD Leaderboard Evaluation: evaluate your Agent in CARLA scenarios\n"
667 | )
668 |
669 | # general parameters
670 | parser = argparse.ArgumentParser(
671 | description=description, formatter_class=RawTextHelpFormatter
672 | )
673 | parser.add_argument(
674 | "--host", default="localhost", help="IP of the host server (default: localhost)"
675 | )
676 | parser.add_argument(
677 | "--port", default=2000, type=int, help="TCP port to listen to (default: 2000)"
678 | )
679 | parser.add_argument(
680 | "--traffic-manager-port",
681 | default=8000,
682 | type=int,
683 | help="Port to use for the TrafficManager (default: 8000)",
684 | )
685 | parser.add_argument(
686 | "--traffic-manager-seed",
687 | default=0,
688 | type=int,
689 | help="Seed used by the TrafficManager (default: 0)",
690 | )
691 | parser.add_argument("--debug", type=int, help="Run with debug output", default=0)
692 | parser.add_argument(
693 | "--record",
694 | type=str,
695 | default="",
696 | help="Use CARLA recording feature to create a recording of the scenario",
697 | )
698 | parser.add_argument(
699 | "--timeout",
700 | default=600.0,
701 | type=float,
702 | help="Set the CARLA client timeout value in seconds",
703 | )
704 |
705 | # simulation setup
706 | parser.add_argument(
707 | "--routes", required=True, help="Name of the routes file to be executed."
708 | )
709 | parser.add_argument(
710 | "--routes-subset", default="", type=str, help="Execute a specific set of routes"
711 | )
712 | parser.add_argument(
713 | "--repetitions", type=int, default=1, help="Number of repetitions per route."
714 | )
715 |
716 | # agent-related options
717 | parser.add_argument(
718 | "-a",
719 | "--agent",
720 | type=str,
721 | help="Path to Agent's py file to evaluate",
722 | required=True,
723 | )
724 | parser.add_argument(
725 | "--agent-config",
726 | type=str,
727 | help="Path to Agent's configuration file",
728 | default="",
729 | )
730 |
731 | parser.add_argument(
732 | "--track", type=str, default="SENSORS", help="Participation track: SENSORS, MAP"
733 | )
734 | parser.add_argument(
735 | "--resume",
736 | type=bool,
737 | default=False,
738 | help="Resume execution from last checkpoint?",
739 | )
740 | parser.add_argument(
741 | "--checkpoint",
742 | type=str,
743 | default="./simulation_results.json",
744 | help="Path to checkpoint used for saving statistics and resuming",
745 | )
746 | parser.add_argument(
747 | "--debug-checkpoint",
748 | type=str,
749 | default="./live_results.txt",
750 | help="Path to checkpoint used for saving live results",
751 | )
752 | parser.add_argument("--gpu-rank", type=int, default=0)
753 | arguments = parser.parse_args()
754 |
755 | statistics_manager = StatisticsManager(
756 | arguments.checkpoint, arguments.debug_checkpoint
757 | )
758 | leaderboard_evaluator = LeaderboardEvaluator(arguments, statistics_manager)
759 | crashed = leaderboard_evaluator.run(arguments)
760 |
761 | del leaderboard_evaluator
762 |
763 | if crashed:
764 | sys.exit(-1)
765 | else:
766 | sys.exit(0)
767 |
768 |
769 | if __name__ == "__main__":
770 | main()
771 |
--------------------------------------------------------------------------------
/misc/run_eval_leaderboard.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from leaderboard.leaderboard_evaluator_local import main
3 | import argparse
4 | import os
5 | import sys
6 | import hydra
7 | from pathlib import Path
8 | from omegaconf import DictConfig, OmegaConf
9 |
10 |
11 | @hydra.main(config_path="../../Bench2DriveZoo/team_code/config", config_name="config")
12 | def main(cfg):
13 | print(OmegaConf.to_yaml(cfg))
14 |
15 | cfg_org = cfg.copy()
16 | cfg = cfg.experiments
17 |
18 | print(cfg_org.eval.routes)
19 | print(cfg_org.checkpoint)
20 | print("Working directory : {}".format(os.getcwd()))
21 | print(f"Save gifs: {cfg_org.save_explainability_viz}")
22 |
23 | # create result folder
24 | Path(cfg_org.checkpoint).parent.mkdir(parents=True, exist_ok=True)
25 |
26 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{cfg_org.CUDA_VISIBLE_DEVICES}"
27 |
28 | arg_dict0 = OmegaConf.to_container(cfg_org.eval, resolve=True)
29 | arg_dict1 = OmegaConf.to_container(cfg, resolve=True)
30 | arg_dict2 = OmegaConf.to_container(cfg_org, resolve=True)
31 | arg_dict1.update(arg_dict2)
32 | arg_dict1.update(arg_dict0)
33 | args = argparse.Namespace(**arg_dict1)
34 |
35 | from leaderboard import leaderboard_evaluator_local
36 | import numpy as np
37 |
38 | # np.warnings.filterwarnings("error", category=np.VisibleDeprecationWarning)
39 |
40 | leaderboard_evaluator_local.main_eval(args)
41 |
42 |
43 | if __name__ == "__main__":
44 | main()
45 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | annotated-types==0.7.0
2 | antlr4-python3-runtime==4.9.3
3 | asttokens==2.4.1
4 | certifi
5 | cffi
6 | charset-normalizer
7 | click==8.1.7
8 | comm
9 | contourpy==1.1.1
10 | coverage==7.6.1
11 | cycler==0.12.1
12 | debugpy
13 | decorator
14 | deepspeed==0.14.4
15 | dictor==0.1.12
16 | docker-pycreds==0.4.0
17 | dotmap==1.3.30
18 | einops==0.8.0
19 | elementpath==1.3.3
20 | ephem==4.1.5
21 | executing==2.0.1
22 | filelock==3.0.12
23 | filterpy==1.4.5
24 | flash-attn==2.6.3
25 | fonttools==4.53.1
26 | fsspec==2024.6.1
27 | gitdb==4.0.11
28 | GitPython==3.1.43
29 | h2
30 | hjson==3.1.0
31 | hpack==4.0.0
32 | huggingface-hub==0.24.5
33 | hydra-core==1.3.2
34 | hyperframe
35 | idna
36 | imageio==2.35.1
37 | imgaug==0.4.0
38 | importlib_metadata
39 | importlib_resources==6.4.0
40 | ipdb==0.13.13
41 | ipykernel==6.29.5
42 | ipython==8.12.3
43 | jedi
44 | Jinja2
45 | joblib==1.4.2
46 | jsonpickle==4.0.2
47 | jupyter_client
48 | jupyter_core
49 | kiwisolver==1.4.5
50 | laspy==2.5.4
51 | lazy_loader==0.4
52 | line_profiler==4.2.0
53 | MarkupSafe
54 | matplotlib==3.5.3
55 | matplotlib-inline
56 | mpmath
57 | natsort==8.4.0
58 | nest_asyncio
59 | networkx
60 | ninja==1.11.1.1
61 | numpy
62 | nvidia-ml-py==12.555.43
63 | omegaconf==2.3.0
64 | opencv-python==4.2.0.32
65 | packaging==24.1
66 | pandas==2.0.3
67 | parso
68 | pexpect
69 | pickleshare
70 | Pillow
71 | platformdirs==4.2.2
72 | plotly==5.23.0
73 | prompt_toolkit==3.0.47
74 | protobuf==5.27.3
75 | psutil
76 | ptyprocess
77 | pure_eval==0.2.3
78 | py-cpuinfo==9.0.0
79 | py-trees==0.8.3
80 | pycparser
81 | pydantic==2.8.2
82 | pydantic_core==2.20.1
83 | pydot==3.0.1
84 | pygame==2.6.0
85 | Pygments
86 | pyparsing==3.1.2
87 | PySocks
88 | python-dateutil
89 | pytz==2024.1
90 | PyWavelets==1.4.1
91 | PyYAML
92 | pyzmq
93 | rdp==0.8
94 | regex==2024.7.24
95 | requests
96 | safetensors==0.4.4
97 | scikit-image==0.21.0
98 | scikit-learn==1.3.2
99 | scipy==1.10.1
100 | sentencepiece==0.2.0
101 | sentry-sdk==2.12.0
102 | setproctitle==1.3.3
103 | Shapely==1.7.1
104 | simple_watchdog_timer==0.1.1
105 | six
106 | smmap==5.0.1
107 | stack-data==0.6.3
108 | sympy
109 | tabulate==0.9.0
110 | tenacity==9.0.0
111 | threadpoolctl==3.5.0
112 | tifffile==2023.7.10
113 | timm==1.0.8
114 | tokenizers==0.19.1
115 | tomli==2.0.1
116 | torch==2.4.0
117 | torchaudio==2.4.0
118 | torchvision==0.19.0
119 | tornado
120 | tqdm==4.66.5
121 | traitlets
122 | transformers==4.43.4
123 | transforms3d==0.4.2
124 | triton==3.0.0
125 | typing_extensions
126 | tzdata==2024.1
127 | ujson==5.10.0
128 | urllib3
129 | wandb==0.17.5
130 | wcwidth
131 | xmlschema==1.0.18
132 | zipp==3.20.2
133 | zstandard==0.23.0
--------------------------------------------------------------------------------
/team_code/config/config.yaml:
--------------------------------------------------------------------------------
1 | defaults:
2 | - user: example
3 | - experiments: eta
4 | - eval: b2d
5 |
6 | host: localhost
7 | port: 2000
8 | trafficManagerPort: 8000
9 | traffic_manager_port: ${trafficManagerPort}
10 | trafficManagerSeed: 0
11 | dataProviderSeed: 0
12 | debug: 0
13 | viz: 0
14 | viz_interval: 50
15 | viz_max: 2000
16 | viz_extra: 0
17 | record: ''
18 | timeout: 600.0
19 |
20 | hydra:
21 | run:
22 | dir: ${experiments.agent_root}/${save_path}/${experiments.runnickname}
23 | job:
24 | config:
25 | override_dirname:
26 | exclude_keys:
27 | - eval
28 | - experiments
29 | - experiments.wanderer_model_name
30 | - experiments.ponderer_model_name
31 | - port
32 | - trafficManagerPort
33 | - experiments.epoch_num
34 | - experiments.use_gru_output
35 | - user
36 | - traffic_manager_seed
37 | kv_sep: '='
38 | item_sep: '_'
39 | env_set:
40 | OMP_NUM_THREADS: 1
41 |
42 | repetitions: 1
43 | track: SENSORS
44 | resume: 1
45 | save_path: evallogs
46 | log_save_path: result_logs
47 | checkpoint_file: results.json
48 | debug_checkpoint: ${hydra:run.dir}/debug/${checkpoint_file}
49 | checkpoint: ${hydra:run.dir}/${checkpoint_file}
50 | traffic_manager_seed: 0
51 |
52 |
53 | DEBUG_CHALLENGE: 0
54 | CUDA_VISIBLE_DEVICES: 0
55 | SEED_OFFSET: 0
56 |
--------------------------------------------------------------------------------
/team_code/config/eval/b2d.yaml:
--------------------------------------------------------------------------------
1 | BENCHMARK: b2d
2 | BENCHMARKNICKNAME: b2dfull
3 | route_rel: leaderboard/data/bench2drive220.xml
4 | routes: ${user.working_dir}/${eval.route_rel}
--------------------------------------------------------------------------------
/team_code/config/experiments/eta.yaml:
--------------------------------------------------------------------------------
1 | BASE_PORT: 30000
2 | BASE_TM_PORT: 50000
3 | IS_BENCH2DRIVE: True
4 | BASE_ROUTES: leaderboard/data/bench2drive220
5 |
6 | BASE_CHECKPOINT_ENDPOINT: endpoint
7 | SAVE_PATH: ./dummy
8 | PLANNER_TYPE: only_traj
9 |
10 | agent: ${user.working_dir}/Bench2DriveZoo/team_code/eta_agent.py
11 | runnickname:
12 | root_path: ${user.working_dir}
13 | ponderer_model_name:
14 | agent_root: ${root_path}/${experiments.ponderer_model_name}
15 | ghost_agent_root: null
16 | ghost_only_viz: False
17 |
18 | epoch_num: 37
19 |
20 | creep_delay: 100
21 | creep_duration: 10
22 | use_creep: True
23 |
24 |
25 | steer_damping: 0.5
26 | use_gt_frc: null
27 |
28 | # Bench2Drive related
29 | routes_subset:
30 | b2d_path: ${user.b2d_path}
--------------------------------------------------------------------------------
/team_code/config/user/example.yaml:
--------------------------------------------------------------------------------
1 | working_dir: /home/shadi/research/
2 | b2d_path: /home/shadi/research/ETA/Bench2Drive
--------------------------------------------------------------------------------
/team_code/planner.py:
--------------------------------------------------------------------------------
1 | import os
2 | from collections import deque
3 |
4 | import numpy as np
5 | import math
6 |
7 | EARTH_RADIUS_EQUA = 6378137.0
8 |
9 | DEBUG = False
10 |
11 |
12 | class Plotter(object):
13 | def __init__(self, size):
14 | self.size = size
15 | self.clear()
16 | self.title = str(self.size)
17 |
18 | def clear(self):
19 | from PIL import Image, ImageDraw
20 |
21 | self.img = Image.fromarray(np.zeros((self.size, self.size, 3), dtype=np.uint8))
22 | self.draw = ImageDraw.Draw(self.img)
23 |
24 | def dot(self, pos, node, color=(255, 255, 255), label=None, r=2):
25 | x, y = 5.5 * (pos - node)
26 | x += self.size / 2
27 | y += self.size / 2
28 |
29 | self.draw.ellipse((x - r, y - r, x + r, y + r), color)
30 |
31 | # If a label is provided, write it
32 | if label:
33 | self.draw.text((x, y), label, fill=color)
34 |
35 | def show(self):
36 | if not DEBUG:
37 | return
38 |
39 | import cv2
40 |
41 | cv2.imshow(self.title, cv2.cvtColor(np.array(self.img), cv2.COLOR_BGR2RGB))
42 | cv2.waitKey(1)
43 |
44 |
45 | class RoutePlanner(object):
46 | def __init__(
47 | self, min_distance, max_distance, debug_size=256, lat_ref=42.0, lon_ref=2.0
48 | ):
49 | self.route = deque()
50 | self.min_distance = min_distance
51 | self.max_distance = max_distance
52 |
53 | # self.mean = np.array([49.0, 8.0]) # for carla 9.9
54 | # self.scale = np.array([111324.60662786, 73032.1570362]) # for carla 9.9
55 | self.mean = np.array([0.0, 0.0]) # for carla 9.10
56 | self.scale = np.array([111324.60662786, 111319.490945]) # for carla 9.10
57 |
58 | self.debug = Plotter(debug_size)
59 | # self.lat_ref, self.lon_ref = self._get_latlon_ref()
60 | self.lat_ref = lat_ref
61 | self.lon_ref = lon_ref
62 |
63 | def set_route(self, global_plan, gps=False, global_plan_world=None):
64 | self.route.clear()
65 |
66 | if global_plan_world:
67 | for (pos, cmd), (pos_word, _) in zip(global_plan, global_plan_world):
68 | if gps:
69 | pos = self.gps_to_location(np.array([pos["lat"], pos["lon"]]))
70 | # pos -= self.mean
71 | # pos *= self.scale
72 | else:
73 | pos = np.array([pos.location.x, pos.location.y])
74 | # pos -= self.mean
75 |
76 | self.route.append((pos, cmd, pos_word))
77 | else:
78 | for pos, cmd in global_plan:
79 | if gps:
80 | pos = self.gps_to_location(np.array([pos["lat"], pos["lon"]]))
81 | # pos -= self.mean
82 | # pos *= self.scale
83 | else:
84 | pos = np.array([pos.location.x, pos.location.y])
85 | # pos -= self.mean
86 |
87 | self.route.append((pos, cmd))
88 |
89 | def run_step(self, gps, return_both=False):
90 | self.debug.clear()
91 |
92 | if len(self.route) == 1:
93 | if return_both:
94 | return self.route[0], self.route[0]
95 | return self.route[0]
96 |
97 | to_pop = 0
98 | farthest_in_range = -np.inf
99 | cumulative_distance = 0.0
100 |
101 | for i in range(1, len(self.route)):
102 | if cumulative_distance > self.max_distance:
103 | break
104 |
105 | cumulative_distance += np.linalg.norm(
106 | self.route[i][0] - self.route[i - 1][0]
107 | )
108 | distance = np.linalg.norm(self.route[i][0] - gps)
109 |
110 | if distance <= self.min_distance and distance > farthest_in_range:
111 | farthest_in_range = distance
112 | to_pop = i
113 |
114 | r = 255 * int(distance > self.min_distance)
115 | g = 255 * int(self.route[i][1].value == 4)
116 | b = 255
117 | self.debug.dot(
118 | gps,
119 | self.route[i][0],
120 | (r, g, b),
121 | label=f"{distance:.2f}_{to_pop}_{farthest_in_range:.2f}",
122 | )
123 |
124 | for _ in range(to_pop):
125 | if len(self.route) > 2:
126 | self.route.popleft()
127 |
128 | self.debug.dot(gps, self.route[0][0], (0, 255, 0))
129 | self.debug.dot(gps, self.route[1][0], (255, 0, 0))
130 | self.debug.dot(gps, gps, (0, 0, 255))
131 | self.debug.show()
132 |
133 | if return_both:
134 | return (
135 | (self.route[1], self.route[2])
136 | if len(self.route) > 2
137 | else (self.route[1], self.route[1])
138 | )
139 |
140 | return self.route[1]
141 |
142 | def gps_to_location(self, gps):
143 | # gps content: numpy array: [lat, lon, alt]
144 | lat, lon = gps
145 | scale = math.cos(self.lat_ref * math.pi / 180.0)
146 | my = math.log(math.tan((lat + 90) * math.pi / 360.0)) * (
147 | EARTH_RADIUS_EQUA * scale
148 | )
149 | mx = (lon * (math.pi * EARTH_RADIUS_EQUA * scale)) / 180.0
150 | y = (
151 | scale
152 | * EARTH_RADIUS_EQUA
153 | * math.log(math.tan((90.0 + self.lat_ref) * math.pi / 360.0))
154 | - my
155 | )
156 | x = mx - scale * self.lon_ref * math.pi * EARTH_RADIUS_EQUA / 180.0
157 | return np.array([x, y])
158 |
--------------------------------------------------------------------------------