├── .gitignore ├── LICENSE ├── README.md ├── docs └── quick_start.md ├── projects ├── configs │ ├── sparsedrive_small_stage1.py │ └── sparsedrive_small_stage2.py └── mmdet3d_plugin │ ├── __init__.py │ ├── apis │ ├── __init__.py │ ├── mmdet_train.py │ ├── test.py │ └── train.py │ ├── core │ ├── box3d.py │ └── evaluation │ │ ├── __init__.py │ │ └── eval_hooks.py │ ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── map │ │ │ ├── AP.py │ │ │ ├── distance.py │ │ │ └── vector_eval.py │ │ ├── motion │ │ │ ├── motion_eval_uniad.py │ │ │ └── motion_utils.py │ │ └── planning │ │ │ └── planning_eval.py │ ├── map_utils │ │ ├── nuscmap_extractor.py │ │ └── utils.py │ ├── nuscenes_3d_dataset.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── augment.py │ │ ├── loading.py │ │ ├── transform.py │ │ └── vectorize.py │ ├── samplers │ │ ├── __init__.py │ │ ├── distributed_sampler.py │ │ ├── group_in_batch_sampler.py │ │ ├── group_sampler.py │ │ └── sampler.py │ └── utils.py │ ├── models │ ├── __init__.py │ ├── attention.py │ ├── base_target.py │ ├── blocks.py │ ├── detection3d │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── detection3d_blocks.py │ │ ├── detection3d_head.py │ │ ├── losses.py │ │ └── target.py │ ├── grid_mask.py │ ├── instance_bank.py │ ├── map │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── loss.py │ │ ├── map_blocks.py │ │ ├── match_cost.py │ │ └── target.py │ ├── motion │ │ ├── __init__.py │ │ ├── decoder.py │ │ ├── instance_queue.py │ │ ├── motion_blocks.py │ │ ├── motion_planning_head.py │ │ └── target.py │ ├── sparsedrive.py │ └── sparsedrive_head.py │ └── ops │ ├── __init__.py │ ├── deformable_aggregation.py │ ├── setup.py │ └── src │ ├── deformable_aggregation.cpp │ └── deformable_aggregation_cuda.cu ├── requirement.txt ├── resources ├── legend.png ├── motion_planner.png ├── overview.png ├── sdc_car.png └── sparse_perception.png ├── scripts ├── create_data.sh ├── kmeans.sh ├── test.sh ├── train.sh └── visualize.sh └── tools ├── benchmark.py ├── data_converter ├── __init__.py └── nuscenes_converter.py ├── dist_test.sh ├── dist_train.sh ├── fuse_conv_bn.py ├── kmeans ├── kmeans_det.py ├── kmeans_map.py ├── kmeans_motion.py └── kmeans_plan.py ├── test.py ├── train.py └── visualization ├── bev_render.py ├── cam_render.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.npy 3 | *.pth 4 | *.whl 5 | *.swp 6 | 7 | data/ 8 | ckpt/ 9 | work_dirs*/ 10 | dist_test/ 11 | vis/ 12 | val/ 13 | lib/ 14 | 15 | *.egg-info 16 | build/ 17 | __pycache__/ 18 | *.so 19 | 20 | job_scripts/ 21 | temp_ops/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 swc-17 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparseDrive: End-to-End Autonomous Driving via Sparse Scene Representation 2 | 3 | https://github.com/swc-17/SparseDrive/assets/64842878/867276dc-7c19-4e01-9a8e-81c4ed844745 4 | 5 | ## News 6 | * **`17 March, 2025`:** SparseDrive is accepted by ICRA 2025. 7 | * **`24 June, 2024`:** We reorganize code for better readability. Code & Models are released. 8 | * **`31 May, 2024`:** We release the SparseDrive paper on [arXiv](https://arxiv.org/abs/2405.19620). Code & Models will be released in June, 2024. Please stay tuned! 9 | 10 | 11 | ## Introduction 12 | > SparseDrive is a Sparse-Centric paradigm for end-to-end autonomous driving. 13 | - We explore the sparse scene representation for end-to-end autonomous driving and propose a Sparse-Centric paradigm named SparseDrive, which unifies multiple tasks with sparse instance representation. 14 | - We revise the great similarity shared between motion prediction and planning, correspondingly leading to a parallel design for motion planner. We further propose a hierarchical planning selection strategy incorporating a collision-aware rescore module to boost the planning performance. 15 | - On the challenging nuScenes benchmark, SparseDrive surpasses previous SOTA methods in terms of all metrics, especially the safety-critical metric collision rate, while keeping much higher training and inference efficiency. 16 | 17 |
18 | 21 |
22 |
Overview of SparseDrive. SparseDrive first encodes multi-view images into feature maps, 26 | then learns sparse scene representation through symmetric sparse perception, and finally perform 27 | motion prediction and planning in a parallel manner. An instance memory queue is devised for 28 | temporal modeling.
29 |
30 |
31 | 34 |
35 |
Model architecture of symmetric sparse perception, which unifies detection, tracking and 39 | online mapping in a symmetric structure.
40 |
41 |
42 | 45 |
46 |
Model structure of parallel motion planner, which performs motion prediction and planning 50 | simultaneously and outputs safe planning trajectory.
51 |
52 | 53 | ## Results in paper 54 | 55 | - Comprehensive results for all tasks on [nuScenes](https://github.com/nutonomy/nuscenes-devkit). 56 | 57 | | Method | NDS | AMOTA | minADE (m) | L2 (m) Avg | Col. (%) Avg | Training Time (h) | FPS | 58 | | :---: | :---:| :---: | :---: | :---: | :---: | :---: | :---: | 59 | | UniAD | 0.498 | 0.359 | 0.71 | 0.73 | 0.61 | 144 | 1.8 | 60 | | SparseDrive-S | 0.525 | 0.386 | 0.62 | 0.61 | 0.08 | **20** | **9.0** | 61 | | SparseDrive-B | **0.588** | **0.501** | **0.60** | **0.58** | **0.06** | 30 | 7.3 | 62 | 63 | - Open-loop planning results on [nuScenes](https://github.com/nutonomy/nuscenes-devkit). 64 | 65 | | Method | L2 (m) 1s | L2 (m) 2s | L2 (m) 3s | L2 (m) Avg | Col. (%) 1s | Col. (%) 2s | Col. (%) 3s | Col. (%) Avg | FPS | 66 | | :---: | :---: | :---: | :---: | :---:| :---: | :---: | :---: | :---: | :---: | 67 | | UniAD | 0.45 | 0.70 | 1.04 | 0.73 | 0.62 | 0.58 | 0.63 | 0.61 | 1.8 | 68 | | VAD | 0.41 | 0.70 | 1.05 | 0.72 | 0.03 | 0.19 | 0.43 | 0.21 |4.5 | 69 | | SparseDrive-S | **0.29** | 0.58 | 0.96 | 0.61 | 0.01 | 0.05 | 0.18 | 0.08 | **9.0** | 70 | | SparseDrive-B | **0.29** | **0.55** | **0.91** | **0.58** | **0.01** | **0.02** | **0.13** | **0.06** | 7.3 | 71 | 72 | ## Results of released checkpoint 73 | We found that some collision cases were not taken into consideration in our previous code, so we re-implement the evaluation metric for collision rate in released code and provide updated results. 74 | 75 | ## Main results 76 | | Model | config | ckpt | log | det: NDS | mapping: mAP | track: AMOTA |track: AMOTP | motion: EPA_car |motion: minADE_car| motion: minFDE_car | motion: MissRate_car | planning: CR | planning: L2 | 77 | | :---: | :---: | :---: | :---: | :---: | :---:|:---:|:---: | :---: | :----: | :----: | :----: | :----: | :----: | 78 | | Stage1 |[cfg](projects/configs/sparsedrive_small_stage1.py)|[ckpt](https://github.com/swc-17/SparseDrive/releases/download/v1.0/sparsedrive_stage1.pth)|[log](https://github.com/swc-17/SparseDrive/releases/download/v1.0/sparsedrive_stage1_log.txt)|0.5260|0.5689|0.385|1.260| | | | | | | 79 | | Stage2 |[cfg](projects/configs/sparsedrive_small_stage2.py)|[ckpt](https://github.com/swc-17/SparseDrive/releases/download/v1.0/sparsedrive_stage2.pth)|[log](https://github.com/swc-17/SparseDrive/releases/download/v1.0/sparsedrive_stage2_log.txt)|0.5257|0.5656|0.372|1.248|0.492|0.61|0.95|0.133|0.097%|0.61| 80 | 81 | ## Detailed results for planning 82 | | Method | L2 (m) 1s | L2 (m) 2s | L2 (m) 3s | L2 (m) Avg | Col. (%) 1s | Col. (%) 2s | Col. (%) 3s | Col. (%) Avg | 83 | | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | 84 | | UniAD | 0.45 | 0.70 | 1.04 | 0.73 | 0.66 | 0.66 | 0.72 | 0.68 | 85 | | UniAD-wo-post-optim | 0.32 | 0.58 | 0.94 | 0.61 | 0.17 | 0.27 | 0.42 | 0.29 | 86 | | VAD | 0.41 | 0.70 | 1.05 | 0.72 | 0.03 | 0.21 | 0.49 | 0.24 | 87 | | SparseDrive-S | 0.30 | 0.58 | 0.95 | 0.61 | 0.01 | 0.05 | 0.23 | 0.10 | 88 | 89 | 90 | ## Quick Start 91 | [Quick Start](docs/quick_start.md) 92 | 93 | ## Citation 94 | If you find SparseDrive useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. 95 | ``` 96 | @article{sun2024sparsedrive, 97 | title={SparseDrive: End-to-End Autonomous Driving via Sparse Scene Representation}, 98 | author={Sun, Wenchao and Lin, Xuewu and Shi, Yining and Zhang, Chuang and Wu, Haoran and Zheng, Sifa}, 99 | journal={arXiv preprint arXiv:2405.19620}, 100 | year={2024} 101 | } 102 | ``` 103 | 104 | ## Acknowledgement 105 | - [Sparse4D](https://github.com/HorizonRobotics/Sparse4D) 106 | - [UniAD](https://github.com/OpenDriveLab/UniAD) 107 | - [VAD](https://github.com/hustvl/VAD) 108 | - [StreamPETR](https://github.com/exiawsh/StreamPETR) 109 | - [StreamMapNet](https://github.com/yuantianyuan01/StreamMapNet) 110 | - [mmdet3d](https://github.com/open-mmlab/mmdetection3d) 111 | 112 | -------------------------------------------------------------------------------- /docs/quick_start.md: -------------------------------------------------------------------------------- 1 | # Quick Start 2 | 3 | ### Set up a new virtual environment 4 | ```bash 5 | conda create -n sparsedrive python=3.8 -y 6 | conda activate sparsedrive 7 | ``` 8 | 9 | ### Install dependency packpages 10 | ```bash 11 | sparsedrive_path="path/to/sparsedrive" 12 | cd ${sparsedrive_path} 13 | pip3 install --upgrade pip 14 | pip3 install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116 15 | pip3 install -r requirement.txt 16 | ``` 17 | 18 | ### Compile the deformable_aggregation CUDA op 19 | ```bash 20 | cd projects/mmdet3d_plugin/ops 21 | python3 setup.py develop 22 | cd ../../../ 23 | ``` 24 | 25 | ### Prepare the data 26 | Download the [NuScenes dataset](https://www.nuscenes.org/nuscenes#download) and CAN bus expansion, put CAN bus expansion in /path/to/nuscenes, create symbolic links. 27 | ```bash 28 | cd ${sparsedrive_path} 29 | mkdir data 30 | ln -s path/to/nuscenes ./data/nuscenes 31 | ``` 32 | 33 | Pack the meta-information and labels of the dataset, and generate the required pkl files to data/infos. Note that we also generate map_annos in data_converter, with a roi_size of (30, 60) as default, if you want a different range, you can modify roi_size in tools/data_converter/nuscenes_converter.py. 34 | ```bash 35 | sh scripts/create_data.sh 36 | ``` 37 | 38 | ### Generate anchors by K-means 39 | Gnerated anchors are saved to data/kmeans and can be visualized in vis/kmeans. 40 | ```bash 41 | sh scripts/kmeans.sh 42 | ``` 43 | 44 | 45 | ### Download pre-trained weights 46 | Download the required backbone [pre-trained weights](https://download.pytorch.org/models/resnet50-19c8e357.pth). 47 | ```bash 48 | mkdir ckpt 49 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth -O ckpt/resnet50-19c8e357.pth 50 | ``` 51 | 52 | ### Commence training and testing 53 | ```bash 54 | # train 55 | sh scripts/train.sh 56 | 57 | # test 58 | sh scripts/test.sh 59 | ``` 60 | 61 | ### Visualization 62 | ``` 63 | sh scripts/visualize.sh 64 | ``` 65 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .models import * 3 | from .apis import * 4 | from .core.evaluation import * 5 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import custom_train_model 2 | from .mmdet_train import custom_train_detector 3 | 4 | # from .test import custom_multi_gpu_test 5 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/apis/mmdet_train.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------- 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | # --------------------------------------------- 4 | # Modified by Zhiqi Li 5 | # --------------------------------------------- 6 | import random 7 | import warnings 8 | 9 | import numpy as np 10 | import torch 11 | import torch.distributed as dist 12 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 13 | from mmcv.runner import ( 14 | HOOKS, 15 | DistSamplerSeedHook, 16 | EpochBasedRunner, 17 | Fp16OptimizerHook, 18 | OptimizerHook, 19 | build_optimizer, 20 | build_runner, 21 | get_dist_info, 22 | ) 23 | from mmcv.utils import build_from_cfg 24 | 25 | from mmdet.core import EvalHook 26 | 27 | from mmdet.datasets import build_dataset, replace_ImageToTensor 28 | from mmdet.utils import get_root_logger 29 | import time 30 | import os.path as osp 31 | from projects.mmdet3d_plugin.datasets.builder import build_dataloader 32 | from projects.mmdet3d_plugin.core.evaluation.eval_hooks import ( 33 | CustomDistEvalHook, 34 | ) 35 | from projects.mmdet3d_plugin.datasets import custom_build_dataset 36 | 37 | 38 | def custom_train_detector( 39 | model, 40 | dataset, 41 | cfg, 42 | distributed=False, 43 | validate=False, 44 | timestamp=None, 45 | meta=None, 46 | ): 47 | logger = get_root_logger(cfg.log_level) 48 | 49 | # prepare data loaders 50 | 51 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 52 | # assert len(dataset)==1s 53 | if "imgs_per_gpu" in cfg.data: 54 | logger.warning( 55 | '"imgs_per_gpu" is deprecated in MMDet V2.0. ' 56 | 'Please use "samples_per_gpu" instead' 57 | ) 58 | if "samples_per_gpu" in cfg.data: 59 | logger.warning( 60 | f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' 61 | f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' 62 | f"={cfg.data.imgs_per_gpu} is used in this experiments" 63 | ) 64 | else: 65 | logger.warning( 66 | 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' 67 | f"{cfg.data.imgs_per_gpu} in this experiments" 68 | ) 69 | cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu 70 | 71 | if "runner" in cfg: 72 | runner_type = cfg.runner["type"] 73 | else: 74 | runner_type = "EpochBasedRunner" 75 | data_loaders = [ 76 | build_dataloader( 77 | ds, 78 | cfg.data.samples_per_gpu, 79 | cfg.data.workers_per_gpu, 80 | # cfg.gpus will be ignored if distributed 81 | len(cfg.gpu_ids), 82 | dist=distributed, 83 | seed=cfg.seed, 84 | nonshuffler_sampler=dict( 85 | type="DistributedSampler" 86 | ), # dict(type='DistributedSampler'), 87 | runner_type=runner_type, 88 | ) 89 | for ds in dataset 90 | ] 91 | 92 | # put model on gpus 93 | if distributed: 94 | find_unused_parameters = cfg.get("find_unused_parameters", False) 95 | # Sets the `find_unused_parameters` parameter in 96 | # torch.nn.parallel.DistributedDataParallel 97 | model = MMDistributedDataParallel( 98 | model.cuda(), 99 | device_ids=[torch.cuda.current_device()], 100 | broadcast_buffers=False, 101 | find_unused_parameters=find_unused_parameters, 102 | ) 103 | 104 | else: 105 | model = MMDataParallel( 106 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids 107 | ) 108 | 109 | # build runner 110 | optimizer = build_optimizer(model, cfg.optimizer) 111 | 112 | if "runner" not in cfg: 113 | cfg.runner = { 114 | "type": "EpochBasedRunner", 115 | "max_epochs": cfg.total_epochs, 116 | } 117 | warnings.warn( 118 | "config is now expected to have a `runner` section, " 119 | "please set `runner` in your config.", 120 | UserWarning, 121 | ) 122 | else: 123 | if "total_epochs" in cfg: 124 | assert cfg.total_epochs == cfg.runner.max_epochs 125 | 126 | runner = build_runner( 127 | cfg.runner, 128 | default_args=dict( 129 | model=model, 130 | optimizer=optimizer, 131 | work_dir=cfg.work_dir, 132 | logger=logger, 133 | meta=meta, 134 | ), 135 | ) 136 | 137 | # an ugly workaround to make .log and .log.json filenames the same 138 | runner.timestamp = timestamp 139 | 140 | # fp16 setting 141 | fp16_cfg = cfg.get("fp16", None) 142 | if fp16_cfg is not None: 143 | optimizer_config = Fp16OptimizerHook( 144 | **cfg.optimizer_config, **fp16_cfg, distributed=distributed 145 | ) 146 | elif distributed and "type" not in cfg.optimizer_config: 147 | optimizer_config = OptimizerHook(**cfg.optimizer_config) 148 | else: 149 | optimizer_config = cfg.optimizer_config 150 | 151 | # register hooks 152 | runner.register_training_hooks( 153 | cfg.lr_config, 154 | optimizer_config, 155 | cfg.checkpoint_config, 156 | cfg.log_config, 157 | cfg.get("momentum_config", None), 158 | ) 159 | 160 | # register profiler hook 161 | # trace_config = dict(type='tb_trace', dir_name='work_dir') 162 | # profiler_config = dict(on_trace_ready=trace_config) 163 | # runner.register_profiler_hook(profiler_config) 164 | 165 | if distributed: 166 | if isinstance(runner, EpochBasedRunner): 167 | runner.register_hook(DistSamplerSeedHook()) 168 | 169 | # register eval hooks 170 | if validate: 171 | # Support batch_size > 1 in validation 172 | val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1) 173 | if val_samples_per_gpu > 1: 174 | assert False 175 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 176 | cfg.data.val.pipeline = replace_ImageToTensor( 177 | cfg.data.val.pipeline 178 | ) 179 | val_dataset = custom_build_dataset(cfg.data.val, dict(test_mode=True)) 180 | 181 | val_dataloader = build_dataloader( 182 | val_dataset, 183 | samples_per_gpu=val_samples_per_gpu, 184 | workers_per_gpu=cfg.data.workers_per_gpu, 185 | dist=distributed, 186 | shuffle=False, 187 | nonshuffler_sampler=dict(type="DistributedSampler"), 188 | ) 189 | eval_cfg = cfg.get("evaluation", {}) 190 | eval_cfg["by_epoch"] = cfg.runner["type"] != "IterBasedRunner" 191 | eval_cfg["jsonfile_prefix"] = osp.join( 192 | "val", 193 | cfg.work_dir, 194 | time.ctime().replace(" ", "_").replace(":", "_"), 195 | ) 196 | eval_hook = CustomDistEvalHook if distributed else EvalHook 197 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 198 | 199 | # user-defined hooks 200 | if cfg.get("custom_hooks", None): 201 | custom_hooks = cfg.custom_hooks 202 | assert isinstance( 203 | custom_hooks, list 204 | ), f"custom_hooks expect list type, but got {type(custom_hooks)}" 205 | for hook_cfg in cfg.custom_hooks: 206 | assert isinstance(hook_cfg, dict), ( 207 | "Each item in custom_hooks expects dict type, but got " 208 | f"{type(hook_cfg)}" 209 | ) 210 | hook_cfg = hook_cfg.copy() 211 | priority = hook_cfg.pop("priority", "NORMAL") 212 | hook = build_from_cfg(hook_cfg, HOOKS) 213 | runner.register_hook(hook, priority=priority) 214 | 215 | if cfg.resume_from: 216 | runner.resume(cfg.resume_from) 217 | elif cfg.load_from: 218 | runner.load_checkpoint(cfg.load_from) 219 | runner.run(data_loaders, cfg.workflow) 220 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/apis/test.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------- 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | # --------------------------------------------- 4 | # Modified by Zhiqi Li 5 | # --------------------------------------------- 6 | import os.path as osp 7 | import pickle 8 | import shutil 9 | import tempfile 10 | import time 11 | 12 | import mmcv 13 | import torch 14 | import torch.distributed as dist 15 | from mmcv.image import tensor2imgs 16 | from mmcv.runner import get_dist_info 17 | 18 | from mmdet.core import encode_mask_results 19 | 20 | 21 | import mmcv 22 | import numpy as np 23 | import pycocotools.mask as mask_util 24 | 25 | 26 | def custom_encode_mask_results(mask_results): 27 | """Encode bitmap mask to RLE code. Semantic Masks only 28 | Args: 29 | mask_results (list | tuple[list]): bitmap mask results. 30 | In mask scoring rcnn, mask_results is a tuple of (segm_results, 31 | segm_cls_score). 32 | Returns: 33 | list | tuple: RLE encoded mask. 34 | """ 35 | cls_segms = mask_results 36 | num_classes = len(cls_segms) 37 | encoded_mask_results = [] 38 | for i in range(len(cls_segms)): 39 | encoded_mask_results.append( 40 | mask_util.encode( 41 | np.array( 42 | cls_segms[i][:, :, np.newaxis], order="F", dtype="uint8" 43 | ) 44 | )[0] 45 | ) # encoded with RLE 46 | return [encoded_mask_results] 47 | 48 | 49 | def custom_multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): 50 | """Test model with multiple gpus. 51 | This method tests model with multiple gpus and collects the results 52 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 53 | it encodes results to gpu tensors and use gpu communication for results 54 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 55 | and collects them by the rank 0 worker. 56 | Args: 57 | model (nn.Module): Model to be tested. 58 | data_loader (nn.Dataloader): Pytorch data loader. 59 | tmpdir (str): Path of directory to save the temporary results from 60 | different gpus under cpu mode. 61 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 62 | Returns: 63 | list: The prediction results. 64 | """ 65 | model.eval() 66 | bbox_results = [] 67 | mask_results = [] 68 | dataset = data_loader.dataset 69 | rank, world_size = get_dist_info() 70 | if rank == 0: 71 | prog_bar = mmcv.ProgressBar(len(dataset)) 72 | time.sleep(2) # This line can prevent deadlock problem in some cases. 73 | have_mask = False 74 | for i, data in enumerate(data_loader): 75 | with torch.no_grad(): 76 | result = model(return_loss=False, rescale=True, **data) 77 | # encode mask results 78 | if isinstance(result, dict): 79 | if "bbox_results" in result.keys(): 80 | bbox_result = result["bbox_results"] 81 | batch_size = len(result["bbox_results"]) 82 | bbox_results.extend(bbox_result) 83 | if ( 84 | "mask_results" in result.keys() 85 | and result["mask_results"] is not None 86 | ): 87 | mask_result = custom_encode_mask_results( 88 | result["mask_results"] 89 | ) 90 | mask_results.extend(mask_result) 91 | have_mask = True 92 | else: 93 | batch_size = len(result) 94 | bbox_results.extend(result) 95 | 96 | if rank == 0: 97 | for _ in range(batch_size * world_size): 98 | prog_bar.update() 99 | 100 | # collect results from all ranks 101 | if gpu_collect: 102 | bbox_results = collect_results_gpu(bbox_results, len(dataset)) 103 | if have_mask: 104 | mask_results = collect_results_gpu(mask_results, len(dataset)) 105 | else: 106 | mask_results = None 107 | else: 108 | bbox_results = collect_results_cpu(bbox_results, len(dataset), tmpdir) 109 | tmpdir = tmpdir + "_mask" if tmpdir is not None else None 110 | if have_mask: 111 | mask_results = collect_results_cpu( 112 | mask_results, len(dataset), tmpdir 113 | ) 114 | else: 115 | mask_results = None 116 | 117 | if mask_results is None: 118 | return bbox_results 119 | return {"bbox_results": bbox_results, "mask_results": mask_results} 120 | 121 | 122 | def collect_results_cpu(result_part, size, tmpdir=None): 123 | rank, world_size = get_dist_info() 124 | # create a tmp dir if it is not specified 125 | if tmpdir is None: 126 | MAX_LEN = 512 127 | # 32 is whitespace 128 | dir_tensor = torch.full( 129 | (MAX_LEN,), 32, dtype=torch.uint8, device="cuda" 130 | ) 131 | if rank == 0: 132 | mmcv.mkdir_or_exist(".dist_test") 133 | tmpdir = tempfile.mkdtemp(dir=".dist_test") 134 | tmpdir = torch.tensor( 135 | bytearray(tmpdir.encode()), dtype=torch.uint8, device="cuda" 136 | ) 137 | dir_tensor[: len(tmpdir)] = tmpdir 138 | dist.broadcast(dir_tensor, 0) 139 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() 140 | else: 141 | mmcv.mkdir_or_exist(tmpdir) 142 | # dump the part result to the dir 143 | mmcv.dump(result_part, osp.join(tmpdir, f"part_{rank}.pkl")) 144 | dist.barrier() 145 | # collect all parts 146 | if rank != 0: 147 | return None 148 | else: 149 | # load results of all parts from tmp dir 150 | part_list = [] 151 | for i in range(world_size): 152 | part_file = osp.join(tmpdir, f"part_{i}.pkl") 153 | part_list.append(mmcv.load(part_file)) 154 | # sort the results 155 | ordered_results = [] 156 | """ 157 | bacause we change the sample of the evaluation stage to make sure that 158 | each gpu will handle continuous sample, 159 | """ 160 | # for res in zip(*part_list): 161 | for res in part_list: 162 | ordered_results.extend(list(res)) 163 | # the dataloader may pad some samples 164 | ordered_results = ordered_results[:size] 165 | # remove tmp dir 166 | shutil.rmtree(tmpdir) 167 | return ordered_results 168 | 169 | 170 | def collect_results_gpu(result_part, size): 171 | collect_results_cpu(result_part, size) 172 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/apis/train.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------- 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | # --------------------------------------------- 4 | # Modified by Zhiqi Li 5 | # --------------------------------------------- 6 | 7 | from .mmdet_train import custom_train_detector 8 | # from mmseg.apis import train_segmentor 9 | from mmdet.apis import train_detector 10 | 11 | 12 | def custom_train_model( 13 | model, 14 | dataset, 15 | cfg, 16 | distributed=False, 17 | validate=False, 18 | timestamp=None, 19 | meta=None, 20 | ): 21 | """A function wrapper for launching model training according to cfg. 22 | 23 | Because we need different eval_hook in runner. Should be deprecated in the 24 | future. 25 | """ 26 | if cfg.model.type in ["EncoderDecoder3D"]: 27 | assert False 28 | else: 29 | custom_train_detector( 30 | model, 31 | dataset, 32 | cfg, 33 | distributed=distributed, 34 | validate=validate, 35 | timestamp=timestamp, 36 | meta=meta, 37 | ) 38 | 39 | 40 | def train_model( 41 | model, 42 | dataset, 43 | cfg, 44 | distributed=False, 45 | validate=False, 46 | timestamp=None, 47 | meta=None, 48 | ): 49 | """A function wrapper for launching model training according to cfg. 50 | 51 | Because we need different eval_hook in runner. Should be deprecated in the 52 | future. 53 | """ 54 | train_detector( 55 | model, 56 | dataset, 57 | cfg, 58 | distributed=distributed, 59 | validate=validate, 60 | timestamp=timestamp, 61 | meta=meta, 62 | ) 63 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/core/box3d.py: -------------------------------------------------------------------------------- 1 | X, Y, Z, W, L, H, SIN_YAW, COS_YAW, VX, VY, VZ = list(range(11)) # undecoded 2 | CNS, YNS = 0, 1 # centerness and yawness indices in quality 3 | YAW = 6 # decoded 4 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_hooks import CustomDistEvalHook -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Note: Considering that MMCV's EvalHook updated its interface in V1.3.16, 2 | # in order to avoid strong version dependency, we did not directly 3 | # inherit EvalHook but BaseDistEvalHook. 4 | 5 | import bisect 6 | import os.path as osp 7 | 8 | import mmcv 9 | import torch.distributed as dist 10 | from mmcv.runner import DistEvalHook as BaseDistEvalHook 11 | from mmcv.runner import EvalHook as BaseEvalHook 12 | from torch.nn.modules.batchnorm import _BatchNorm 13 | from mmdet.core.evaluation.eval_hooks import DistEvalHook 14 | 15 | 16 | def _calc_dynamic_intervals(start_interval, dynamic_interval_list): 17 | assert mmcv.is_list_of(dynamic_interval_list, tuple) 18 | 19 | dynamic_milestones = [0] 20 | dynamic_milestones.extend( 21 | [dynamic_interval[0] for dynamic_interval in dynamic_interval_list] 22 | ) 23 | dynamic_intervals = [start_interval] 24 | dynamic_intervals.extend( 25 | [dynamic_interval[1] for dynamic_interval in dynamic_interval_list] 26 | ) 27 | return dynamic_milestones, dynamic_intervals 28 | 29 | 30 | class CustomDistEvalHook(BaseDistEvalHook): 31 | def __init__(self, *args, dynamic_intervals=None, **kwargs): 32 | super(CustomDistEvalHook, self).__init__(*args, **kwargs) 33 | self.use_dynamic_intervals = dynamic_intervals is not None 34 | if self.use_dynamic_intervals: 35 | ( 36 | self.dynamic_milestones, 37 | self.dynamic_intervals, 38 | ) = _calc_dynamic_intervals(self.interval, dynamic_intervals) 39 | 40 | def _decide_interval(self, runner): 41 | if self.use_dynamic_intervals: 42 | progress = runner.epoch if self.by_epoch else runner.iter 43 | step = bisect.bisect(self.dynamic_milestones, (progress + 1)) 44 | # Dynamically modify the evaluation interval 45 | self.interval = self.dynamic_intervals[step - 1] 46 | 47 | def before_train_epoch(self, runner): 48 | """Evaluate the model only at the start of training by epoch.""" 49 | self._decide_interval(runner) 50 | super().before_train_epoch(runner) 51 | 52 | def before_train_iter(self, runner): 53 | self._decide_interval(runner) 54 | super().before_train_iter(runner) 55 | 56 | def _do_evaluate(self, runner): 57 | """perform evaluation and save ckpt.""" 58 | # Synchronization of BatchNorm's buffer (running_mean 59 | # and running_var) is not supported in the DDP of pytorch, 60 | # which may cause the inconsistent performance of models in 61 | # different ranks, so we broadcast BatchNorm's buffers 62 | # of rank 0 to other ranks to avoid this. 63 | if self.broadcast_bn_buffer: 64 | model = runner.model 65 | for name, module in model.named_modules(): 66 | if ( 67 | isinstance(module, _BatchNorm) 68 | and module.track_running_stats 69 | ): 70 | dist.broadcast(module.running_var, 0) 71 | dist.broadcast(module.running_mean, 0) 72 | 73 | if not self._should_evaluate(runner): 74 | return 75 | 76 | tmpdir = self.tmpdir 77 | if tmpdir is None: 78 | tmpdir = osp.join(runner.work_dir, ".eval_hook") 79 | 80 | from projects.mmdet3d_plugin.apis.test import ( 81 | custom_multi_gpu_test, 82 | ) # to solve circlur import 83 | 84 | results = custom_multi_gpu_test( 85 | runner.model, 86 | self.dataloader, 87 | tmpdir=tmpdir, 88 | gpu_collect=self.gpu_collect, 89 | ) 90 | if runner.rank == 0: 91 | print("\n") 92 | runner.log_buffer.output["eval_iter_num"] = len(self.dataloader) 93 | 94 | key_score = self.evaluate(runner, results) 95 | 96 | if self.save_best: 97 | self._save_ckpt(runner, key_score) 98 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nuscenes_3d_dataset import NuScenes3DDataset 2 | from .builder import * 3 | from .pipelines import * 4 | from .samplers import * 5 | 6 | __all__ = [ 7 | 'NuScenes3DDataset', 8 | "custom_build_dataset", 9 | ] 10 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import platform 3 | import random 4 | from functools import partial 5 | 6 | import numpy as np 7 | from mmcv.parallel import collate 8 | from mmcv.runner import get_dist_info 9 | from mmcv.utils import Registry, build_from_cfg 10 | from torch.utils.data import DataLoader 11 | 12 | from mmdet.datasets.samplers import GroupSampler 13 | from projects.mmdet3d_plugin.datasets.samplers import ( 14 | GroupInBatchSampler, 15 | DistributedGroupSampler, 16 | DistributedSampler, 17 | build_sampler 18 | ) 19 | 20 | 21 | def build_dataloader( 22 | dataset, 23 | samples_per_gpu, 24 | workers_per_gpu, 25 | num_gpus=1, 26 | dist=True, 27 | shuffle=True, 28 | seed=None, 29 | shuffler_sampler=None, 30 | nonshuffler_sampler=None, 31 | runner_type="EpochBasedRunner", 32 | **kwargs 33 | ): 34 | """Build PyTorch DataLoader. 35 | In distributed training, each GPU/process has a dataloader. 36 | In non-distributed training, there is only one dataloader for all GPUs. 37 | Args: 38 | dataset (Dataset): A PyTorch dataset. 39 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 40 | batch size of each GPU. 41 | workers_per_gpu (int): How many subprocesses to use for data loading 42 | for each GPU. 43 | num_gpus (int): Number of GPUs. Only used in non-distributed training. 44 | dist (bool): Distributed training/test or not. Default: True. 45 | shuffle (bool): Whether to shuffle the data at every epoch. 46 | Default: True. 47 | kwargs: any keyword argument to be used to initialize DataLoader 48 | Returns: 49 | DataLoader: A PyTorch dataloader. 50 | """ 51 | rank, world_size = get_dist_info() 52 | batch_sampler = None 53 | if runner_type == 'IterBasedRunner': 54 | print("Use GroupInBatchSampler !!!") 55 | batch_sampler = GroupInBatchSampler( 56 | dataset, 57 | samples_per_gpu, 58 | world_size, 59 | rank, 60 | seed=seed, 61 | ) 62 | batch_size = 1 63 | sampler = None 64 | num_workers = workers_per_gpu 65 | elif dist: 66 | # DistributedGroupSampler will definitely shuffle the data to satisfy 67 | # that images on each GPU are in the same group 68 | if shuffle: 69 | print("Use DistributedGroupSampler !!!") 70 | sampler = build_sampler( 71 | shuffler_sampler 72 | if shuffler_sampler is not None 73 | else dict(type="DistributedGroupSampler"), 74 | dict( 75 | dataset=dataset, 76 | samples_per_gpu=samples_per_gpu, 77 | num_replicas=world_size, 78 | rank=rank, 79 | seed=seed, 80 | ), 81 | ) 82 | else: 83 | sampler = build_sampler( 84 | nonshuffler_sampler 85 | if nonshuffler_sampler is not None 86 | else dict(type="DistributedSampler"), 87 | dict( 88 | dataset=dataset, 89 | num_replicas=world_size, 90 | rank=rank, 91 | shuffle=shuffle, 92 | seed=seed, 93 | ), 94 | ) 95 | 96 | batch_size = samples_per_gpu 97 | num_workers = workers_per_gpu 98 | else: 99 | # assert False, 'not support in bevformer' 100 | print("WARNING!!!!, Only can be used for obtain inference speed!!!!") 101 | sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None 102 | batch_size = num_gpus * samples_per_gpu 103 | num_workers = num_gpus * workers_per_gpu 104 | 105 | init_fn = ( 106 | partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) 107 | if seed is not None 108 | else None 109 | ) 110 | 111 | data_loader = DataLoader( 112 | dataset, 113 | batch_size=batch_size, 114 | sampler=sampler, 115 | batch_sampler=batch_sampler, 116 | num_workers=num_workers, 117 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 118 | pin_memory=False, 119 | worker_init_fn=init_fn, 120 | **kwargs 121 | ) 122 | 123 | return data_loader 124 | 125 | 126 | def worker_init_fn(worker_id, num_workers, rank, seed): 127 | # The seed of each worker equals to 128 | # num_worker * rank + worker_id + user_seed 129 | worker_seed = num_workers * rank + worker_id + seed 130 | np.random.seed(worker_seed) 131 | random.seed(worker_seed) 132 | 133 | 134 | # Copyright (c) OpenMMLab. All rights reserved. 135 | import platform 136 | from mmcv.utils import Registry, build_from_cfg 137 | 138 | from mmdet.datasets import DATASETS 139 | from mmdet.datasets.builder import _concat_dataset 140 | 141 | if platform.system() != "Windows": 142 | # https://github.com/pytorch/pytorch/issues/973 143 | import resource 144 | 145 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 146 | base_soft_limit = rlimit[0] 147 | hard_limit = rlimit[1] 148 | soft_limit = min(max(4096, base_soft_limit), hard_limit) 149 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 150 | 151 | OBJECTSAMPLERS = Registry("Object sampler") 152 | 153 | 154 | def custom_build_dataset(cfg, default_args=None): 155 | try: 156 | from mmdet3d.datasets.dataset_wrappers import CBGSDataset 157 | except: 158 | CBGSDataset = None 159 | from mmdet.datasets.dataset_wrappers import ( 160 | ClassBalancedDataset, 161 | ConcatDataset, 162 | RepeatDataset, 163 | ) 164 | 165 | if isinstance(cfg, (list, tuple)): 166 | dataset = ConcatDataset( 167 | [custom_build_dataset(c, default_args) for c in cfg] 168 | ) 169 | elif cfg["type"] == "ConcatDataset": 170 | dataset = ConcatDataset( 171 | [custom_build_dataset(c, default_args) for c in cfg["datasets"]], 172 | cfg.get("separate_eval", True), 173 | ) 174 | elif cfg["type"] == "RepeatDataset": 175 | dataset = RepeatDataset( 176 | custom_build_dataset(cfg["dataset"], default_args), cfg["times"] 177 | ) 178 | elif cfg["type"] == "ClassBalancedDataset": 179 | dataset = ClassBalancedDataset( 180 | custom_build_dataset(cfg["dataset"], default_args), 181 | cfg["oversample_thr"], 182 | ) 183 | elif cfg["type"] == "CBGSDataset": 184 | dataset = CBGSDataset( 185 | custom_build_dataset(cfg["dataset"], default_args) 186 | ) 187 | elif isinstance(cfg.get("ann_file"), (list, tuple)): 188 | dataset = _concat_dataset(cfg, default_args) 189 | else: 190 | dataset = build_from_cfg(cfg, DATASETS, default_args) 191 | 192 | return dataset 193 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swc-17/SparseDrive/52c4c05b6d446b710c8a12eb9fb19d698b33cb2b/projects/mmdet3d_plugin/datasets/evaluation/__init__.py -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/evaluation/map/AP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .distance import chamfer_distance, frechet_distance, chamfer_distance_batch 3 | from typing import List, Tuple, Union 4 | from numpy.typing import NDArray 5 | 6 | def average_precision(recalls, precisions, mode='area'): 7 | """Calculate average precision. 8 | 9 | Args: 10 | recalls (ndarray): shape (num_dets, ) 11 | precisions (ndarray): shape (num_dets, ) 12 | mode (str): 'area' or '11points', 'area' means calculating the area 13 | under precision-recall curve, '11points' means calculating 14 | the average precision of recalls at [0, 0.1, ..., 1] 15 | 16 | Returns: 17 | float: calculated average precision 18 | """ 19 | 20 | recalls = recalls[np.newaxis, :] 21 | precisions = precisions[np.newaxis, :] 22 | 23 | assert recalls.shape == precisions.shape and recalls.ndim == 2 24 | num_scales = recalls.shape[0] 25 | ap = 0. 26 | 27 | if mode == 'area': 28 | zeros = np.zeros((num_scales, 1), dtype=recalls.dtype) 29 | ones = np.ones((num_scales, 1), dtype=recalls.dtype) 30 | mrec = np.hstack((zeros, recalls, ones)) 31 | mpre = np.hstack((zeros, precisions, zeros)) 32 | for i in range(mpre.shape[1] - 1, 0, -1): 33 | mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i]) 34 | 35 | ind = np.where(mrec[0, 1:] != mrec[0, :-1])[0] 36 | ap = np.sum( 37 | (mrec[0, ind + 1] - mrec[0, ind]) * mpre[0, ind + 1]) 38 | 39 | elif mode == '11points': 40 | for thr in np.arange(0, 1 + 1e-3, 0.1): 41 | precs = precisions[0, recalls[i, :] >= thr] 42 | prec = precs.max() if precs.size > 0 else 0 43 | ap += prec 44 | ap /= 11 45 | else: 46 | raise ValueError( 47 | 'Unrecognized mode, only "area" and "11points" are supported') 48 | 49 | return ap 50 | 51 | def instance_match(pred_lines: NDArray, 52 | scores: NDArray, 53 | gt_lines: NDArray, 54 | thresholds: Union[Tuple, List], 55 | metric: str='chamfer') -> List: 56 | """Compute whether detected lines are true positive or false positive. 57 | 58 | Args: 59 | pred_lines (array): Detected lines of a sample, of shape (M, INTERP_NUM, 2 or 3). 60 | scores (array): Confidence score of each line, of shape (M, ). 61 | gt_lines (array): GT lines of a sample, of shape (N, INTERP_NUM, 2 or 3). 62 | thresholds (list of tuple): List of thresholds. 63 | metric (str): Distance function for lines matching. Default: 'chamfer'. 64 | 65 | Returns: 66 | list_of_tp_fp (list): tp-fp matching result at all thresholds 67 | """ 68 | 69 | if metric == 'chamfer': 70 | distance_fn = chamfer_distance 71 | 72 | elif metric == 'frechet': 73 | distance_fn = frechet_distance 74 | 75 | else: 76 | raise ValueError(f'unknown distance function {metric}') 77 | 78 | num_preds = pred_lines.shape[0] 79 | num_gts = gt_lines.shape[0] 80 | 81 | # tp and fp 82 | tp_fp_list = [] 83 | tp = np.zeros((num_preds), dtype=np.float32) 84 | fp = np.zeros((num_preds), dtype=np.float32) 85 | 86 | # if there is no gt lines in this sample, then all pred lines are false positives 87 | if num_gts == 0: 88 | fp[...] = 1 89 | for thr in thresholds: 90 | tp_fp_list.append((tp.copy(), fp.copy())) 91 | return tp_fp_list 92 | 93 | if num_preds == 0: 94 | for thr in thresholds: 95 | tp_fp_list.append((tp.copy(), fp.copy())) 96 | return tp_fp_list 97 | 98 | assert pred_lines.shape[1] == gt_lines.shape[1], \ 99 | "sample points num should be the same" 100 | 101 | # distance matrix: M x N 102 | matrix = np.zeros((num_preds, num_gts)) 103 | 104 | # for i in range(num_preds): 105 | # for j in range(num_gts): 106 | # matrix[i, j] = distance_fn(pred_lines[i], gt_lines[j]) 107 | 108 | matrix = chamfer_distance_batch(pred_lines, gt_lines) 109 | # for each det, the min distance with all gts 110 | matrix_min = matrix.min(axis=1) 111 | 112 | # for each det, which gt is the closest to it 113 | matrix_argmin = matrix.argmin(axis=1) 114 | # sort all dets in descending order by scores 115 | sort_inds = np.argsort(-scores) 116 | 117 | # match under different thresholds 118 | for thr in thresholds: 119 | tp = np.zeros((num_preds), dtype=np.float32) 120 | fp = np.zeros((num_preds), dtype=np.float32) 121 | 122 | gt_covered = np.zeros(num_gts, dtype=bool) 123 | for i in sort_inds: 124 | if matrix_min[i] <= thr: 125 | matched_gt = matrix_argmin[i] 126 | if not gt_covered[matched_gt]: 127 | gt_covered[matched_gt] = True 128 | tp[i] = 1 129 | else: 130 | fp[i] = 1 131 | else: 132 | fp[i] = 1 133 | 134 | tp_fp_list.append((tp, fp)) 135 | 136 | return tp_fp_list -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/evaluation/map/distance.py: -------------------------------------------------------------------------------- 1 | from scipy.spatial import distance 2 | from numpy.typing import NDArray 3 | import torch 4 | 5 | def chamfer_distance(line1: NDArray, line2: NDArray) -> float: 6 | ''' Calculate chamfer distance between two lines. Make sure the 7 | lines are interpolated. 8 | 9 | Args: 10 | line1 (array): coordinates of line1 11 | line2 (array): coordinates of line2 12 | 13 | Returns: 14 | distance (float): chamfer distance 15 | ''' 16 | 17 | dist_matrix = distance.cdist(line1, line2, 'euclidean') 18 | dist12 = dist_matrix.min(-1).sum() / len(line1) 19 | dist21 = dist_matrix.min(-2).sum() / len(line2) 20 | 21 | return (dist12 + dist21) / 2 22 | 23 | def frechet_distance(line1: NDArray, line2: NDArray) -> float: 24 | ''' Calculate frechet distance between two lines. Make sure the 25 | lines are interpolated. 26 | 27 | Args: 28 | line1 (array): coordinates of line1 29 | line2 (array): coordinates of line2 30 | 31 | Returns: 32 | distance (float): frechet distance 33 | ''' 34 | 35 | raise NotImplementedError 36 | 37 | def chamfer_distance_batch(pred_lines, gt_lines): 38 | ''' Calculate chamfer distance between two group of lines. Make sure the 39 | lines are interpolated. 40 | 41 | Args: 42 | pred_lines (array or tensor): shape (m, num_pts, 2 or 3) 43 | gt_lines (array or tensor): shape (n, num_pts, 2 or 3) 44 | 45 | Returns: 46 | distance (array): chamfer distance 47 | ''' 48 | _, num_pts, coord_dims = pred_lines.shape 49 | 50 | if not isinstance(pred_lines, torch.Tensor): 51 | pred_lines = torch.tensor(pred_lines) 52 | if not isinstance(gt_lines, torch.Tensor): 53 | gt_lines = torch.tensor(gt_lines) 54 | dist_mat = torch.cdist(pred_lines.view(-1, coord_dims), 55 | gt_lines.view(-1, coord_dims), p=2) 56 | # (num_query*num_points, num_gt*num_points) 57 | dist_mat = torch.stack(torch.split(dist_mat, num_pts)) 58 | # (num_query, num_points, num_gt*num_points) 59 | dist_mat = torch.stack(torch.split(dist_mat, num_pts, dim=-1)) 60 | # (num_gt, num_q, num_pts, num_pts) 61 | 62 | dist1 = dist_mat.min(-1)[0].sum(-1) 63 | dist2 = dist_mat.min(-2)[0].sum(-1) 64 | 65 | dist_matrix = (dist1 + dist2).transpose(0, 1) / (2 * num_pts) 66 | 67 | return dist_matrix.numpy() -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/evaluation/planning/planning_eval.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from shapely.geometry import Polygon 6 | 7 | from mmcv.utils import print_log 8 | from mmdet.datasets import build_dataset, build_dataloader 9 | 10 | from projects.mmdet3d_plugin.datasets.utils import box3d_to_corners 11 | 12 | 13 | def check_collision(ego_box, boxes): 14 | ''' 15 | ego_box: tensor with shape [7], [x, y, z, w, l, h, yaw] 16 | boxes: tensor with shape [N, 7] 17 | ''' 18 | if boxes.shape[0] == 0: 19 | return False 20 | 21 | # follow uniad, add a 0.5m offset 22 | ego_box[0] += 0.5 * torch.cos(ego_box[6]) 23 | ego_box[1] += 0.5 * torch.sin(ego_box[6]) 24 | ego_corners_box = box3d_to_corners(ego_box.unsqueeze(0))[0, [0, 3, 7, 4], :2] 25 | corners_box = box3d_to_corners(boxes)[:, [0, 3, 7, 4], :2] 26 | ego_poly = Polygon([(point[0], point[1]) for point in ego_corners_box]) 27 | for i in range(len(corners_box)): 28 | box_poly = Polygon([(point[0], point[1]) for point in corners_box[i]]) 29 | collision = ego_poly.intersects(box_poly) 30 | if collision: 31 | return True 32 | 33 | return False 34 | 35 | def get_yaw(traj): 36 | start = traj[0] 37 | end = traj[-1] 38 | dist = torch.linalg.norm(end - start, dim=-1) 39 | if dist < 0.5: 40 | return traj.new_ones(traj.shape[0]) * np.pi / 2 41 | 42 | zeros = traj.new_zeros((1, 2)) 43 | traj_cat = torch.cat([zeros, traj], dim=0) 44 | yaw = traj.new_zeros(traj.shape[0]+1) 45 | yaw[..., 1:-1] = torch.atan2( 46 | traj_cat[..., 2:, 1] - traj_cat[..., :-2, 1], 47 | traj_cat[..., 2:, 0] - traj_cat[..., :-2, 0], 48 | ) 49 | yaw[..., -1] = torch.atan2( 50 | traj_cat[..., -1, 1] - traj_cat[..., -2, 1], 51 | traj_cat[..., -1, 0] - traj_cat[..., -2, 0], 52 | ) 53 | return yaw[1:] 54 | 55 | class PlanningMetric(): 56 | def __init__( 57 | self, 58 | n_future=6, 59 | compute_on_step: bool = False, 60 | ): 61 | self.W = 1.85 62 | self.H = 4.084 63 | 64 | self.n_future = n_future 65 | self.reset() 66 | 67 | def reset(self): 68 | self.obj_col = torch.zeros(self.n_future) 69 | self.obj_box_col = torch.zeros(self.n_future) 70 | self.L2 = torch.zeros(self.n_future) 71 | self.total = torch.tensor(0) 72 | 73 | def evaluate_single_coll(self, traj, fut_boxes): 74 | n_future = traj.shape[0] 75 | yaw = get_yaw(traj) 76 | ego_box = traj.new_zeros((n_future, 7)) 77 | ego_box[:, :2] = traj 78 | ego_box[:, 3:6] = ego_box.new_tensor([self.H, self.W, 1.56]) 79 | ego_box[:, 6] = yaw 80 | collision = torch.zeros(n_future, dtype=torch.bool) 81 | 82 | for t in range(n_future): 83 | ego_box_t = ego_box[t].clone() 84 | boxes = fut_boxes[t][0].clone() 85 | collision[t] = check_collision(ego_box_t, boxes) 86 | return collision 87 | 88 | def evaluate_coll(self, trajs, gt_trajs, fut_boxes): 89 | B, n_future, _ = trajs.shape 90 | trajs = trajs * torch.tensor([-1, 1], device=trajs.device) 91 | gt_trajs = gt_trajs * torch.tensor([-1, 1], device=gt_trajs.device) 92 | 93 | obj_coll_sum = torch.zeros(n_future, device=trajs.device) 94 | obj_box_coll_sum = torch.zeros(n_future, device=trajs.device) 95 | 96 | assert B == 1, 'only supprt bs=1' 97 | for i in range(B): 98 | gt_box_coll = self.evaluate_single_coll(gt_trajs[i], fut_boxes) 99 | box_coll = self.evaluate_single_coll(trajs[i], fut_boxes) 100 | box_coll = torch.logical_and(box_coll, torch.logical_not(gt_box_coll)) 101 | 102 | obj_coll_sum += gt_box_coll.long() 103 | obj_box_coll_sum += box_coll.long() 104 | 105 | return obj_coll_sum, obj_box_coll_sum 106 | 107 | def compute_L2(self, trajs, gt_trajs, gt_trajs_mask): 108 | ''' 109 | trajs: torch.Tensor (B, n_future, 3) 110 | gt_trajs: torch.Tensor (B, n_future, 3) 111 | ''' 112 | return torch.sqrt((((trajs[:, :, :2] - gt_trajs[:, :, :2]) ** 2) * gt_trajs_mask).sum(dim=-1)) 113 | 114 | def update(self, trajs, gt_trajs, gt_trajs_mask, fut_boxes): 115 | assert trajs.shape == gt_trajs.shape 116 | trajs[..., 0] = - trajs[..., 0] 117 | gt_trajs[..., 0] = - gt_trajs[..., 0] 118 | L2 = self.compute_L2(trajs, gt_trajs, gt_trajs_mask) 119 | obj_coll_sum, obj_box_coll_sum = self.evaluate_coll(trajs[:,:,:2], gt_trajs[:,:,:2], fut_boxes) 120 | 121 | self.obj_col += obj_coll_sum 122 | self.obj_box_col += obj_box_coll_sum 123 | self.L2 += L2.sum(dim=0) 124 | self.total +=len(trajs) 125 | 126 | def compute(self): 127 | return { 128 | 'obj_col': self.obj_col / self.total, 129 | 'obj_box_col': self.obj_box_col / self.total, 130 | 'L2' : self.L2 / self.total 131 | } 132 | 133 | 134 | def planning_eval(results, eval_config, logger): 135 | dataset = build_dataset(eval_config) 136 | dataloader = build_dataloader( 137 | dataset, samples_per_gpu=1, workers_per_gpu=1, shuffle=False, dist=False) 138 | planning_metrics = PlanningMetric() 139 | for i, data in enumerate(tqdm(dataloader)): 140 | sdc_planning = data['gt_ego_fut_trajs'].cumsum(dim=-2).unsqueeze(1) 141 | sdc_planning_mask = data['gt_ego_fut_masks'].unsqueeze(-1).repeat(1, 1, 2).unsqueeze(1) 142 | command = data['gt_ego_fut_cmd'].argmax(dim=-1).item() 143 | fut_boxes = data['fut_boxes'] 144 | if not sdc_planning_mask.all(): ## for incomplete gt, we do not count this sample 145 | continue 146 | res = results[i] 147 | pred_sdc_traj = res['img_bbox']['final_planning'].unsqueeze(0) 148 | planning_metrics.update(pred_sdc_traj[:, :6, :2], sdc_planning[0,:, :6, :2], sdc_planning_mask[0,:, :6, :2], fut_boxes) 149 | 150 | planning_results = planning_metrics.compute() 151 | planning_metrics.reset() 152 | from prettytable import PrettyTable 153 | planning_tab = PrettyTable() 154 | metric_dict = {} 155 | 156 | planning_tab.field_names = [ 157 | "metrics", "0.5s", "1.0s", "1.5s", "2.0s", "2.5s", "3.0s", "avg"] 158 | for key in planning_results.keys(): 159 | value = planning_results[key].tolist() 160 | new_values = [] 161 | for i in range(len(value)): 162 | new_values.append(np.array(value[:i+1]).mean()) 163 | value = new_values 164 | avg = [value[1], value[3], value[5]] 165 | avg = sum(avg) / len(avg) 166 | value.append(avg) 167 | metric_dict[key] = avg 168 | row_value = [] 169 | row_value.append(key) 170 | for i in range(len(value)): 171 | if 'col' in key: 172 | row_value.append('%.3f' % float(value[i]*100) + '%') 173 | else: 174 | row_value.append('%.4f' % float(value[i])) 175 | planning_tab.add_row(row_value) 176 | 177 | print_log('\n'+str(planning_tab), logger=logger) 178 | return metric_dict 179 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/map_utils/nuscmap_extractor.py: -------------------------------------------------------------------------------- 1 | from shapely.geometry import LineString, box, Polygon 2 | from shapely import ops, strtree 3 | 4 | import numpy as np 5 | from nuscenes.map_expansion.map_api import NuScenesMap, NuScenesMapExplorer 6 | from nuscenes.eval.common.utils import quaternion_yaw 7 | from pyquaternion import Quaternion 8 | from .utils import split_collections, get_drivable_area_contour, \ 9 | get_ped_crossing_contour 10 | from numpy.typing import NDArray 11 | from typing import Dict, List, Tuple, Union 12 | 13 | class NuscMapExtractor(object): 14 | """NuScenes map ground-truth extractor. 15 | 16 | Args: 17 | data_root (str): path to nuScenes dataset 18 | roi_size (tuple or list): bev range 19 | """ 20 | def __init__(self, data_root: str, roi_size: Union[List, Tuple]) -> None: 21 | self.roi_size = roi_size 22 | self.MAPS = ['boston-seaport', 'singapore-hollandvillage', 23 | 'singapore-onenorth', 'singapore-queenstown'] 24 | 25 | self.nusc_maps = {} 26 | self.map_explorer = {} 27 | for loc in self.MAPS: 28 | self.nusc_maps[loc] = NuScenesMap( 29 | dataroot=data_root, map_name=loc) 30 | self.map_explorer[loc] = NuScenesMapExplorer(self.nusc_maps[loc]) 31 | 32 | # local patch in nuScenes format 33 | self.local_patch = box(-roi_size[0] / 2, -roi_size[1] / 2, 34 | roi_size[0] / 2, roi_size[1] / 2) 35 | 36 | def _union_ped(self, ped_geoms: List[Polygon]) -> List[Polygon]: 37 | ''' merge close ped crossings. 38 | 39 | Args: 40 | ped_geoms (list): list of Polygon 41 | 42 | Returns: 43 | union_ped_geoms (Dict): merged ped crossings 44 | ''' 45 | 46 | def get_rec_direction(geom): 47 | rect = geom.minimum_rotated_rectangle 48 | rect_v_p = np.array(rect.exterior.coords)[:3] 49 | rect_v = rect_v_p[1:]-rect_v_p[:-1] 50 | v_len = np.linalg.norm(rect_v, axis=-1) 51 | longest_v_i = v_len.argmax() 52 | 53 | return rect_v[longest_v_i], v_len[longest_v_i] 54 | 55 | tree = strtree.STRtree(ped_geoms) 56 | index_by_id = dict((id(pt), i) for i, pt in enumerate(ped_geoms)) 57 | 58 | final_pgeom = [] 59 | remain_idx = [i for i in range(len(ped_geoms))] 60 | for i, pgeom in enumerate(ped_geoms): 61 | 62 | if i not in remain_idx: 63 | continue 64 | # update 65 | remain_idx.pop(remain_idx.index(i)) 66 | pgeom_v, pgeom_v_norm = get_rec_direction(pgeom) 67 | final_pgeom.append(pgeom) 68 | 69 | for o in tree.query(pgeom): 70 | o_idx = index_by_id[id(o)] 71 | if o_idx not in remain_idx: 72 | continue 73 | 74 | o_v, o_v_norm = get_rec_direction(o) 75 | cos = pgeom_v.dot(o_v)/(pgeom_v_norm*o_v_norm) 76 | if 1 - np.abs(cos) < 0.01: # theta < 8 degrees. 77 | final_pgeom[-1] =\ 78 | final_pgeom[-1].union(o) 79 | # update 80 | remain_idx.pop(remain_idx.index(o_idx)) 81 | 82 | results = [] 83 | for p in final_pgeom: 84 | results.extend(split_collections(p)) 85 | return results 86 | 87 | def get_map_geom(self, 88 | location: str, 89 | translation: Union[List, NDArray], 90 | rotation: Union[List, NDArray]) -> Dict[str, List[Union[LineString, Polygon]]]: 91 | ''' Extract geometries given `location` and self pose, self may be lidar or ego. 92 | 93 | Args: 94 | location (str): city name 95 | translation (array): self2global translation, shape (3,) 96 | rotation (array): self2global quaternion, shape (4, ) 97 | 98 | Returns: 99 | geometries (Dict): extracted geometries by category. 100 | ''' 101 | 102 | # (center_x, center_y, len_y, len_x) in nuscenes format 103 | patch_box = (translation[0], translation[1], 104 | self.roi_size[1], self.roi_size[0]) 105 | rotation = Quaternion(rotation) 106 | yaw = quaternion_yaw(rotation) / np.pi * 180 107 | 108 | # get dividers 109 | lane_dividers = self.map_explorer[location]._get_layer_line( 110 | patch_box, yaw, 'lane_divider') 111 | 112 | road_dividers = self.map_explorer[location]._get_layer_line( 113 | patch_box, yaw, 'road_divider') 114 | 115 | all_dividers = [] 116 | for line in lane_dividers + road_dividers: 117 | all_dividers += split_collections(line) 118 | 119 | # get ped crossings 120 | ped_crossings = [] 121 | ped = self.map_explorer[location]._get_layer_polygon( 122 | patch_box, yaw, 'ped_crossing') 123 | 124 | for p in ped: 125 | ped_crossings += split_collections(p) 126 | # some ped crossings are split into several small parts 127 | # we need to merge them 128 | ped_crossings = self._union_ped(ped_crossings) 129 | 130 | ped_crossing_lines = [] 131 | for p in ped_crossings: 132 | # extract exteriors to get a closed polyline 133 | line = get_ped_crossing_contour(p, self.local_patch) 134 | if line is not None: 135 | ped_crossing_lines.append(line) 136 | 137 | # get boundaries 138 | # we take the union of road segments and lanes as drivable areas 139 | # we don't take drivable area layer in nuScenes since its definition may be ambiguous 140 | road_segments = self.map_explorer[location]._get_layer_polygon( 141 | patch_box, yaw, 'road_segment') 142 | lanes = self.map_explorer[location]._get_layer_polygon( 143 | patch_box, yaw, 'lane') 144 | union_roads = ops.unary_union(road_segments) 145 | union_lanes = ops.unary_union(lanes) 146 | drivable_areas = ops.unary_union([union_roads, union_lanes]) 147 | 148 | drivable_areas = split_collections(drivable_areas) 149 | 150 | # boundaries are defined as the contour of drivable areas 151 | boundaries = get_drivable_area_contour(drivable_areas, self.roi_size) 152 | 153 | return dict( 154 | divider=all_dividers, # List[LineString] 155 | ped_crossing=ped_crossing_lines, # List[LineString] 156 | boundary=boundaries, # List[LineString] 157 | drivable_area=drivable_areas, # List[Polygon], 158 | ) 159 | 160 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/map_utils/utils.py: -------------------------------------------------------------------------------- 1 | from shapely.geometry import LineString, box, Polygon, LinearRing 2 | from shapely.geometry.base import BaseGeometry 3 | from shapely import ops 4 | import numpy as np 5 | from scipy.spatial import distance 6 | from typing import List, Optional, Tuple 7 | from numpy.typing import NDArray 8 | 9 | def split_collections(geom: BaseGeometry) -> List[Optional[BaseGeometry]]: 10 | ''' Split Multi-geoms to list and check is valid or is empty. 11 | 12 | Args: 13 | geom (BaseGeometry): geoms to be split or validate. 14 | 15 | Returns: 16 | geometries (List): list of geometries. 17 | ''' 18 | assert geom.geom_type in ['MultiLineString', 'LineString', 'MultiPolygon', 19 | 'Polygon', 'GeometryCollection'], f"got geom type {geom.geom_type}" 20 | if 'Multi' in geom.geom_type: 21 | outs = [] 22 | for g in geom.geoms: 23 | if g.is_valid and not g.is_empty: 24 | outs.append(g) 25 | return outs 26 | else: 27 | if geom.is_valid and not geom.is_empty: 28 | return [geom,] 29 | else: 30 | return [] 31 | 32 | def get_drivable_area_contour(drivable_areas: List[Polygon], 33 | roi_size: Tuple) -> List[LineString]: 34 | ''' Extract drivable area contours to get list of boundaries. 35 | 36 | Args: 37 | drivable_areas (list): list of drivable areas. 38 | roi_size (tuple): bev range size 39 | 40 | Returns: 41 | boundaries (List): list of boundaries. 42 | ''' 43 | max_x = roi_size[0] / 2 44 | max_y = roi_size[1] / 2 45 | 46 | # a bit smaller than roi to avoid unexpected boundaries on edges 47 | local_patch = box(-max_x + 0.2, -max_y + 0.2, max_x - 0.2, max_y - 0.2) 48 | 49 | exteriors = [] 50 | interiors = [] 51 | 52 | for poly in drivable_areas: 53 | exteriors.append(poly.exterior) 54 | for inter in poly.interiors: 55 | interiors.append(inter) 56 | 57 | results = [] 58 | for ext in exteriors: 59 | # NOTE: we make sure all exteriors are clock-wise 60 | # such that each boundary's right-hand-side is drivable area 61 | # and left-hand-side is walk way 62 | 63 | if ext.is_ccw: 64 | ext = LinearRing(list(ext.coords)[::-1]) 65 | lines = ext.intersection(local_patch) 66 | if lines.geom_type == 'MultiLineString': 67 | lines = ops.linemerge(lines) 68 | assert lines.geom_type in ['MultiLineString', 'LineString'] 69 | 70 | results.extend(split_collections(lines)) 71 | 72 | for inter in interiors: 73 | # NOTE: we make sure all interiors are counter-clock-wise 74 | if not inter.is_ccw: 75 | inter = LinearRing(list(inter.coords)[::-1]) 76 | lines = inter.intersection(local_patch) 77 | if lines.geom_type == 'MultiLineString': 78 | lines = ops.linemerge(lines) 79 | assert lines.geom_type in ['MultiLineString', 'LineString'] 80 | 81 | results.extend(split_collections(lines)) 82 | 83 | return results 84 | 85 | def get_ped_crossing_contour(polygon: Polygon, 86 | local_patch: box) -> Optional[LineString]: 87 | ''' Extract ped crossing contours to get a closed polyline. 88 | Different from `get_drivable_area_contour`, this function ensures a closed polyline. 89 | 90 | Args: 91 | polygon (Polygon): ped crossing polygon to be extracted. 92 | local_patch (tuple): local patch params 93 | 94 | Returns: 95 | line (LineString): a closed line 96 | ''' 97 | 98 | ext = polygon.exterior 99 | if not ext.is_ccw: 100 | ext = LinearRing(list(ext.coords)[::-1]) 101 | lines = ext.intersection(local_patch) 102 | if lines.type != 'LineString': 103 | # remove points in intersection results 104 | lines = [l for l in lines.geoms if l.geom_type != 'Point'] 105 | lines = ops.linemerge(lines) 106 | 107 | # same instance but not connected. 108 | if lines.type != 'LineString': 109 | ls = [] 110 | for l in lines.geoms: 111 | ls.append(np.array(l.coords)) 112 | 113 | lines = np.concatenate(ls, axis=0) 114 | lines = LineString(lines) 115 | if not lines.is_empty: 116 | return lines 117 | 118 | return None 119 | 120 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | InstanceNameFilter, 3 | CircleObjectRangeFilter, 4 | NormalizeMultiviewImage, 5 | NuScenesSparse4DAdaptor, 6 | MultiScaleDepthMapGenerator, 7 | ) 8 | from .augment import ( 9 | ResizeCropFlipImage, 10 | BBoxRotation, 11 | PhotoMetricDistortionMultiViewImage, 12 | ) 13 | from .loading import LoadMultiViewImageFromFiles, LoadPointsFromFile 14 | from .vectorize import VectorizeMap 15 | 16 | __all__ = [ 17 | "InstanceNameFilter", 18 | "ResizeCropFlipImage", 19 | "BBoxRotation", 20 | "CircleObjectRangeFilter", 21 | "MultiScaleDepthMapGenerator", 22 | "NormalizeMultiviewImage", 23 | "PhotoMetricDistortionMultiViewImage", 24 | "NuScenesSparse4DAdaptor", 25 | "LoadMultiViewImageFromFiles", 26 | "LoadPointsFromFile", 27 | "VectorizeMap", 28 | ] 29 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/pipelines/augment.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | from numpy import random 5 | import mmcv 6 | from mmdet.datasets.builder import PIPELINES 7 | from PIL import Image 8 | 9 | 10 | @PIPELINES.register_module() 11 | class ResizeCropFlipImage(object): 12 | def __call__(self, results): 13 | aug_config = results.get("aug_config") 14 | if aug_config is None: 15 | return results 16 | imgs = results["img"] 17 | N = len(imgs) 18 | new_imgs = [] 19 | for i in range(N): 20 | img, mat = self._img_transform( 21 | np.uint8(imgs[i]), aug_config, 22 | ) 23 | new_imgs.append(np.array(img).astype(np.float32)) 24 | results["lidar2img"][i] = mat @ results["lidar2img"][i] 25 | if "cam_intrinsic" in results: 26 | results["cam_intrinsic"][i][:3, :3] *= aug_config["resize"] 27 | # results["cam_intrinsic"][i][:3, :3] = ( 28 | # mat[:3, :3] @ results["cam_intrinsic"][i][:3, :3] 29 | # ) 30 | 31 | results["img"] = new_imgs 32 | results["img_shape"] = [x.shape[:2] for x in new_imgs] 33 | return results 34 | 35 | def _img_transform(self, img, aug_configs): 36 | H, W = img.shape[:2] 37 | resize = aug_configs.get("resize", 1) 38 | resize_dims = (int(W * resize), int(H * resize)) 39 | crop = aug_configs.get("crop", [0, 0, *resize_dims]) 40 | flip = aug_configs.get("flip", False) 41 | rotate = aug_configs.get("rotate", 0) 42 | 43 | origin_dtype = img.dtype 44 | if origin_dtype != np.uint8: 45 | min_value = img.min() 46 | max_vaule = img.max() 47 | scale = 255 / (max_vaule - min_value) 48 | img = (img - min_value) * scale 49 | img = np.uint8(img) 50 | img = Image.fromarray(img) 51 | img = img.resize(resize_dims).crop(crop) 52 | if flip: 53 | img = img.transpose(method=Image.FLIP_LEFT_RIGHT) 54 | img = img.rotate(rotate) 55 | img = np.array(img).astype(np.float32) 56 | if origin_dtype != np.uint8: 57 | img = img.astype(np.float32) 58 | img = img / scale + min_value 59 | 60 | transform_matrix = np.eye(3) 61 | transform_matrix[:2, :2] *= resize 62 | transform_matrix[:2, 2] -= np.array(crop[:2]) 63 | if flip: 64 | flip_matrix = np.array( 65 | [[-1, 0, crop[2] - crop[0]], [0, 1, 0], [0, 0, 1]] 66 | ) 67 | transform_matrix = flip_matrix @ transform_matrix 68 | rotate = rotate / 180 * np.pi 69 | rot_matrix = np.array( 70 | [ 71 | [np.cos(rotate), np.sin(rotate), 0], 72 | [-np.sin(rotate), np.cos(rotate), 0], 73 | [0, 0, 1], 74 | ] 75 | ) 76 | rot_center = np.array([crop[2] - crop[0], crop[3] - crop[1]]) / 2 77 | rot_matrix[:2, 2] = -rot_matrix[:2, :2] @ rot_center + rot_center 78 | transform_matrix = rot_matrix @ transform_matrix 79 | extend_matrix = np.eye(4) 80 | extend_matrix[:3, :3] = transform_matrix 81 | return img, extend_matrix 82 | 83 | 84 | @PIPELINES.register_module() 85 | class BBoxRotation(object): 86 | def __call__(self, results): 87 | angle = results["aug_config"]["rotate_3d"] 88 | rot_cos = np.cos(angle) 89 | rot_sin = np.sin(angle) 90 | 91 | rot_mat = np.array( 92 | [ 93 | [rot_cos, -rot_sin, 0, 0], 94 | [rot_sin, rot_cos, 0, 0], 95 | [0, 0, 1, 0], 96 | [0, 0, 0, 1], 97 | ] 98 | ) 99 | rot_mat_inv = np.linalg.inv(rot_mat) 100 | 101 | num_view = len(results["lidar2img"]) 102 | for view in range(num_view): 103 | results["lidar2img"][view] = ( 104 | results["lidar2img"][view] @ rot_mat_inv 105 | ) 106 | if "lidar2global" in results: 107 | results["lidar2global"] = results["lidar2global"] @ rot_mat_inv 108 | if "gt_bboxes_3d" in results: 109 | results["gt_bboxes_3d"] = self.box_rotate( 110 | results["gt_bboxes_3d"], angle 111 | ) 112 | return results 113 | 114 | @staticmethod 115 | def box_rotate(bbox_3d, angle): 116 | rot_cos = np.cos(angle) 117 | rot_sin = np.sin(angle) 118 | rot_mat_T = np.array( 119 | [[rot_cos, rot_sin, 0], [-rot_sin, rot_cos, 0], [0, 0, 1]] 120 | ) 121 | bbox_3d[:, :3] = bbox_3d[:, :3] @ rot_mat_T 122 | bbox_3d[:, 6] += angle 123 | if bbox_3d.shape[-1] > 7: 124 | vel_dims = bbox_3d[:, 7:].shape[-1] 125 | bbox_3d[:, 7:] = bbox_3d[:, 7:] @ rot_mat_T[:vel_dims, :vel_dims] 126 | return bbox_3d 127 | 128 | 129 | @PIPELINES.register_module() 130 | class PhotoMetricDistortionMultiViewImage: 131 | """Apply photometric distortion to image sequentially, every transformation 132 | is applied with a probability of 0.5. The position of random contrast is in 133 | second or second to last. 134 | 1. random brightness 135 | 2. random contrast (mode 0) 136 | 3. convert color from BGR to HSV 137 | 4. random saturation 138 | 5. random hue 139 | 6. convert color from HSV to BGR 140 | 7. random contrast (mode 1) 141 | 8. randomly swap channels 142 | Args: 143 | brightness_delta (int): delta of brightness. 144 | contrast_range (tuple): range of contrast. 145 | saturation_range (tuple): range of saturation. 146 | hue_delta (int): delta of hue. 147 | """ 148 | 149 | def __init__( 150 | self, 151 | brightness_delta=32, 152 | contrast_range=(0.5, 1.5), 153 | saturation_range=(0.5, 1.5), 154 | hue_delta=18, 155 | ): 156 | self.brightness_delta = brightness_delta 157 | self.contrast_lower, self.contrast_upper = contrast_range 158 | self.saturation_lower, self.saturation_upper = saturation_range 159 | self.hue_delta = hue_delta 160 | 161 | def __call__(self, results): 162 | """Call function to perform photometric distortion on images. 163 | Args: 164 | results (dict): Result dict from loading pipeline. 165 | Returns: 166 | dict: Result dict with images distorted. 167 | """ 168 | imgs = results["img"] 169 | new_imgs = [] 170 | for img in imgs: 171 | assert img.dtype == np.float32, ( 172 | "PhotoMetricDistortion needs the input image of dtype np.float32," 173 | ' please set "to_float32=True" in "LoadImageFromFile" pipeline' 174 | ) 175 | # random brightness 176 | if random.randint(2): 177 | delta = random.uniform( 178 | -self.brightness_delta, self.brightness_delta 179 | ) 180 | img += delta 181 | 182 | # mode == 0 --> do random contrast first 183 | # mode == 1 --> do random contrast last 184 | mode = random.randint(2) 185 | if mode == 1: 186 | if random.randint(2): 187 | alpha = random.uniform( 188 | self.contrast_lower, self.contrast_upper 189 | ) 190 | img *= alpha 191 | 192 | # convert color from BGR to HSV 193 | img = mmcv.bgr2hsv(img) 194 | 195 | # random saturation 196 | if random.randint(2): 197 | img[..., 1] *= random.uniform( 198 | self.saturation_lower, self.saturation_upper 199 | ) 200 | 201 | # random hue 202 | if random.randint(2): 203 | img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) 204 | img[..., 0][img[..., 0] > 360] -= 360 205 | img[..., 0][img[..., 0] < 0] += 360 206 | 207 | # convert color from HSV to BGR 208 | img = mmcv.hsv2bgr(img) 209 | 210 | # random contrast 211 | if mode == 0: 212 | if random.randint(2): 213 | alpha = random.uniform( 214 | self.contrast_lower, self.contrast_upper 215 | ) 216 | img *= alpha 217 | 218 | # randomly swap channels 219 | if random.randint(2): 220 | img = img[..., random.permutation(3)] 221 | new_imgs.append(img) 222 | results["img"] = new_imgs 223 | return results 224 | 225 | def __repr__(self): 226 | repr_str = self.__class__.__name__ 227 | repr_str += f"(\nbrightness_delta={self.brightness_delta},\n" 228 | repr_str += "contrast_range=" 229 | repr_str += f"{(self.contrast_lower, self.contrast_upper)},\n" 230 | repr_str += "saturation_range=" 231 | repr_str += f"{(self.saturation_lower, self.saturation_upper)},\n" 232 | repr_str += f"hue_delta={self.hue_delta})" 233 | return repr_str 234 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mmcv 3 | from mmdet.datasets.builder import PIPELINES 4 | 5 | 6 | @PIPELINES.register_module() 7 | class LoadMultiViewImageFromFiles(object): 8 | """Load multi channel images from a list of separate channel files. 9 | 10 | Expects results['img_filename'] to be a list of filenames. 11 | 12 | Args: 13 | to_float32 (bool, optional): Whether to convert the img to float32. 14 | Defaults to False. 15 | color_type (str, optional): Color type of the file. 16 | Defaults to 'unchanged'. 17 | """ 18 | 19 | def __init__(self, to_float32=False, color_type="unchanged"): 20 | self.to_float32 = to_float32 21 | self.color_type = color_type 22 | 23 | def __call__(self, results): 24 | """Call function to load multi-view image from files. 25 | 26 | Args: 27 | results (dict): Result dict containing multi-view image filenames. 28 | 29 | Returns: 30 | dict: The result dict containing the multi-view image data. 31 | Added keys and values are described below. 32 | 33 | - filename (str): Multi-view image filenames. 34 | - img (np.ndarray): Multi-view image arrays. 35 | - img_shape (tuple[int]): Shape of multi-view image arrays. 36 | - ori_shape (tuple[int]): Shape of original image arrays. 37 | - pad_shape (tuple[int]): Shape of padded image arrays. 38 | - scale_factor (float): Scale factor. 39 | - img_norm_cfg (dict): Normalization configuration of images. 40 | """ 41 | filename = results["img_filename"] 42 | # img is of shape (h, w, c, num_views) 43 | img = np.stack( 44 | [mmcv.imread(name, self.color_type) for name in filename], axis=-1 45 | ) 46 | if self.to_float32: 47 | img = img.astype(np.float32) 48 | results["filename"] = filename 49 | # unravel to list, see `DefaultFormatBundle` in formatting.py 50 | # which will transpose each image separately and then stack into array 51 | results["img"] = [img[..., i] for i in range(img.shape[-1])] 52 | results["img_shape"] = img.shape 53 | results["ori_shape"] = img.shape 54 | # Set initial values for default meta_keys 55 | results["pad_shape"] = img.shape 56 | results["scale_factor"] = 1.0 57 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 58 | results["img_norm_cfg"] = dict( 59 | mean=np.zeros(num_channels, dtype=np.float32), 60 | std=np.ones(num_channels, dtype=np.float32), 61 | to_rgb=False, 62 | ) 63 | return results 64 | 65 | def __repr__(self): 66 | """str: Return a string that describes the module.""" 67 | repr_str = self.__class__.__name__ 68 | repr_str += f"(to_float32={self.to_float32}, " 69 | repr_str += f"color_type='{self.color_type}')" 70 | return repr_str 71 | 72 | 73 | @PIPELINES.register_module() 74 | class LoadPointsFromFile(object): 75 | """Load Points From File. 76 | 77 | Load points from file. 78 | 79 | Args: 80 | coord_type (str): The type of coordinates of points cloud. 81 | Available options includes: 82 | - 'LIDAR': Points in LiDAR coordinates. 83 | - 'DEPTH': Points in depth coordinates, usually for indoor dataset. 84 | - 'CAMERA': Points in camera coordinates. 85 | load_dim (int, optional): The dimension of the loaded points. 86 | Defaults to 6. 87 | use_dim (list[int], optional): Which dimensions of the points to use. 88 | Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 89 | or use_dim=[0, 1, 2, 3] to use the intensity dimension. 90 | shift_height (bool, optional): Whether to use shifted height. 91 | Defaults to False. 92 | use_color (bool, optional): Whether to use color features. 93 | Defaults to False. 94 | file_client_args (dict, optional): Config dict of file clients, 95 | refer to 96 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py 97 | for more details. Defaults to dict(backend='disk'). 98 | """ 99 | 100 | def __init__( 101 | self, 102 | coord_type, 103 | load_dim=6, 104 | use_dim=[0, 1, 2], 105 | shift_height=False, 106 | use_color=False, 107 | file_client_args=dict(backend="disk"), 108 | ): 109 | self.shift_height = shift_height 110 | self.use_color = use_color 111 | if isinstance(use_dim, int): 112 | use_dim = list(range(use_dim)) 113 | assert ( 114 | max(use_dim) < load_dim 115 | ), f"Expect all used dimensions < {load_dim}, got {use_dim}" 116 | assert coord_type in ["CAMERA", "LIDAR", "DEPTH"] 117 | 118 | self.coord_type = coord_type 119 | self.load_dim = load_dim 120 | self.use_dim = use_dim 121 | self.file_client_args = file_client_args.copy() 122 | self.file_client = None 123 | 124 | def _load_points(self, pts_filename): 125 | """Private function to load point clouds data. 126 | 127 | Args: 128 | pts_filename (str): Filename of point clouds data. 129 | 130 | Returns: 131 | np.ndarray: An array containing point clouds data. 132 | """ 133 | if self.file_client is None: 134 | self.file_client = mmcv.FileClient(**self.file_client_args) 135 | try: 136 | pts_bytes = self.file_client.get(pts_filename) 137 | points = np.frombuffer(pts_bytes, dtype=np.float32) 138 | except ConnectionError: 139 | mmcv.check_file_exist(pts_filename) 140 | if pts_filename.endswith(".npy"): 141 | points = np.load(pts_filename) 142 | else: 143 | points = np.fromfile(pts_filename, dtype=np.float32) 144 | 145 | return points 146 | 147 | def __call__(self, results): 148 | """Call function to load points data from file. 149 | 150 | Args: 151 | results (dict): Result dict containing point clouds data. 152 | 153 | Returns: 154 | dict: The result dict containing the point clouds data. 155 | Added key and value are described below. 156 | 157 | - points (:obj:`BasePoints`): Point clouds data. 158 | """ 159 | pts_filename = results["pts_filename"] 160 | points = self._load_points(pts_filename) 161 | points = points.reshape(-1, self.load_dim) 162 | points = points[:, self.use_dim] 163 | attribute_dims = None 164 | 165 | if self.shift_height: 166 | floor_height = np.percentile(points[:, 2], 0.99) 167 | height = points[:, 2] - floor_height 168 | points = np.concatenate( 169 | [points[:, :3], np.expand_dims(height, 1), points[:, 3:]], 1 170 | ) 171 | attribute_dims = dict(height=3) 172 | 173 | if self.use_color: 174 | assert len(self.use_dim) >= 6 175 | if attribute_dims is None: 176 | attribute_dims = dict() 177 | attribute_dims.update( 178 | dict( 179 | color=[ 180 | points.shape[1] - 3, 181 | points.shape[1] - 2, 182 | points.shape[1] - 1, 183 | ] 184 | ) 185 | ) 186 | 187 | results["points"] = points 188 | return results 189 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/pipelines/vectorize.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union, Dict 2 | 3 | import numpy as np 4 | from shapely.geometry import LineString 5 | from numpy.typing import NDArray 6 | 7 | from mmcv.parallel import DataContainer as DC 8 | from mmdet.datasets.builder import PIPELINES 9 | 10 | 11 | @PIPELINES.register_module(force=True) 12 | class VectorizeMap(object): 13 | """Generate vectoized map and put into `semantic_mask` key. 14 | Concretely, shapely geometry objects are converted into sample points (ndarray). 15 | We use args `sample_num`, `sample_dist`, `simplify` to specify sampling method. 16 | 17 | Args: 18 | roi_size (tuple or list): bev range . 19 | normalize (bool): whether to normalize points to range (0, 1). 20 | coords_dim (int): dimension of point coordinates. 21 | simplify (bool): whether to use simpily function. If true, `sample_num` \ 22 | and `sample_dist` will be ignored. 23 | sample_num (int): number of points to interpolate from a polyline. Set to -1 to ignore. 24 | sample_dist (float): interpolate distance. Set to -1 to ignore. 25 | """ 26 | 27 | def __init__(self, 28 | roi_size: Union[Tuple, List], 29 | normalize: bool, 30 | coords_dim: int=2, 31 | simplify: bool=False, 32 | sample_num: int=-1, 33 | sample_dist: float=-1, 34 | permute: bool=False 35 | ): 36 | self.coords_dim = coords_dim 37 | self.sample_num = sample_num 38 | self.sample_dist = sample_dist 39 | self.roi_size = np.array(roi_size) 40 | self.normalize = normalize 41 | self.simplify = simplify 42 | self.permute = permute 43 | 44 | if sample_dist > 0: 45 | assert sample_num < 0 and not simplify 46 | self.sample_fn = self.interp_fixed_dist 47 | elif sample_num > 0: 48 | assert sample_dist < 0 and not simplify 49 | self.sample_fn = self.interp_fixed_num 50 | else: 51 | assert simplify 52 | 53 | def interp_fixed_num(self, line: LineString) -> NDArray: 54 | ''' Interpolate a line to fixed number of points. 55 | 56 | Args: 57 | line (LineString): line 58 | 59 | Returns: 60 | points (array): interpolated points, shape (N, 2) 61 | ''' 62 | 63 | distances = np.linspace(0, line.length, self.sample_num) 64 | sampled_points = np.array([list(line.interpolate(distance).coords) 65 | for distance in distances]).squeeze() 66 | 67 | return sampled_points 68 | 69 | def interp_fixed_dist(self, line: LineString) -> NDArray: 70 | ''' Interpolate a line at fixed interval. 71 | 72 | Args: 73 | line (LineString): line 74 | 75 | Returns: 76 | points (array): interpolated points, shape (N, 2) 77 | ''' 78 | 79 | distances = list(np.arange(self.sample_dist, line.length, self.sample_dist)) 80 | # make sure to sample at least two points when sample_dist > line.length 81 | distances = [0,] + distances + [line.length,] 82 | 83 | sampled_points = np.array([list(line.interpolate(distance).coords) 84 | for distance in distances]).squeeze() 85 | 86 | return sampled_points 87 | 88 | def get_vectorized_lines(self, map_geoms: Dict) -> Dict: 89 | ''' Vectorize map elements. Iterate over the input dict and apply the 90 | specified sample funcion. 91 | 92 | Args: 93 | line (LineString): line 94 | 95 | Returns: 96 | vectors (array): dict of vectorized map elements. 97 | ''' 98 | 99 | vectors = {} 100 | for label, geom_list in map_geoms.items(): 101 | vectors[label] = [] 102 | for geom in geom_list: 103 | if geom.geom_type == 'LineString': 104 | if self.simplify: 105 | line = geom.simplify(0.2, preserve_topology=True) 106 | line = np.array(line.coords) 107 | else: 108 | line = self.sample_fn(geom) 109 | line = line[:, :self.coords_dim] 110 | 111 | if self.normalize: 112 | line = self.normalize_line(line) 113 | if self.permute: 114 | line = self.permute_line(line) 115 | vectors[label].append(line) 116 | 117 | elif geom.geom_type == 'Polygon': 118 | # polygon objects will not be vectorized 119 | continue 120 | 121 | else: 122 | raise ValueError('map geoms must be either LineString or Polygon!') 123 | return vectors 124 | 125 | def normalize_line(self, line: NDArray) -> NDArray: 126 | ''' Convert points to range (0, 1). 127 | 128 | Args: 129 | line (LineString): line 130 | 131 | Returns: 132 | normalized (array): normalized points. 133 | ''' 134 | 135 | origin = -np.array([self.roi_size[0]/2, self.roi_size[1]/2]) 136 | 137 | line[:, :2] = line[:, :2] - origin 138 | 139 | # transform from range [0, 1] to (0, 1) 140 | eps = 1e-5 141 | line[:, :2] = line[:, :2] / (self.roi_size + eps) 142 | 143 | return line 144 | 145 | def permute_line(self, line: np.ndarray, padding=1e5): 146 | ''' 147 | (num_pts, 2) -> (num_permute, num_pts, 2) 148 | where num_permute = 2 * (num_pts - 1) 149 | ''' 150 | is_closed = np.allclose(line[0], line[-1], atol=1e-3) 151 | num_points = len(line) 152 | permute_num = num_points - 1 153 | permute_lines_list = [] 154 | if is_closed: 155 | pts_to_permute = line[:-1, :] # throw away replicate start end pts 156 | for shift_i in range(permute_num): 157 | permute_lines_list.append(np.roll(pts_to_permute, shift_i, axis=0)) 158 | flip_pts_to_permute = np.flip(pts_to_permute, axis=0) 159 | for shift_i in range(permute_num): 160 | permute_lines_list.append(np.roll(flip_pts_to_permute, shift_i, axis=0)) 161 | else: 162 | permute_lines_list.append(line) 163 | permute_lines_list.append(np.flip(line, axis=0)) 164 | 165 | permute_lines_array = np.stack(permute_lines_list, axis=0) 166 | 167 | if is_closed: 168 | tmp = np.zeros((permute_num * 2, num_points, self.coords_dim)) 169 | tmp[:, :-1, :] = permute_lines_array 170 | tmp[:, -1, :] = permute_lines_array[:, 0, :] # add replicate start end pts 171 | permute_lines_array = tmp 172 | 173 | else: 174 | # padding 175 | padding = np.full([permute_num * 2 - 2, num_points, self.coords_dim], padding) 176 | permute_lines_array = np.concatenate((permute_lines_array, padding), axis=0) 177 | 178 | return permute_lines_array 179 | 180 | def __call__(self, input_dict): 181 | if "map_geoms" not in input_dict: 182 | return input_dict 183 | map_geoms = input_dict['map_geoms'] 184 | vectors = self.get_vectorized_lines(map_geoms) 185 | 186 | if self.permute: 187 | gt_map_labels, gt_map_pts = [], [] 188 | for label, vecs in vectors.items(): 189 | for vec in vecs: 190 | gt_map_labels.append(label) 191 | gt_map_pts.append(vec) 192 | input_dict['gt_map_labels'] = np.array(gt_map_labels, dtype=np.int64) 193 | input_dict['gt_map_pts'] = np.array(gt_map_pts, dtype=np.float32).reshape(-1, 2 * (self.sample_num - 1), self.sample_num, self.coords_dim) 194 | else: 195 | input_dict['vectors'] = DC(vectors, stack=False, cpu_only=True) 196 | 197 | return input_dict 198 | 199 | def __repr__(self): 200 | repr_str = self.__class__.__name__ 201 | repr_str += f'(simplify={self.simplify}, ' 202 | repr_str += f'sample_num={self.sample_num}), ' 203 | repr_str += f'sample_dist={self.sample_dist}), ' 204 | repr_str += f'roi_size={self.roi_size})' 205 | repr_str += f'normalize={self.normalize})' 206 | repr_str += f'coords_dim={self.coords_dim})' 207 | 208 | return repr_str -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .group_sampler import DistributedGroupSampler 2 | from .distributed_sampler import DistributedSampler 3 | from .sampler import SAMPLER, build_sampler 4 | from .group_in_batch_sampler import ( 5 | GroupInBatchSampler, 6 | ) 7 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.data import DistributedSampler as _DistributedSampler 5 | from .sampler import SAMPLER 6 | 7 | import pdb 8 | import sys 9 | 10 | 11 | class ForkedPdb(pdb.Pdb): 12 | def interaction(self, *args, **kwargs): 13 | _stdin = sys.stdin 14 | try: 15 | sys.stdin = open("/dev/stdin") 16 | pdb.Pdb.interaction(self, *args, **kwargs) 17 | finally: 18 | sys.stdin = _stdin 19 | 20 | 21 | def set_trace(): 22 | ForkedPdb().set_trace(sys._getframe().f_back) 23 | 24 | 25 | @SAMPLER.register_module() 26 | class DistributedSampler(_DistributedSampler): 27 | def __init__( 28 | self, dataset=None, num_replicas=None, rank=None, shuffle=True, seed=0 29 | ): 30 | super().__init__( 31 | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle 32 | ) 33 | # for the compatibility from PyTorch 1.3+ 34 | self.seed = seed if seed is not None else 0 35 | 36 | def __iter__(self): 37 | # deterministically shuffle based on epoch 38 | assert not self.shuffle 39 | if "data_infos" in dir(self.dataset): 40 | timestamps = [ 41 | x["timestamp"] / 1e6 for x in self.dataset.data_infos 42 | ] 43 | vehicle_idx = [ 44 | x["lidar_path"].split("/")[-1][:4] 45 | if "lidar_path" in x 46 | else None 47 | for x in self.dataset.data_infos 48 | ] 49 | else: 50 | timestamps = [ 51 | x["timestamp"] / 1e6 52 | for x in self.dataset.datasets[0].data_infos 53 | ] * len(self.dataset.datasets) 54 | vehicle_idx = [ 55 | x["lidar_path"].split("/")[-1][:4] 56 | if "lidar_path" in x 57 | else None 58 | for x in self.dataset.datasets[0].data_infos 59 | ] * len(self.dataset.datasets) 60 | 61 | sequence_splits = [] 62 | for i in range(len(timestamps)): 63 | if i == 0 or ( 64 | abs(timestamps[i] - timestamps[i - 1]) > 4 65 | or vehicle_idx[i] != vehicle_idx[i - 1] 66 | ): 67 | sequence_splits.append([i]) 68 | else: 69 | sequence_splits[-1].append(i) 70 | 71 | indices = [] 72 | perfix_sum = 0 73 | split_length = len(self.dataset) // self.num_replicas 74 | for i in range(len(sequence_splits)): 75 | if perfix_sum >= (self.rank + 1) * split_length: 76 | break 77 | elif perfix_sum >= self.rank * split_length: 78 | indices.extend(sequence_splits[i]) 79 | perfix_sum += len(sequence_splits[i]) 80 | 81 | self.num_samples = len(indices) 82 | return iter(indices) 83 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Divadi/SOLOFusion/blob/main/mmdet3d/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py 2 | import itertools 3 | import copy 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from mmcv.runner import get_dist_info 9 | from torch.utils.data.sampler import Sampler 10 | 11 | 12 | # https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157 13 | def sync_random_seed(seed=None, device="cuda"): 14 | """Make sure different ranks share the same seed. 15 | All workers must call this function, otherwise it will deadlock. 16 | This method is generally used in `DistributedSampler`, 17 | because the seed should be identical across all processes 18 | in the distributed group. 19 | In distributed sampling, different ranks should sample non-overlapped 20 | data in the dataset. Therefore, this function is used to make sure that 21 | each rank shuffles the data indices in the same order based 22 | on the same seed. Then different ranks could use different indices 23 | to select non-overlapped data from the same data list. 24 | Args: 25 | seed (int, Optional): The seed. Default to None. 26 | device (str): The device where the seed will be put on. 27 | Default to 'cuda'. 28 | Returns: 29 | int: Seed to be used. 30 | """ 31 | if seed is None: 32 | seed = np.random.randint(2**31) 33 | assert isinstance(seed, int) 34 | 35 | rank, world_size = get_dist_info() 36 | 37 | if world_size == 1: 38 | return seed 39 | 40 | if rank == 0: 41 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 42 | else: 43 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 44 | dist.broadcast(random_num, src=0) 45 | return random_num.item() 46 | 47 | 48 | class GroupInBatchSampler(Sampler): 49 | """ 50 | Pardon this horrendous name. Basically, we want every sample to be from its own group. 51 | If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on 52 | its own group. 53 | 54 | Shuffling is only done for group order, not done within groups. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | dataset, 60 | batch_size=1, 61 | world_size=None, 62 | rank=None, 63 | seed=0, 64 | skip_prob=0., 65 | sequence_flip_prob=0., 66 | ): 67 | _rank, _world_size = get_dist_info() 68 | if world_size is None: 69 | world_size = _world_size 70 | if rank is None: 71 | rank = _rank 72 | 73 | self.dataset = dataset 74 | self.batch_size = batch_size 75 | self.world_size = world_size 76 | self.rank = rank 77 | self.seed = sync_random_seed(seed) 78 | 79 | self.size = len(self.dataset) 80 | 81 | assert hasattr(self.dataset, "flag") 82 | self.flag = self.dataset.flag 83 | self.group_sizes = np.bincount(self.flag) 84 | self.groups_num = len(self.group_sizes) 85 | self.global_batch_size = batch_size * world_size 86 | assert self.groups_num >= self.global_batch_size 87 | 88 | # Now, for efficiency, make a dict group_idx: List[dataset sample_idxs] 89 | self.group_idx_to_sample_idxs = { 90 | group_idx: np.where(self.flag == group_idx)[0].tolist() 91 | for group_idx in range(self.groups_num) 92 | } 93 | 94 | # Get a generator per sample idx. Considering samples over all 95 | # GPUs, each sample position has its own generator 96 | self.group_indices_per_global_sample_idx = [ 97 | self._group_indices_per_global_sample_idx( 98 | self.rank * self.batch_size + local_sample_idx 99 | ) 100 | for local_sample_idx in range(self.batch_size) 101 | ] 102 | 103 | # Keep track of a buffer of dataset sample idxs for each local sample idx 104 | self.buffer_per_local_sample = [[] for _ in range(self.batch_size)] 105 | self.aug_per_local_sample = [None for _ in range(self.batch_size)] 106 | self.skip_prob = skip_prob 107 | self.sequence_flip_prob = sequence_flip_prob 108 | 109 | def _infinite_group_indices(self): 110 | g = torch.Generator() 111 | g.manual_seed(self.seed) 112 | while True: 113 | yield from torch.randperm(self.groups_num, generator=g).tolist() 114 | 115 | def _group_indices_per_global_sample_idx(self, global_sample_idx): 116 | yield from itertools.islice( 117 | self._infinite_group_indices(), 118 | global_sample_idx, 119 | None, 120 | self.global_batch_size, 121 | ) 122 | 123 | def __iter__(self): 124 | while True: 125 | curr_batch = [] 126 | for local_sample_idx in range(self.batch_size): 127 | skip = ( 128 | np.random.uniform() < self.skip_prob 129 | and len(self.buffer_per_local_sample[local_sample_idx]) > 1 130 | ) 131 | if len(self.buffer_per_local_sample[local_sample_idx]) == 0: 132 | # Finished current group, refill with next group 133 | # skip = False 134 | new_group_idx = next( 135 | self.group_indices_per_global_sample_idx[ 136 | local_sample_idx 137 | ] 138 | ) 139 | self.buffer_per_local_sample[ 140 | local_sample_idx 141 | ] = copy.deepcopy( 142 | self.group_idx_to_sample_idxs[new_group_idx] 143 | ) 144 | if np.random.uniform() < self.sequence_flip_prob: 145 | self.buffer_per_local_sample[ 146 | local_sample_idx 147 | ] = self.buffer_per_local_sample[local_sample_idx][ 148 | ::-1 149 | ] 150 | if self.dataset.keep_consistent_seq_aug: 151 | self.aug_per_local_sample[ 152 | local_sample_idx 153 | ] = self.dataset.get_augmentation() 154 | 155 | if not self.dataset.keep_consistent_seq_aug: 156 | self.aug_per_local_sample[ 157 | local_sample_idx 158 | ] = self.dataset.get_augmentation() 159 | 160 | if skip: 161 | self.buffer_per_local_sample[local_sample_idx].pop(0) 162 | curr_batch.append( 163 | dict( 164 | idx=self.buffer_per_local_sample[local_sample_idx].pop( 165 | 0 166 | ), 167 | aug_config=self.aug_per_local_sample[local_sample_idx], 168 | ) 169 | ) 170 | 171 | yield curr_batch 172 | 173 | def __len__(self): 174 | """Length of base dataset.""" 175 | return self.size 176 | 177 | def set_epoch(self, epoch): 178 | self.epoch = epoch 179 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/group_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.runner import get_dist_info 7 | from torch.utils.data import Sampler 8 | from .sampler import SAMPLER 9 | import random 10 | from IPython import embed 11 | 12 | 13 | @SAMPLER.register_module() 14 | class DistributedGroupSampler(Sampler): 15 | """Sampler that restricts data loading to a subset of the dataset. 16 | It is especially useful in conjunction with 17 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 18 | process can pass a DistributedSampler instance as a DataLoader sampler, 19 | and load a subset of the original dataset that is exclusive to it. 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | Arguments: 23 | dataset: Dataset used for sampling. 24 | num_replicas (optional): Number of processes participating in 25 | distributed training. 26 | rank (optional): Rank of the current process within num_replicas. 27 | seed (int, optional): random seed used to shuffle the sampler if 28 | ``shuffle=True``. This number should be identical across all 29 | processes in the distributed group. Default: 0. 30 | """ 31 | 32 | def __init__( 33 | self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0 34 | ): 35 | _rank, _num_replicas = get_dist_info() 36 | if num_replicas is None: 37 | num_replicas = _num_replicas 38 | if rank is None: 39 | rank = _rank 40 | self.dataset = dataset 41 | self.samples_per_gpu = samples_per_gpu 42 | self.num_replicas = num_replicas 43 | self.rank = rank 44 | self.epoch = 0 45 | self.seed = seed if seed is not None else 0 46 | 47 | assert hasattr(self.dataset, "flag") 48 | self.flag = self.dataset.flag 49 | self.group_sizes = np.bincount(self.flag) 50 | 51 | self.num_samples = 0 52 | for i, j in enumerate(self.group_sizes): 53 | self.num_samples += ( 54 | int( 55 | math.ceil( 56 | self.group_sizes[i] 57 | * 1.0 58 | / self.samples_per_gpu 59 | / self.num_replicas 60 | ) 61 | ) 62 | * self.samples_per_gpu 63 | ) 64 | self.total_size = self.num_samples * self.num_replicas 65 | 66 | def __iter__(self): 67 | # deterministically shuffle based on epoch 68 | g = torch.Generator() 69 | g.manual_seed(self.epoch + self.seed) 70 | 71 | indices = [] 72 | for i, size in enumerate(self.group_sizes): 73 | if size > 0: 74 | indice = np.where(self.flag == i)[0] 75 | assert len(indice) == size 76 | # add .numpy() to avoid bug when selecting indice in parrots. 77 | # TODO: check whether torch.randperm() can be replaced by 78 | # numpy.random.permutation(). 79 | indice = indice[ 80 | list(torch.randperm(int(size), generator=g).numpy()) 81 | ].tolist() 82 | extra = int( 83 | math.ceil( 84 | size * 1.0 / self.samples_per_gpu / self.num_replicas 85 | ) 86 | ) * self.samples_per_gpu * self.num_replicas - len(indice) 87 | # pad indice 88 | tmp = indice.copy() 89 | for _ in range(extra // size): 90 | indice.extend(tmp) 91 | indice.extend(tmp[: extra % size]) 92 | indices.extend(indice) 93 | 94 | assert len(indices) == self.total_size 95 | 96 | indices = [ 97 | indices[j] 98 | for i in list( 99 | torch.randperm( 100 | len(indices) // self.samples_per_gpu, generator=g 101 | ) 102 | ) 103 | for j in range( 104 | i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu 105 | ) 106 | ] 107 | 108 | # subsample 109 | offset = self.num_samples * self.rank 110 | indices = indices[offset : offset + self.num_samples] 111 | assert len(indices) == self.num_samples 112 | 113 | return iter(indices) 114 | 115 | def __len__(self): 116 | return self.num_samples 117 | 118 | def set_epoch(self, epoch): 119 | self.epoch = epoch 120 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/sampler.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils.registry import Registry, build_from_cfg 2 | 3 | SAMPLER = Registry("sampler") 4 | 5 | 6 | def build_sampler(cfg, default_args): 7 | return build_from_cfg(cfg, SAMPLER, default_args) 8 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from projects.mmdet3d_plugin.core.box3d import * 8 | 9 | 10 | def box3d_to_corners(box3d): 11 | if isinstance(box3d, torch.Tensor): 12 | box3d = box3d.detach().cpu().numpy() 13 | corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1) 14 | corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]] 15 | # use relative origin [0.5, 0.5, 0] 16 | corners_norm = corners_norm - np.array([0.5, 0.5, 0.5]) 17 | corners = box3d[:, None, [W, L, H]] * corners_norm.reshape([1, 8, 3]) 18 | 19 | # rotate around z axis 20 | rot_cos = np.cos(box3d[:, YAW]) 21 | rot_sin = np.sin(box3d[:, YAW]) 22 | rot_mat = np.tile(np.eye(3)[None], (box3d.shape[0], 1, 1)) 23 | rot_mat[:, 0, 0] = rot_cos 24 | rot_mat[:, 0, 1] = -rot_sin 25 | rot_mat[:, 1, 0] = rot_sin 26 | rot_mat[:, 1, 1] = rot_cos 27 | corners = (rot_mat[:, None] @ corners[..., None]).squeeze(axis=-1) 28 | corners += box3d[:, None, :3] 29 | return corners 30 | 31 | 32 | def plot_rect3d_on_img( 33 | img, num_rects, rect_corners, color=(0, 255, 0), thickness=1 34 | ): 35 | """Plot the boundary lines of 3D rectangular on 2D images. 36 | 37 | Args: 38 | img (numpy.array): The numpy array of image. 39 | num_rects (int): Number of 3D rectangulars. 40 | rect_corners (numpy.array): Coordinates of the corners of 3D 41 | rectangulars. Should be in the shape of [num_rect, 8, 2]. 42 | color (tuple[int], optional): The color to draw bboxes. 43 | Default: (0, 255, 0). 44 | thickness (int, optional): The thickness of bboxes. Default: 1. 45 | """ 46 | line_indices = ( 47 | (0, 1), 48 | (0, 3), 49 | (0, 4), 50 | (1, 2), 51 | (1, 5), 52 | (3, 2), 53 | (3, 7), 54 | (4, 5), 55 | (4, 7), 56 | (2, 6), 57 | (5, 6), 58 | (6, 7), 59 | ) 60 | h, w = img.shape[:2] 61 | for i in range(num_rects): 62 | corners = np.clip(rect_corners[i], -1e4, 1e5).astype(np.int32) 63 | for start, end in line_indices: 64 | if ( 65 | (corners[start, 1] >= h or corners[start, 1] < 0) 66 | or (corners[start, 0] >= w or corners[start, 0] < 0) 67 | ) and ( 68 | (corners[end, 1] >= h or corners[end, 1] < 0) 69 | or (corners[end, 0] >= w or corners[end, 0] < 0) 70 | ): 71 | continue 72 | if isinstance(color[0], int): 73 | cv2.line( 74 | img, 75 | (corners[start, 0], corners[start, 1]), 76 | (corners[end, 0], corners[end, 1]), 77 | color, 78 | thickness, 79 | cv2.LINE_AA, 80 | ) 81 | else: 82 | cv2.line( 83 | img, 84 | (corners[start, 0], corners[start, 1]), 85 | (corners[end, 0], corners[end, 1]), 86 | color[i], 87 | thickness, 88 | cv2.LINE_AA, 89 | ) 90 | 91 | return img.astype(np.uint8) 92 | 93 | 94 | def draw_lidar_bbox3d_on_img( 95 | bboxes3d, raw_img, lidar2img_rt, img_metas=None, color=(0, 255, 0), thickness=1 96 | ): 97 | """Project the 3D bbox on 2D plane and draw on input image. 98 | 99 | Args: 100 | bboxes3d (:obj:`LiDARInstance3DBoxes`): 101 | 3d bbox in lidar coordinate system to visualize. 102 | raw_img (numpy.array): The numpy array of image. 103 | lidar2img_rt (numpy.array, shape=[4, 4]): The projection matrix 104 | according to the camera intrinsic parameters. 105 | img_metas (dict): Useless here. 106 | color (tuple[int], optional): The color to draw bboxes. 107 | Default: (0, 255, 0). 108 | thickness (int, optional): The thickness of bboxes. Default: 1. 109 | """ 110 | img = raw_img.copy() 111 | # corners_3d = bboxes3d.corners 112 | corners_3d = box3d_to_corners(bboxes3d) 113 | num_bbox = corners_3d.shape[0] 114 | pts_4d = np.concatenate( 115 | [corners_3d.reshape(-1, 3), np.ones((num_bbox * 8, 1))], axis=-1 116 | ) 117 | lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4) 118 | if isinstance(lidar2img_rt, torch.Tensor): 119 | lidar2img_rt = lidar2img_rt.cpu().numpy() 120 | pts_2d = pts_4d @ lidar2img_rt.T 121 | 122 | pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5) 123 | pts_2d[:, 0] /= pts_2d[:, 2] 124 | pts_2d[:, 1] /= pts_2d[:, 2] 125 | imgfov_pts_2d = pts_2d[..., :2].reshape(num_bbox, 8, 2) 126 | 127 | return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness) 128 | 129 | 130 | def draw_points_on_img(points, img, lidar2img_rt, color=(0, 255, 0), circle=4): 131 | img = img.copy() 132 | N = points.shape[0] 133 | points = points.cpu().numpy() 134 | lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4) 135 | if isinstance(lidar2img_rt, torch.Tensor): 136 | lidar2img_rt = lidar2img_rt.cpu().numpy() 137 | pts_2d = ( 138 | np.sum(points[:, :, None] * lidar2img_rt[:3, :3], axis=-1) 139 | + lidar2img_rt[:3, 3] 140 | ) 141 | pts_2d[..., 2] = np.clip(pts_2d[..., 2], a_min=1e-5, a_max=1e5) 142 | pts_2d = pts_2d[..., :2] / pts_2d[..., 2:3] 143 | pts_2d = np.clip(pts_2d, -1e4, 1e4).astype(np.int32) 144 | 145 | for i in range(N): 146 | for point in pts_2d[i]: 147 | if isinstance(color[0], int): 148 | color_tmp = color 149 | else: 150 | color_tmp = color[i] 151 | cv2.circle(img, point.tolist(), circle, color_tmp, thickness=-1) 152 | return img.astype(np.uint8) 153 | 154 | 155 | def draw_lidar_bbox3d_on_bev( 156 | bboxes_3d, bev_size, bev_range=115, color=(255, 0, 0), thickness=3): 157 | if isinstance(bev_size, (list, tuple)): 158 | bev_h, bev_w = bev_size 159 | else: 160 | bev_h, bev_w = bev_size, bev_size 161 | bev = np.zeros([bev_h, bev_w, 3]) 162 | 163 | marking_color = (127, 127, 127) 164 | bev_resolution = bev_range / bev_h 165 | for cir in range(int(bev_range / 2 / 10)): 166 | cv2.circle( 167 | bev, 168 | (int(bev_h / 2), int(bev_w / 2)), 169 | int((cir + 1) * 10 / bev_resolution), 170 | marking_color, 171 | thickness=thickness, 172 | ) 173 | cv2.line( 174 | bev, 175 | (0, int(bev_h / 2)), 176 | (bev_w, int(bev_h / 2)), 177 | marking_color, 178 | ) 179 | cv2.line( 180 | bev, 181 | (int(bev_w / 2), 0), 182 | (int(bev_w / 2), bev_h), 183 | marking_color, 184 | ) 185 | if len(bboxes_3d) != 0: 186 | bev_corners = box3d_to_corners(bboxes_3d)[:, [0, 3, 4, 7]][ 187 | ..., [0, 1] 188 | ] 189 | xs = bev_corners[..., 0] / bev_resolution + bev_w / 2 190 | ys = -bev_corners[..., 1] / bev_resolution + bev_h / 2 191 | for obj_idx, (x, y) in enumerate(zip(xs, ys)): 192 | for p1, p2 in ((0, 1), (0, 2), (1, 3), (2, 3)): 193 | if isinstance(color[0], (list, tuple)): 194 | tmp = color[obj_idx] 195 | else: 196 | tmp = color 197 | cv2.line( 198 | bev, 199 | (int(x[p1]), int(y[p1])), 200 | (int(x[p2]), int(y[p2])), 201 | tmp, 202 | thickness=thickness, 203 | ) 204 | return bev.astype(np.uint8) 205 | 206 | 207 | def draw_lidar_bbox3d(bboxes_3d, imgs, lidar2imgs, color=(255, 0, 0)): 208 | vis_imgs = [] 209 | for i, (img, lidar2img) in enumerate(zip(imgs, lidar2imgs)): 210 | vis_imgs.append( 211 | draw_lidar_bbox3d_on_img(bboxes_3d, img, lidar2img, color=color) 212 | ) 213 | 214 | num_imgs = len(vis_imgs) 215 | if num_imgs < 4 or num_imgs % 2 != 0: 216 | vis_imgs = np.concatenate(vis_imgs, axis=1) 217 | else: 218 | vis_imgs = np.concatenate([ 219 | np.concatenate(vis_imgs[:num_imgs//2], axis=1), 220 | np.concatenate(vis_imgs[num_imgs//2:], axis=1) 221 | ], axis=0) 222 | 223 | bev = draw_lidar_bbox3d_on_bev(bboxes_3d, vis_imgs.shape[0], color=color) 224 | vis_imgs = np.concatenate([bev, vis_imgs], axis=1) 225 | return vis_imgs 226 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparsedrive import SparseDrive 2 | from .sparsedrive_head import SparseDriveHead 3 | from .blocks import ( 4 | DeformableFeatureAggregation, 5 | DenseDepthNet, 6 | AsymmetricFFN, 7 | ) 8 | from .instance_bank import InstanceBank 9 | from .detection3d import ( 10 | SparseBox3DDecoder, 11 | SparseBox3DTarget, 12 | SparseBox3DRefinementModule, 13 | SparseBox3DKeyPointsGenerator, 14 | SparseBox3DEncoder, 15 | ) 16 | from .map import * 17 | from .motion import * 18 | 19 | 20 | __all__ = [ 21 | "SparseDrive", 22 | "SparseDriveHead", 23 | "DeformableFeatureAggregation", 24 | "DenseDepthNet", 25 | "AsymmetricFFN", 26 | "InstanceBank", 27 | "SparseBox3DDecoder", 28 | "SparseBox3DTarget", 29 | "SparseBox3DRefinementModule", 30 | "SparseBox3DKeyPointsGenerator", 31 | "SparseBox3DEncoder", 32 | ] 33 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/base_target.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | __all__ = ["BaseTargetWithDenoising"] 5 | 6 | 7 | class BaseTargetWithDenoising(ABC): 8 | def __init__(self, num_dn_groups=0, num_temp_dn_groups=0): 9 | super(BaseTargetWithDenoising, self).__init__() 10 | self.num_dn_groups = num_dn_groups 11 | self.num_temp_dn_groups = num_temp_dn_groups 12 | self.dn_metas = None 13 | 14 | @abstractmethod 15 | def sample(self, cls_pred, box_pred, cls_target, box_target): 16 | """ 17 | Perform Hungarian matching between predictions and ground truth, 18 | returning the matched ground truth corresponding to the predictions 19 | along with the corresponding regression weights. 20 | """ 21 | 22 | def get_dn_anchors(self, cls_target, box_target, *args, **kwargs): 23 | """ 24 | Generate noisy instances for the current frame, with a total of 25 | 'self.num_dn_groups' groups. 26 | """ 27 | return None 28 | 29 | def update_dn(self, instance_feature, anchor, *args, **kwargs): 30 | """ 31 | Insert the previously saved 'self.dn_metas' into the noisy instances 32 | of the current frame. 33 | """ 34 | 35 | def cache_dn( 36 | self, 37 | dn_instance_feature, 38 | dn_anchor, 39 | dn_cls_target, 40 | valid_mask, 41 | dn_id_target, 42 | ): 43 | """ 44 | Randomly save information for 'self.num_temp_dn_groups' groups of 45 | temporal noisy instances to 'self.dn_metas'. 46 | """ 47 | if self.num_temp_dn_groups < 0: 48 | return 49 | self.dn_metas = dict(dn_anchor=dn_anchor[:, : self.num_temp_dn_groups]) 50 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection3d/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import SparseBox3DDecoder 2 | from .target import SparseBox3DTarget 3 | from .detection3d_blocks import ( 4 | SparseBox3DRefinementModule, 5 | SparseBox3DKeyPointsGenerator, 6 | SparseBox3DEncoder, 7 | ) 8 | from .losses import SparseBox3DLoss 9 | from .detection3d_head import Sparse4DHead 10 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection3d/decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from mmdet.core.bbox.builder import BBOX_CODERS 6 | 7 | from projects.mmdet3d_plugin.core.box3d import * 8 | 9 | def decode_box(box): 10 | yaw = torch.atan2(box[..., SIN_YAW], box[..., COS_YAW]) 11 | box = torch.cat( 12 | [ 13 | box[..., [X, Y, Z]], 14 | box[..., [W, L, H]].exp(), 15 | yaw[..., None], 16 | box[..., VX:], 17 | ], 18 | dim=-1, 19 | ) 20 | return box 21 | 22 | 23 | @BBOX_CODERS.register_module() 24 | class SparseBox3DDecoder(object): 25 | def __init__( 26 | self, 27 | num_output: int = 300, 28 | score_threshold: Optional[float] = None, 29 | sorted: bool = True, 30 | ): 31 | super(SparseBox3DDecoder, self).__init__() 32 | self.num_output = num_output 33 | self.score_threshold = score_threshold 34 | self.sorted = sorted 35 | 36 | def decode( 37 | self, 38 | cls_scores, 39 | box_preds, 40 | instance_id=None, 41 | quality=None, 42 | output_idx=-1, 43 | ): 44 | squeeze_cls = instance_id is not None 45 | 46 | cls_scores = cls_scores[output_idx].sigmoid() 47 | 48 | if squeeze_cls: 49 | cls_scores, cls_ids = cls_scores.max(dim=-1) 50 | cls_scores = cls_scores.unsqueeze(dim=-1) 51 | 52 | box_preds = box_preds[output_idx] 53 | bs, num_pred, num_cls = cls_scores.shape 54 | cls_scores, indices = cls_scores.flatten(start_dim=1).topk( 55 | self.num_output, dim=1, sorted=self.sorted 56 | ) 57 | if not squeeze_cls: 58 | cls_ids = indices % num_cls 59 | if self.score_threshold is not None: 60 | mask = cls_scores >= self.score_threshold 61 | 62 | if quality[output_idx] is None: 63 | quality = None 64 | if quality is not None: 65 | centerness = quality[output_idx][..., CNS] 66 | centerness = torch.gather(centerness, 1, indices // num_cls) 67 | cls_scores_origin = cls_scores.clone() 68 | cls_scores *= centerness.sigmoid() 69 | cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True) 70 | if not squeeze_cls: 71 | cls_ids = torch.gather(cls_ids, 1, idx) 72 | if self.score_threshold is not None: 73 | mask = torch.gather(mask, 1, idx) 74 | indices = torch.gather(indices, 1, idx) 75 | 76 | output = [] 77 | for i in range(bs): 78 | category_ids = cls_ids[i] 79 | if squeeze_cls: 80 | category_ids = category_ids[indices[i]] 81 | scores = cls_scores[i] 82 | box = box_preds[i, indices[i] // num_cls] 83 | if self.score_threshold is not None: 84 | category_ids = category_ids[mask[i]] 85 | scores = scores[mask[i]] 86 | box = box[mask[i]] 87 | if quality is not None: 88 | scores_origin = cls_scores_origin[i] 89 | if self.score_threshold is not None: 90 | scores_origin = scores_origin[mask[i]] 91 | 92 | box = decode_box(box) 93 | output.append( 94 | { 95 | "boxes_3d": box.cpu(), 96 | "scores_3d": scores.cpu(), 97 | "labels_3d": category_ids.cpu(), 98 | } 99 | ) 100 | if quality is not None: 101 | output[-1]["cls_scores"] = scores_origin.cpu() 102 | if instance_id is not None: 103 | ids = instance_id[i, indices[i]] 104 | if self.score_threshold is not None: 105 | ids = ids[mask[i]] 106 | output[-1]["instance_ids"] = ids 107 | return output 108 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection3d/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from mmcv.utils import build_from_cfg 5 | from mmdet.models.builder import LOSSES 6 | 7 | from projects.mmdet3d_plugin.core.box3d import * 8 | 9 | 10 | @LOSSES.register_module() 11 | class SparseBox3DLoss(nn.Module): 12 | def __init__( 13 | self, 14 | loss_box, 15 | loss_centerness=None, 16 | loss_yawness=None, 17 | cls_allow_reverse=None, 18 | ): 19 | super().__init__() 20 | 21 | def build(cfg, registry): 22 | if cfg is None: 23 | return None 24 | return build_from_cfg(cfg, registry) 25 | 26 | self.loss_box = build(loss_box, LOSSES) 27 | self.loss_cns = build(loss_centerness, LOSSES) 28 | self.loss_yns = build(loss_yawness, LOSSES) 29 | self.cls_allow_reverse = cls_allow_reverse 30 | 31 | def forward( 32 | self, 33 | box, 34 | box_target, 35 | weight=None, 36 | avg_factor=None, 37 | prefix="", 38 | suffix="", 39 | quality=None, 40 | cls_target=None, 41 | **kwargs, 42 | ): 43 | # Some categories do not distinguish between positive and negative 44 | # directions. For example, barrier in nuScenes dataset. 45 | if self.cls_allow_reverse is not None and cls_target is not None: 46 | if_reverse = ( 47 | torch.nn.functional.cosine_similarity( 48 | box_target[..., [SIN_YAW, COS_YAW]], 49 | box[..., [SIN_YAW, COS_YAW]], 50 | dim=-1, 51 | ) 52 | < 0 53 | ) 54 | if_reverse = ( 55 | torch.isin( 56 | cls_target, cls_target.new_tensor(self.cls_allow_reverse) 57 | ) 58 | & if_reverse 59 | ) 60 | box_target[..., [SIN_YAW, COS_YAW]] = torch.where( 61 | if_reverse[..., None], 62 | -box_target[..., [SIN_YAW, COS_YAW]], 63 | box_target[..., [SIN_YAW, COS_YAW]], 64 | ) 65 | 66 | output = {} 67 | box_loss = self.loss_box( 68 | box, box_target, weight=weight, avg_factor=avg_factor 69 | ) 70 | output[f"{prefix}loss_box{suffix}"] = box_loss 71 | 72 | if quality is not None: 73 | cns = quality[..., CNS] 74 | yns = quality[..., YNS].sigmoid() 75 | cns_target = torch.norm( 76 | box_target[..., [X, Y, Z]] - box[..., [X, Y, Z]], p=2, dim=-1 77 | ) 78 | cns_target = torch.exp(-cns_target) 79 | cns_loss = self.loss_cns(cns, cns_target, avg_factor=avg_factor) 80 | output[f"{prefix}loss_cns{suffix}"] = cns_loss 81 | 82 | yns_target = ( 83 | torch.nn.functional.cosine_similarity( 84 | box_target[..., [SIN_YAW, COS_YAW]], 85 | box[..., [SIN_YAW, COS_YAW]], 86 | dim=-1, 87 | ) 88 | > 0 89 | ) 90 | yns_target = yns_target.float() 91 | yns_loss = self.loss_yns(yns, yns_target, avg_factor=avg_factor) 92 | output[f"{prefix}loss_yns{suffix}"] = yns_loss 93 | return output 94 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/grid_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | class Grid(object): 8 | def __init__( 9 | self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0 10 | ): 11 | self.use_h = use_h 12 | self.use_w = use_w 13 | self.rotate = rotate 14 | self.offset = offset 15 | self.ratio = ratio 16 | self.mode = mode 17 | self.st_prob = prob 18 | self.prob = prob 19 | 20 | def set_prob(self, epoch, max_epoch): 21 | self.prob = self.st_prob * epoch / max_epoch 22 | 23 | def __call__(self, img, label): 24 | if np.random.rand() > self.prob: 25 | return img, label 26 | h = img.size(1) 27 | w = img.size(2) 28 | self.d1 = 2 29 | self.d2 = min(h, w) 30 | hh = int(1.5 * h) 31 | ww = int(1.5 * w) 32 | d = np.random.randint(self.d1, self.d2) 33 | if self.ratio == 1: 34 | self.l = np.random.randint(1, d) 35 | else: 36 | self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) 37 | mask = np.ones((hh, ww), np.float32) 38 | st_h = np.random.randint(d) 39 | st_w = np.random.randint(d) 40 | if self.use_h: 41 | for i in range(hh // d): 42 | s = d * i + st_h 43 | t = min(s + self.l, hh) 44 | mask[s:t, :] *= 0 45 | if self.use_w: 46 | for i in range(ww // d): 47 | s = d * i + st_w 48 | t = min(s + self.l, ww) 49 | mask[:, s:t] *= 0 50 | 51 | r = np.random.randint(self.rotate) 52 | mask = Image.fromarray(np.uint8(mask)) 53 | mask = mask.rotate(r) 54 | mask = np.asarray(mask) 55 | mask = mask[ 56 | (hh - h) // 2 : (hh - h) // 2 + h, 57 | (ww - w) // 2 : (ww - w) // 2 + w, 58 | ] 59 | 60 | mask = torch.from_numpy(mask).float() 61 | if self.mode == 1: 62 | mask = 1 - mask 63 | 64 | mask = mask.expand_as(img) 65 | if self.offset: 66 | offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float() 67 | offset = (1 - mask) * offset 68 | img = img * mask + offset 69 | else: 70 | img = img * mask 71 | 72 | return img, label 73 | 74 | 75 | class GridMask(nn.Module): 76 | def __init__( 77 | self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0 78 | ): 79 | super(GridMask, self).__init__() 80 | self.use_h = use_h 81 | self.use_w = use_w 82 | self.rotate = rotate 83 | self.offset = offset 84 | self.ratio = ratio 85 | self.mode = mode 86 | self.st_prob = prob 87 | self.prob = prob 88 | 89 | def set_prob(self, epoch, max_epoch): 90 | self.prob = self.st_prob * epoch / max_epoch # + 1.#0.5 91 | 92 | def forward(self, x): 93 | if np.random.rand() > self.prob or not self.training: 94 | return x 95 | n, c, h, w = x.size() 96 | x = x.view(-1, h, w) 97 | hh = int(1.5 * h) 98 | ww = int(1.5 * w) 99 | d = np.random.randint(2, h) 100 | self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) 101 | mask = np.ones((hh, ww), np.float32) 102 | st_h = np.random.randint(d) 103 | st_w = np.random.randint(d) 104 | if self.use_h: 105 | for i in range(hh // d): 106 | s = d * i + st_h 107 | t = min(s + self.l, hh) 108 | mask[s:t, :] *= 0 109 | if self.use_w: 110 | for i in range(ww // d): 111 | s = d * i + st_w 112 | t = min(s + self.l, ww) 113 | mask[:, s:t] *= 0 114 | 115 | r = np.random.randint(self.rotate) 116 | mask = Image.fromarray(np.uint8(mask)) 117 | mask = mask.rotate(r) 118 | mask = np.asarray(mask) 119 | mask = mask[ 120 | (hh - h) // 2 : (hh - h) // 2 + h, 121 | (ww - w) // 2 : (ww - w) // 2 + w, 122 | ] 123 | 124 | mask = torch.from_numpy(mask.copy()).float().cuda() 125 | if self.mode == 1: 126 | mask = 1 - mask 127 | mask = mask.expand_as(x) 128 | if self.offset: 129 | offset = ( 130 | torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)) 131 | .float() 132 | .cuda() 133 | ) 134 | x = x * mask + offset * (1 - mask) 135 | else: 136 | x = x * mask 137 | 138 | return x.view(n, c, h, w) 139 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/map/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import SparsePoint3DDecoder 2 | from .target import SparsePoint3DTarget, HungarianLinesAssigner 3 | from .match_cost import LinesL1Cost, MapQueriesCost 4 | from .loss import LinesL1Loss, SparseLineLoss 5 | from .map_blocks import ( 6 | SparsePoint3DRefinementModule, 7 | SparsePoint3DKeyPointsGenerator, 8 | SparsePoint3DEncoder, 9 | ) -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/map/decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List 2 | 3 | import torch 4 | 5 | from mmdet.core.bbox.builder import BBOX_CODERS 6 | 7 | 8 | @BBOX_CODERS.register_module() 9 | class SparsePoint3DDecoder(object): 10 | def __init__( 11 | self, 12 | coords_dim: int = 2, 13 | score_threshold: Optional[float] = None, 14 | ): 15 | super(SparsePoint3DDecoder, self).__init__() 16 | self.score_threshold = score_threshold 17 | self.coords_dim = coords_dim 18 | 19 | def decode( 20 | self, 21 | cls_scores, 22 | pts_preds, 23 | instance_id=None, 24 | quality=None, 25 | output_idx=-1, 26 | ): 27 | bs, num_pred, num_cls = cls_scores[-1].shape 28 | cls_scores = cls_scores[-1].sigmoid() 29 | pts_preds = pts_preds[-1].reshape(bs, num_pred, -1, self.coords_dim) 30 | cls_scores, indices = cls_scores.flatten(start_dim=1).topk( 31 | num_pred, dim=1 32 | ) 33 | cls_ids = indices % num_cls 34 | if self.score_threshold is not None: 35 | mask = cls_scores >= self.score_threshold 36 | output = [] 37 | for i in range(bs): 38 | category_ids = cls_ids[i] 39 | scores = cls_scores[i] 40 | pts = pts_preds[i, indices[i] // num_cls] 41 | if self.score_threshold is not None: 42 | category_ids = category_ids[mask[i]] 43 | scores = scores[mask[i]] 44 | pts = pts[mask[i]] 45 | 46 | output.append( 47 | { 48 | "vectors": [vec.detach().cpu().numpy() for vec in pts], 49 | "scores": scores.detach().cpu().numpy(), 50 | "labels": category_ids.detach().cpu().numpy(), 51 | } 52 | ) 53 | return output -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/map/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from mmcv.utils import build_from_cfg 5 | from mmdet.models.builder import LOSSES 6 | from mmdet.models.losses import l1_loss, smooth_l1_loss 7 | 8 | 9 | @LOSSES.register_module() 10 | class LinesL1Loss(nn.Module): 11 | 12 | def __init__(self, reduction='mean', loss_weight=1.0, beta=0.5): 13 | """ 14 | L1 loss. The same as the smooth L1 loss 15 | Args: 16 | reduction (str, optional): The method to reduce the loss. 17 | Options are "none", "mean" and "sum". 18 | loss_weight (float, optional): The weight of loss. 19 | """ 20 | 21 | super().__init__() 22 | self.reduction = reduction 23 | self.loss_weight = loss_weight 24 | self.beta = beta 25 | 26 | def forward(self, 27 | pred, 28 | target, 29 | weight=None, 30 | avg_factor=None, 31 | reduction_override=None): 32 | """Forward function. 33 | Args: 34 | pred (torch.Tensor): The prediction. 35 | shape: [bs, ...] 36 | target (torch.Tensor): The learning target of the prediction. 37 | shape: [bs, ...] 38 | weight (torch.Tensor, optional): The weight of loss for each 39 | prediction. Defaults to None. 40 | it's useful when the predictions are not all valid. 41 | avg_factor (int, optional): Average factor that is used to average 42 | the loss. Defaults to None. 43 | reduction_override (str, optional): The reduction method used to 44 | override the original reduction method of the loss. 45 | Defaults to None. 46 | """ 47 | assert reduction_override in (None, 'none', 'mean', 'sum') 48 | reduction = ( 49 | reduction_override if reduction_override else self.reduction) 50 | 51 | if self.beta > 0: 52 | loss = smooth_l1_loss( 53 | pred, target, weight, reduction=reduction, avg_factor=avg_factor, beta=self.beta) 54 | 55 | else: 56 | loss = l1_loss( 57 | pred, target, weight, reduction=reduction, avg_factor=avg_factor) 58 | 59 | num_points = pred.shape[-1] // 2 60 | loss = loss / num_points 61 | 62 | return loss*self.loss_weight 63 | 64 | 65 | @LOSSES.register_module() 66 | class SparseLineLoss(nn.Module): 67 | def __init__( 68 | self, 69 | loss_line, 70 | num_sample=20, 71 | roi_size=(30, 60), 72 | ): 73 | super().__init__() 74 | 75 | def build(cfg, registry): 76 | if cfg is None: 77 | return None 78 | return build_from_cfg(cfg, registry) 79 | 80 | self.loss_line = build(loss_line, LOSSES) 81 | self.num_sample = num_sample 82 | self.roi_size = roi_size 83 | 84 | def forward( 85 | self, 86 | line, 87 | line_target, 88 | weight=None, 89 | avg_factor=None, 90 | prefix="", 91 | suffix="", 92 | **kwargs, 93 | ): 94 | 95 | output = {} 96 | line = self.normalize_line(line) 97 | line_target = self.normalize_line(line_target) 98 | line_loss = self.loss_line( 99 | line, line_target, weight=weight, avg_factor=avg_factor 100 | ) 101 | output[f"{prefix}loss_line{suffix}"] = line_loss 102 | 103 | return output 104 | 105 | def normalize_line(self, line): 106 | if line.shape[0] == 0: 107 | return line 108 | 109 | line = line.view(line.shape[:-1] + (self.num_sample, -1)) 110 | 111 | origin = -line.new_tensor([self.roi_size[0]/2, self.roi_size[1]/2]) 112 | line = line - origin 113 | 114 | # transform from range [0, 1] to (0, 1) 115 | eps = 1e-5 116 | norm = line.new_tensor([self.roi_size[0], self.roi_size[1]]) + eps 117 | line = line / norm 118 | line = line.flatten(-2, -1) 119 | 120 | return line 121 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/map/map_blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | from mmcv.cnn import Linear, Scale, bias_init_with_prob 9 | from mmcv.runner.base_module import Sequential, BaseModule 10 | from mmcv.cnn import xavier_init 11 | from mmcv.cnn.bricks.registry import ( 12 | PLUGIN_LAYERS, 13 | POSITIONAL_ENCODING, 14 | ) 15 | 16 | from ..blocks import linear_relu_ln 17 | 18 | 19 | @POSITIONAL_ENCODING.register_module() 20 | class SparsePoint3DEncoder(BaseModule): 21 | def __init__( 22 | self, 23 | embed_dims: int = 256, 24 | num_sample: int = 20, 25 | coords_dim: int = 2, 26 | ): 27 | super(SparsePoint3DEncoder, self).__init__() 28 | self.embed_dims = embed_dims 29 | self.input_dims = num_sample * coords_dim 30 | def embedding_layer(input_dims): 31 | return nn.Sequential(*linear_relu_ln(embed_dims, 1, 2, input_dims)) 32 | 33 | self.pos_fc = embedding_layer(self.input_dims) 34 | 35 | def forward(self, anchor: torch.Tensor): 36 | pos_feat = self.pos_fc(anchor) 37 | return pos_feat 38 | 39 | 40 | @PLUGIN_LAYERS.register_module() 41 | class SparsePoint3DRefinementModule(BaseModule): 42 | def __init__( 43 | self, 44 | embed_dims: int = 256, 45 | num_sample: int = 20, 46 | coords_dim: int = 2, 47 | num_cls: int = 3, 48 | with_cls_branch: bool = True, 49 | ): 50 | super(SparsePoint3DRefinementModule, self).__init__() 51 | self.embed_dims = embed_dims 52 | self.num_sample = num_sample 53 | self.output_dim = num_sample * coords_dim 54 | self.num_cls = num_cls 55 | 56 | self.layers = nn.Sequential( 57 | *linear_relu_ln(embed_dims, 2, 2), 58 | Linear(self.embed_dims, self.output_dim), 59 | Scale([1.0] * self.output_dim), 60 | ) 61 | 62 | self.with_cls_branch = with_cls_branch 63 | if with_cls_branch: 64 | self.cls_layers = nn.Sequential( 65 | *linear_relu_ln(embed_dims, 1, 2), 66 | Linear(self.embed_dims, self.num_cls), 67 | ) 68 | 69 | def init_weight(self): 70 | if self.with_cls_branch: 71 | bias_init = bias_init_with_prob(0.01) 72 | nn.init.constant_(self.cls_layers[-1].bias, bias_init) 73 | 74 | def forward( 75 | self, 76 | instance_feature: torch.Tensor, 77 | anchor: torch.Tensor, 78 | anchor_embed: torch.Tensor, 79 | time_interval: torch.Tensor = 1.0, 80 | return_cls=True, 81 | ): 82 | output = self.layers(instance_feature + anchor_embed) 83 | output = output + anchor 84 | if return_cls: 85 | assert self.with_cls_branch, "Without classification layers !!!" 86 | cls = self.cls_layers(instance_feature) ## NOTE anchor embed? 87 | else: 88 | cls = None 89 | qt = None 90 | return output, cls, qt 91 | 92 | 93 | @PLUGIN_LAYERS.register_module() 94 | class SparsePoint3DKeyPointsGenerator(BaseModule): 95 | def __init__( 96 | self, 97 | embed_dims: int = 256, 98 | num_sample: int = 20, 99 | num_learnable_pts: int = 0, 100 | fix_height: Tuple = (0,), 101 | ground_height: int = 0, 102 | ): 103 | super(SparsePoint3DKeyPointsGenerator, self).__init__() 104 | self.embed_dims = embed_dims 105 | self.num_sample = num_sample 106 | self.num_learnable_pts = num_learnable_pts 107 | self.num_pts = num_sample * len(fix_height) * num_learnable_pts 108 | if self.num_learnable_pts > 0: 109 | self.learnable_fc = Linear(self.embed_dims, self.num_pts * 2) 110 | 111 | self.fix_height = np.array(fix_height) 112 | self.ground_height = ground_height 113 | 114 | def init_weight(self): 115 | if self.num_learnable_pts > 0: 116 | xavier_init(self.learnable_fc, distribution="uniform", bias=0.0) 117 | 118 | def forward( 119 | self, 120 | anchor, 121 | instance_feature=None, 122 | T_cur2temp_list=None, 123 | cur_timestamp=None, 124 | temp_timestamps=None, 125 | ): 126 | assert self.num_learnable_pts > 0, 'No learnable pts' 127 | bs, num_anchor, _ = anchor.shape 128 | key_points = anchor.view(bs, num_anchor, self.num_sample, -1) 129 | offset = ( 130 | self.learnable_fc(instance_feature) 131 | .reshape(bs, num_anchor, self.num_sample, len(self.fix_height), self.num_learnable_pts, 2) 132 | ) 133 | key_points = offset + key_points[..., None, None, :] 134 | key_points = torch.cat( 135 | [ 136 | key_points, 137 | key_points.new_full(key_points.shape[:-1]+(1,), fill_value=self.ground_height), 138 | ], 139 | dim=-1, 140 | ) 141 | fix_height = key_points.new_tensor(self.fix_height) 142 | height_offset = key_points.new_zeros([len(fix_height), 2]) 143 | height_offset = torch.cat([height_offset, fix_height[:,None]], dim=-1) 144 | key_points = key_points + height_offset[None, None, None, :, None] 145 | key_points = key_points.flatten(2, 4) 146 | if ( 147 | cur_timestamp is None 148 | or temp_timestamps is None 149 | or T_cur2temp_list is None 150 | or len(temp_timestamps) == 0 151 | ): 152 | return key_points 153 | 154 | temp_key_points_list = [] 155 | for i, t_time in enumerate(temp_timestamps): 156 | temp_key_points = key_points 157 | T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype) 158 | temp_key_points = ( 159 | T_cur2temp[:, None, None, :3] 160 | @ torch.cat( 161 | [ 162 | temp_key_points, 163 | torch.ones_like(temp_key_points[..., :1]), 164 | ], 165 | dim=-1, 166 | ).unsqueeze(-1) 167 | ) 168 | temp_key_points = temp_key_points.squeeze(-1) 169 | temp_key_points_list.append(temp_key_points) 170 | return key_points, temp_key_points_list 171 | 172 | # @staticmethod 173 | def anchor_projection( 174 | self, 175 | anchor, 176 | T_src2dst_list, 177 | src_timestamp=None, 178 | dst_timestamps=None, 179 | time_intervals=None, 180 | ): 181 | dst_anchors = [] 182 | for i in range(len(T_src2dst_list)): 183 | dst_anchor = anchor.clone() 184 | bs, num_anchor, _ = anchor.shape 185 | dst_anchor = dst_anchor.reshape(bs, num_anchor, self.num_sample, -1).flatten(1, 2) 186 | T_src2dst = torch.unsqueeze( 187 | T_src2dst_list[i].to(dtype=anchor.dtype), dim=1 188 | ) 189 | 190 | dst_anchor = ( 191 | torch.matmul( 192 | T_src2dst[..., :2, :2], dst_anchor[..., None] 193 | ).squeeze(dim=-1) 194 | + T_src2dst[..., :2, 3] 195 | ) 196 | 197 | dst_anchor = dst_anchor.reshape(bs, num_anchor, self.num_sample, -1).flatten(2, 3) 198 | dst_anchors.append(dst_anchor) 199 | return dst_anchors -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/map/match_cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from mmdet.core.bbox.match_costs.builder import MATCH_COST 3 | from mmdet.core.bbox.match_costs import build_match_cost 4 | from torch.nn.functional import smooth_l1_loss 5 | 6 | 7 | @MATCH_COST.register_module() 8 | class LinesL1Cost(object): 9 | """LinesL1Cost. 10 | Args: 11 | weight (int | float, optional): loss_weight 12 | """ 13 | 14 | def __init__(self, weight=1.0, beta=0.0, permute=False): 15 | self.weight = weight 16 | self.permute = permute 17 | self.beta = beta 18 | 19 | def __call__(self, lines_pred, gt_lines, **kwargs): 20 | """ 21 | Args: 22 | lines_pred (Tensor): predicted normalized lines: 23 | [num_query, 2*num_points] 24 | gt_lines (Tensor): Ground truth lines 25 | [num_gt, 2*num_points] or [num_gt, num_permute, 2*num_points] 26 | Returns: 27 | torch.Tensor: reg_cost value with weight 28 | shape [num_pred, num_gt] 29 | """ 30 | if self.permute: 31 | assert len(gt_lines.shape) == 3 32 | else: 33 | assert len(gt_lines.shape) == 2 34 | 35 | num_pred, num_gt = len(lines_pred), len(gt_lines) 36 | if self.permute: 37 | # permute-invarint labels 38 | gt_lines = gt_lines.flatten(0, 1) # (num_gt*num_permute, 2*num_pts) 39 | 40 | num_pts = lines_pred.shape[-1]//2 41 | 42 | if self.beta > 0: 43 | lines_pred = lines_pred.unsqueeze(1).repeat(1, len(gt_lines), 1) 44 | gt_lines = gt_lines.unsqueeze(0).repeat(num_pred, 1, 1) 45 | dist_mat = smooth_l1_loss(lines_pred, gt_lines, reduction='none', beta=self.beta).sum(-1) 46 | 47 | else: 48 | dist_mat = torch.cdist(lines_pred, gt_lines, p=1) 49 | 50 | dist_mat = dist_mat / num_pts 51 | 52 | if self.permute: 53 | # dist_mat: (num_pred, num_gt*num_permute) 54 | dist_mat = dist_mat.view(num_pred, num_gt, -1) # (num_pred, num_gt, num_permute) 55 | dist_mat, gt_permute_index = torch.min(dist_mat, 2) 56 | return dist_mat * self.weight, gt_permute_index 57 | 58 | return dist_mat * self.weight 59 | 60 | 61 | @MATCH_COST.register_module() 62 | class MapQueriesCost(object): 63 | 64 | def __init__(self, cls_cost, reg_cost, iou_cost=None): 65 | 66 | self.cls_cost = build_match_cost(cls_cost) 67 | self.reg_cost = build_match_cost(reg_cost) 68 | 69 | self.iou_cost = None 70 | if iou_cost is not None: 71 | self.iou_cost = build_match_cost(iou_cost) 72 | 73 | def __call__(self, preds: dict, gts: dict, ignore_cls_cost: bool): 74 | 75 | # classification and bboxcost. 76 | cls_cost = self.cls_cost(preds['scores'], gts['labels']) 77 | 78 | # regression cost 79 | regkwargs = {} 80 | if 'masks' in preds and 'masks' in gts: 81 | assert isinstance(self.reg_cost, DynamicLinesCost), ' Issues!!' 82 | regkwargs = { 83 | 'masks_pred': preds['masks'], 84 | 'masks_gt': gts['masks'], 85 | } 86 | 87 | reg_cost = self.reg_cost(preds['lines'], gts['lines'], **regkwargs) 88 | if self.reg_cost.permute: 89 | reg_cost, gt_permute_idx = reg_cost 90 | 91 | # weighted sum of above three costs 92 | if ignore_cls_cost: 93 | cost = reg_cost 94 | else: 95 | cost = cls_cost + reg_cost 96 | 97 | # Iou 98 | if self.iou_cost is not None: 99 | iou_cost = self.iou_cost(preds['lines'],gts['lines']) 100 | cost += iou_cost 101 | 102 | if self.reg_cost.permute: 103 | return cost, gt_permute_idx 104 | return cost 105 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/map/target.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | from mmdet.core.bbox.builder import (BBOX_SAMPLERS, BBOX_ASSIGNERS) 7 | from mmdet.core.bbox.match_costs import build_match_cost 8 | from mmdet.core import (build_assigner, build_sampler) 9 | from mmdet.core.bbox.assigners import (AssignResult, BaseAssigner) 10 | 11 | from ..base_target import BaseTargetWithDenoising 12 | 13 | 14 | @BBOX_SAMPLERS.register_module() 15 | class SparsePoint3DTarget(BaseTargetWithDenoising): 16 | def __init__( 17 | self, 18 | assigner=None, 19 | num_dn_groups=0, 20 | dn_noise_scale=0.5, 21 | max_dn_gt=32, 22 | add_neg_dn=True, 23 | num_temp_dn_groups=0, 24 | num_cls=3, 25 | num_sample=20, 26 | roi_size=(30, 60), 27 | ): 28 | super(SparsePoint3DTarget, self).__init__( 29 | num_dn_groups, num_temp_dn_groups 30 | ) 31 | self.assigner = build_assigner(assigner) 32 | self.dn_noise_scale = dn_noise_scale 33 | self.max_dn_gt = max_dn_gt 34 | self.add_neg_dn = add_neg_dn 35 | 36 | self.num_cls = num_cls 37 | self.num_sample = num_sample 38 | self.roi_size = roi_size 39 | 40 | def sample( 41 | self, 42 | cls_preds, 43 | pts_preds, 44 | cls_targets, 45 | pts_targets, 46 | ): 47 | pts_targets = [x.flatten(2, 3) if len(x.shape)==4 else x for x in pts_targets] 48 | indices = [] 49 | for(cls_pred, pts_pred, cls_target, pts_target) in zip( 50 | cls_preds, pts_preds, cls_targets, pts_targets 51 | ): 52 | # normalize to (0, 1) 53 | pts_pred = self.normalize_line(pts_pred) 54 | pts_target = self.normalize_line(pts_target) 55 | preds=dict(lines=pts_pred, scores=cls_pred) 56 | gts=dict(lines=pts_target, labels=cls_target) 57 | indice = self.assigner.assign(preds, gts) 58 | indices.append(indice) 59 | 60 | bs, num_pred, num_cls = cls_preds.shape 61 | output_cls_target = cls_targets[0].new_ones([bs, num_pred], dtype=torch.long) * num_cls 62 | output_box_target = pts_preds.new_zeros(pts_preds.shape) 63 | output_reg_weights = pts_preds.new_zeros(pts_preds.shape) 64 | for i, (pred_idx, target_idx, gt_permute_index) in enumerate(indices): 65 | if len(cls_targets[i]) == 0: 66 | continue 67 | permute_idx = gt_permute_index[pred_idx, target_idx] 68 | output_cls_target[i, pred_idx] = cls_targets[i][target_idx] 69 | output_box_target[i, pred_idx] = pts_targets[i][target_idx, permute_idx] 70 | output_reg_weights[i, pred_idx] = 1 71 | 72 | return output_cls_target, output_box_target, output_reg_weights 73 | 74 | def normalize_line(self, line): 75 | if line.shape[0] == 0: 76 | return line 77 | 78 | line = line.view(line.shape[:-1] + (self.num_sample, -1)) 79 | 80 | origin = -line.new_tensor([self.roi_size[0]/2, self.roi_size[1]/2]) 81 | line = line - origin 82 | 83 | # transform from range [0, 1] to (0, 1) 84 | eps = 1e-5 85 | norm = line.new_tensor([self.roi_size[0], self.roi_size[1]]) + eps 86 | line = line / norm 87 | line = line.flatten(-2, -1) 88 | 89 | return line 90 | 91 | 92 | @BBOX_ASSIGNERS.register_module() 93 | class HungarianLinesAssigner(BaseAssigner): 94 | """ 95 | Computes one-to-one matching between predictions and ground truth. 96 | This class computes an assignment between the targets and the predictions 97 | based on the costs. The costs are weighted sum of three components: 98 | classification cost and regression L1 cost. The 99 | targets don't include the no_object, so generally there are more 100 | predictions than targets. After the one-to-one matching, the un-matched 101 | are treated as backgrounds. Thus each query prediction will be assigned 102 | with `0` or a positive integer indicating the ground truth index: 103 | - 0: negative sample, no assigned gt 104 | - positive integer: positive sample, index (1-based) of assigned gt 105 | Args: 106 | cls_weight (int | float, optional): The scale factor for classification 107 | cost. Default 1.0. 108 | bbox_weight (int | float, optional): The scale factor for regression 109 | L1 cost. Default 1.0. 110 | """ 111 | 112 | def __init__(self, cost=dict, **kwargs): 113 | self.cost = build_match_cost(cost) 114 | 115 | def assign(self, 116 | preds: dict, 117 | gts: dict, 118 | ignore_cls_cost=False, 119 | gt_bboxes_ignore=None, 120 | eps=1e-7): 121 | """ 122 | Computes one-to-one matching based on the weighted costs. 123 | This method assign each query prediction to a ground truth or 124 | background. The `assigned_gt_inds` with -1 means don't care, 125 | 0 means negative sample, and positive number is the index (1-based) 126 | of assigned gt. 127 | The assignment is done in the following steps, the order matters. 128 | 1. assign every prediction to -1 129 | 2. compute the weighted costs 130 | 3. do Hungarian matching on CPU based on the costs 131 | 4. assign all to 0 (background) first, then for each matched pair 132 | between predictions and gts, treat this prediction as foreground 133 | and assign the corresponding gt index (plus 1) to it. 134 | Args: 135 | lines_pred (Tensor): predicted normalized lines: 136 | [num_query, num_points, 2] 137 | cls_pred (Tensor): Predicted classification logits, shape 138 | [num_query, num_class]. 139 | 140 | lines_gt (Tensor): Ground truth lines 141 | [num_gt, num_points, 2]. 142 | labels_gt (Tensor): Label of `gt_bboxes`, shape (num_gt,). 143 | gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are 144 | labelled as `ignored`. Default None. 145 | eps (int | float, optional): A value added to the denominator for 146 | numerical stability. Default 1e-7. 147 | Returns: 148 | :obj:`AssignResult`: The assigned result. 149 | """ 150 | assert gt_bboxes_ignore is None, \ 151 | 'Only case when gt_bboxes_ignore is None is supported.' 152 | 153 | num_gts, num_lines = gts['lines'].size(0), preds['lines'].size(0) 154 | if num_gts == 0 or num_lines == 0: 155 | return None, None, None 156 | 157 | # compute the weighted costs 158 | gt_permute_idx = None # (num_preds, num_gts) 159 | if self.cost.reg_cost.permute: 160 | cost, gt_permute_idx = self.cost(preds, gts, ignore_cls_cost) 161 | else: 162 | cost = self.cost(preds, gts, ignore_cls_cost) 163 | 164 | # do Hungarian matching on CPU using linear_sum_assignment 165 | cost = cost.detach().cpu().numpy() 166 | matched_row_inds, matched_col_inds = linear_sum_assignment(cost) 167 | return matched_row_inds, matched_col_inds, gt_permute_idx -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/motion/__init__.py: -------------------------------------------------------------------------------- 1 | from .motion_planning_head import MotionPlanningHead 2 | from .motion_blocks import MotionPlanningRefinementModule 3 | from .instance_queue import InstanceQueue 4 | from .target import MotionTarget, PlanningTarget 5 | from .decoder import SparseBox3DMotionDecoder, HierarchicalPlanningDecoder 6 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/motion/instance_queue.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from mmcv.utils import build_from_cfg 8 | from mmcv.cnn.bricks.registry import PLUGIN_LAYERS 9 | 10 | from projects.mmdet3d_plugin.ops import feature_maps_format 11 | from projects.mmdet3d_plugin.core.box3d import * 12 | 13 | 14 | @PLUGIN_LAYERS.register_module() 15 | class InstanceQueue(nn.Module): 16 | def __init__( 17 | self, 18 | embed_dims, 19 | queue_length=0, 20 | tracking_threshold=0, 21 | feature_map_scale=None, 22 | ): 23 | super(InstanceQueue, self).__init__() 24 | self.embed_dims = embed_dims 25 | self.queue_length = queue_length 26 | self.tracking_threshold = tracking_threshold 27 | 28 | kernel_size = tuple([int(x / 2) for x in feature_map_scale]) 29 | self.ego_feature_encoder = nn.Sequential( 30 | nn.Conv2d(embed_dims, embed_dims, 3, stride=1, padding=1, bias=False), 31 | nn.BatchNorm2d(embed_dims), 32 | nn.Conv2d(embed_dims, embed_dims, 3, stride=2, padding=1, bias=False), 33 | nn.BatchNorm2d(embed_dims), 34 | nn.ReLU(), 35 | nn.AvgPool2d(kernel_size), 36 | ) 37 | self.ego_anchor = nn.Parameter( 38 | torch.tensor([[0, 0.5, -1.84 + 1.56/2, np.log(4.08), np.log(1.73), np.log(1.56), 1, 0, 0, 0, 0],], dtype=torch.float32), 39 | requires_grad=False, 40 | ) 41 | 42 | self.reset() 43 | 44 | def reset(self): 45 | self.metas = None 46 | self.prev_instance_id = None 47 | self.prev_confidence = None 48 | self.period = None 49 | self.instance_feature_queue = [] 50 | self.anchor_queue = [] 51 | self.prev_ego_status = None 52 | self.ego_period = None 53 | self.ego_feature_queue = [] 54 | self.ego_anchor_queue = [] 55 | 56 | def get( 57 | self, 58 | det_output, 59 | feature_maps, 60 | metas, 61 | batch_size, 62 | mask, 63 | anchor_handler, 64 | ): 65 | if ( 66 | self.period is not None 67 | and batch_size == self.period.shape[0] 68 | ): 69 | if anchor_handler is not None: 70 | T_temp2cur = feature_maps[0].new_tensor( 71 | np.stack( 72 | [ 73 | x["T_global_inv"] 74 | @ self.metas["img_metas"][i]["T_global"] 75 | for i, x in enumerate(metas["img_metas"]) 76 | ] 77 | ) 78 | ) 79 | for i in range(len(self.anchor_queue)): 80 | temp_anchor = self.anchor_queue[i] 81 | temp_anchor = anchor_handler.anchor_projection( 82 | temp_anchor, 83 | [T_temp2cur], 84 | )[0] 85 | self.anchor_queue[i] = temp_anchor 86 | for i in range(len(self.ego_anchor_queue)): 87 | temp_anchor = self.ego_anchor_queue[i] 88 | temp_anchor = anchor_handler.anchor_projection( 89 | temp_anchor, 90 | [T_temp2cur], 91 | )[0] 92 | self.ego_anchor_queue[i] = temp_anchor 93 | else: 94 | self.reset() 95 | 96 | self.prepare_motion(det_output, mask) 97 | ego_feature, ego_anchor = self.prepare_planning(feature_maps, mask, batch_size) 98 | 99 | # temporal 100 | temp_instance_feature = torch.stack(self.instance_feature_queue, dim=2) 101 | temp_anchor = torch.stack(self.anchor_queue, dim=2) 102 | temp_ego_feature = torch.stack(self.ego_feature_queue, dim=2) 103 | temp_ego_anchor = torch.stack(self.ego_anchor_queue, dim=2) 104 | 105 | period = torch.cat([self.period, self.ego_period], dim=1) 106 | temp_instance_feature = torch.cat([temp_instance_feature, temp_ego_feature], dim=1) 107 | temp_anchor = torch.cat([temp_anchor, temp_ego_anchor], dim=1) 108 | num_agent = temp_anchor.shape[1] 109 | 110 | temp_mask = torch.arange(len(self.anchor_queue), 0, -1, device=temp_anchor.device) 111 | temp_mask = temp_mask[None, None].repeat((batch_size, num_agent, 1)) 112 | temp_mask = torch.gt(temp_mask, period[..., None]) 113 | 114 | return ego_feature, ego_anchor, temp_instance_feature, temp_anchor, temp_mask 115 | 116 | def prepare_motion( 117 | self, 118 | det_output, 119 | mask, 120 | ): 121 | instance_feature = det_output["instance_feature"] 122 | det_anchors = det_output["prediction"][-1] 123 | 124 | if self.period == None: 125 | self.period = instance_feature.new_zeros(instance_feature.shape[:2]).long() 126 | else: 127 | instance_id = det_output['instance_id'] 128 | prev_instance_id = self.prev_instance_id 129 | match = instance_id[..., None] == prev_instance_id[:, None] 130 | if self.tracking_threshold > 0: 131 | temp_mask = self.prev_confidence > self.tracking_threshold 132 | match = match * temp_mask.unsqueeze(1) 133 | 134 | for i in range(len(self.instance_feature_queue)): 135 | temp_feature = self.instance_feature_queue[i] 136 | temp_feature = ( 137 | match[..., None] * temp_feature[:, None] 138 | ).sum(dim=2) 139 | self.instance_feature_queue[i] = temp_feature 140 | 141 | temp_anchor = self.anchor_queue[i] 142 | temp_anchor = ( 143 | match[..., None] * temp_anchor[:, None] 144 | ).sum(dim=2) 145 | self.anchor_queue[i] = temp_anchor 146 | 147 | self.period = ( 148 | match * self.period[:, None] 149 | ).sum(dim=2) 150 | 151 | self.instance_feature_queue.append(instance_feature.detach()) 152 | self.anchor_queue.append(det_anchors.detach()) 153 | self.period += 1 154 | 155 | if len(self.instance_feature_queue) > self.queue_length: 156 | self.instance_feature_queue.pop(0) 157 | self.anchor_queue.pop(0) 158 | self.period = torch.clip(self.period, 0, self.queue_length) 159 | 160 | def prepare_planning( 161 | self, 162 | feature_maps, 163 | mask, 164 | batch_size, 165 | ): 166 | ## ego instance init 167 | feature_maps_inv = feature_maps_format(feature_maps, inverse=True) 168 | feature_map = feature_maps_inv[0][-1][:, 0] 169 | ego_feature = self.ego_feature_encoder(feature_map) 170 | ego_feature = ego_feature.unsqueeze(1).squeeze(-1).squeeze(-1) 171 | 172 | ego_anchor = torch.tile( 173 | self.ego_anchor[None], (batch_size, 1, 1) 174 | ) 175 | if self.prev_ego_status is not None: 176 | prev_ego_status = torch.where( 177 | mask[:, None, None], 178 | self.prev_ego_status, 179 | self.prev_ego_status.new_tensor(0), 180 | ) 181 | ego_anchor[..., VY] = prev_ego_status[..., 6] 182 | 183 | if self.ego_period == None: 184 | self.ego_period = ego_feature.new_zeros((batch_size, 1)).long() 185 | else: 186 | self.ego_period = torch.where( 187 | mask[:, None], 188 | self.ego_period, 189 | self.ego_period.new_tensor(0), 190 | ) 191 | 192 | self.ego_feature_queue.append(ego_feature.detach()) 193 | self.ego_anchor_queue.append(ego_anchor.detach()) 194 | self.ego_period += 1 195 | 196 | if len(self.ego_feature_queue) > self.queue_length: 197 | self.ego_feature_queue.pop(0) 198 | self.ego_anchor_queue.pop(0) 199 | self.ego_period = torch.clip(self.ego_period, 0, self.queue_length) 200 | 201 | return ego_feature, ego_anchor 202 | 203 | def cache_motion(self, instance_feature, det_output, metas): 204 | det_classification = det_output["classification"][-1].sigmoid() 205 | det_confidence = det_classification.max(dim=-1).values 206 | instance_id = det_output['instance_id'] 207 | self.metas = metas 208 | self.prev_confidence = det_confidence.detach() 209 | self.prev_instance_id = instance_id 210 | 211 | def cache_planning(self, ego_feature, ego_status): 212 | self.prev_ego_status = ego_status.detach() 213 | self.ego_feature_queue[-1] = ego_feature.detach() 214 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/motion/motion_blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from mmcv.cnn import Linear, Scale, bias_init_with_prob 6 | from mmcv.runner.base_module import Sequential, BaseModule 7 | from mmcv.cnn import xavier_init 8 | from mmcv.cnn.bricks.registry import ( 9 | PLUGIN_LAYERS, 10 | ) 11 | 12 | from projects.mmdet3d_plugin.core.box3d import * 13 | from ..blocks import linear_relu_ln 14 | 15 | 16 | @PLUGIN_LAYERS.register_module() 17 | class MotionPlanningRefinementModule(BaseModule): 18 | def __init__( 19 | self, 20 | embed_dims=256, 21 | fut_ts=12, 22 | fut_mode=6, 23 | ego_fut_ts=6, 24 | ego_fut_mode=3, 25 | ): 26 | super(MotionPlanningRefinementModule, self).__init__() 27 | self.embed_dims = embed_dims 28 | self.fut_ts = fut_ts 29 | self.fut_mode = fut_mode 30 | self.ego_fut_ts = ego_fut_ts 31 | self.ego_fut_mode = ego_fut_mode 32 | 33 | self.motion_cls_branch = nn.Sequential( 34 | *linear_relu_ln(embed_dims, 1, 2), 35 | Linear(embed_dims, 1), 36 | ) 37 | self.motion_reg_branch = nn.Sequential( 38 | nn.Linear(embed_dims, embed_dims), 39 | nn.ReLU(), 40 | nn.Linear(embed_dims, embed_dims), 41 | nn.ReLU(), 42 | nn.Linear(embed_dims, fut_ts * 2), 43 | ) 44 | self.plan_cls_branch = nn.Sequential( 45 | *linear_relu_ln(embed_dims, 1, 2), 46 | Linear(embed_dims, 1), 47 | ) 48 | self.plan_reg_branch = nn.Sequential( 49 | nn.Linear(embed_dims, embed_dims), 50 | nn.ReLU(), 51 | nn.Linear(embed_dims, embed_dims), 52 | nn.ReLU(), 53 | nn.Linear(embed_dims, ego_fut_ts * 2), 54 | ) 55 | self.plan_status_branch = nn.Sequential( 56 | nn.Linear(embed_dims, embed_dims), 57 | nn.ReLU(), 58 | nn.Linear(embed_dims, embed_dims), 59 | nn.ReLU(), 60 | nn.Linear(embed_dims, 10), 61 | ) 62 | 63 | def init_weight(self): 64 | bias_init = bias_init_with_prob(0.01) 65 | nn.init.constant_(self.motion_cls_branch[-1].bias, bias_init) 66 | nn.init.constant_(self.plan_cls_branch[-1].bias, bias_init) 67 | 68 | def forward( 69 | self, 70 | motion_query, 71 | plan_query, 72 | ego_feature, 73 | ego_anchor_embed, 74 | ): 75 | bs, num_anchor = motion_query.shape[:2] 76 | motion_cls = self.motion_cls_branch(motion_query).squeeze(-1) 77 | motion_reg = self.motion_reg_branch(motion_query).reshape(bs, num_anchor, self.fut_mode, self.fut_ts, 2) 78 | plan_cls = self.plan_cls_branch(plan_query).squeeze(-1) 79 | plan_reg = self.plan_reg_branch(plan_query).reshape(bs, 1, 3 * self.ego_fut_mode, self.ego_fut_ts, 2) 80 | planning_status = self.plan_status_branch(ego_feature + ego_anchor_embed) 81 | return motion_cls, motion_reg, plan_cls, plan_reg, planning_status -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/motion/target.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mmdet.core.bbox.builder import BBOX_SAMPLERS 4 | 5 | __all__ = ["MotionTarget", "PlanningTarget"] 6 | 7 | 8 | def get_cls_target( 9 | reg_preds, 10 | reg_target, 11 | reg_weight, 12 | ): 13 | bs, num_pred, mode, ts, d = reg_preds.shape 14 | reg_preds_cum = reg_preds.cumsum(dim=-2) 15 | reg_target_cum = reg_target.cumsum(dim=-2) 16 | dist = torch.linalg.norm(reg_target_cum.unsqueeze(2) - reg_preds_cum, dim=-1) 17 | dist = dist * reg_weight.unsqueeze(2) 18 | dist = dist.mean(dim=-1) 19 | mode_idx = torch.argmin(dist, dim=-1) 20 | return mode_idx 21 | 22 | def get_best_reg( 23 | reg_preds, 24 | reg_target, 25 | reg_weight, 26 | ): 27 | bs, num_pred, mode, ts, d = reg_preds.shape 28 | reg_preds_cum = reg_preds.cumsum(dim=-2) 29 | reg_target_cum = reg_target.cumsum(dim=-2) 30 | dist = torch.linalg.norm(reg_target_cum.unsqueeze(2) - reg_preds_cum, dim=-1) 31 | dist = dist * reg_weight.unsqueeze(2) 32 | dist = dist.mean(dim=-1) 33 | mode_idx = torch.argmin(dist, dim=-1) 34 | mode_idx = mode_idx[..., None, None, None].repeat(1, 1, 1, ts, d) 35 | best_reg = torch.gather(reg_preds, 2, mode_idx).squeeze(2) 36 | return best_reg 37 | 38 | 39 | @BBOX_SAMPLERS.register_module() 40 | class MotionTarget(): 41 | def __init__( 42 | self, 43 | ): 44 | super(MotionTarget, self).__init__() 45 | 46 | def sample( 47 | self, 48 | reg_pred, 49 | gt_reg_target, 50 | gt_reg_mask, 51 | motion_loss_cache, 52 | ): 53 | bs, num_anchor, mode, ts, d = reg_pred.shape 54 | reg_target = reg_pred.new_zeros((bs, num_anchor, ts, d)) 55 | reg_weight = reg_pred.new_zeros((bs, num_anchor, ts)) 56 | indices = motion_loss_cache['indices'] 57 | num_pos = reg_pred.new_tensor([0]) 58 | for i, (pred_idx, target_idx) in enumerate(indices): 59 | if len(gt_reg_target[i]) == 0: 60 | continue 61 | reg_target[i, pred_idx] = gt_reg_target[i][target_idx] 62 | reg_weight[i, pred_idx] = gt_reg_mask[i][target_idx] 63 | num_pos += len(pred_idx) 64 | 65 | cls_target = get_cls_target(reg_pred, reg_target, reg_weight) 66 | cls_weight = reg_weight.any(dim=-1) 67 | best_reg = get_best_reg(reg_pred, reg_target, reg_weight) 68 | 69 | return cls_target, cls_weight, best_reg, reg_target, reg_weight, num_pos 70 | 71 | 72 | @BBOX_SAMPLERS.register_module() 73 | class PlanningTarget(): 74 | def __init__( 75 | self, 76 | ego_fut_ts, 77 | ego_fut_mode, 78 | ): 79 | super(PlanningTarget, self).__init__() 80 | self.ego_fut_ts = ego_fut_ts 81 | self.ego_fut_mode = ego_fut_mode 82 | 83 | def sample( 84 | self, 85 | cls_pred, 86 | reg_pred, 87 | gt_reg_target, 88 | gt_reg_mask, 89 | data, 90 | ): 91 | gt_reg_target = gt_reg_target.unsqueeze(1) 92 | gt_reg_mask = gt_reg_mask.unsqueeze(1) 93 | 94 | bs = reg_pred.shape[0] 95 | bs_indices = torch.arange(bs, device=reg_pred.device) 96 | cmd = data['gt_ego_fut_cmd'].argmax(dim=-1) 97 | 98 | cls_pred = cls_pred.reshape(bs, 3, 1, self.ego_fut_mode) 99 | reg_pred = reg_pred.reshape(bs, 3, 1, self.ego_fut_mode, self.ego_fut_ts, 2) 100 | cls_pred = cls_pred[bs_indices, cmd] 101 | reg_pred = reg_pred[bs_indices, cmd] 102 | cls_target = get_cls_target(reg_pred, gt_reg_target, gt_reg_mask) 103 | cls_weight = gt_reg_mask.any(dim=-1) 104 | best_reg = get_best_reg(reg_pred, gt_reg_target, gt_reg_mask) 105 | 106 | return cls_pred, cls_target, cls_weight, best_reg, gt_reg_target, gt_reg_mask 107 | 108 | 109 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/sparsedrive.py: -------------------------------------------------------------------------------- 1 | from inspect import signature 2 | 3 | import torch 4 | 5 | from mmcv.runner import force_fp32, auto_fp16 6 | from mmcv.utils import build_from_cfg 7 | from mmcv.cnn.bricks.registry import PLUGIN_LAYERS 8 | from mmdet.models import ( 9 | DETECTORS, 10 | BaseDetector, 11 | build_backbone, 12 | build_head, 13 | build_neck, 14 | ) 15 | from .grid_mask import GridMask 16 | 17 | try: 18 | from ..ops import feature_maps_format 19 | DAF_VALID = True 20 | except: 21 | DAF_VALID = False 22 | 23 | __all__ = ["SparseDrive"] 24 | 25 | 26 | @DETECTORS.register_module() 27 | class SparseDrive(BaseDetector): 28 | def __init__( 29 | self, 30 | img_backbone, 31 | head, 32 | img_neck=None, 33 | init_cfg=None, 34 | train_cfg=None, 35 | test_cfg=None, 36 | pretrained=None, 37 | use_grid_mask=True, 38 | use_deformable_func=False, 39 | depth_branch=None, 40 | ): 41 | super(SparseDrive, self).__init__(init_cfg=init_cfg) 42 | if pretrained is not None: 43 | backbone.pretrained = pretrained 44 | self.img_backbone = build_backbone(img_backbone) 45 | if img_neck is not None: 46 | self.img_neck = build_neck(img_neck) 47 | self.head = build_head(head) 48 | self.use_grid_mask = use_grid_mask 49 | if use_deformable_func: 50 | assert DAF_VALID, "deformable_aggregation needs to be set up." 51 | self.use_deformable_func = use_deformable_func 52 | if depth_branch is not None: 53 | self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS) 54 | else: 55 | self.depth_branch = None 56 | if use_grid_mask: 57 | self.grid_mask = GridMask( 58 | True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7 59 | ) 60 | 61 | @auto_fp16(apply_to=("img",), out_fp32=True) 62 | def extract_feat(self, img, return_depth=False, metas=None): 63 | bs = img.shape[0] 64 | if img.dim() == 5: # multi-view 65 | num_cams = img.shape[1] 66 | img = img.flatten(end_dim=1) 67 | else: 68 | num_cams = 1 69 | if self.use_grid_mask: 70 | img = self.grid_mask(img) 71 | if "metas" in signature(self.img_backbone.forward).parameters: 72 | feature_maps = self.img_backbone(img, num_cams, metas=metas) 73 | else: 74 | feature_maps = self.img_backbone(img) 75 | if self.img_neck is not None: 76 | feature_maps = list(self.img_neck(feature_maps)) 77 | for i, feat in enumerate(feature_maps): 78 | feature_maps[i] = torch.reshape( 79 | feat, (bs, num_cams) + feat.shape[1:] 80 | ) 81 | if return_depth and self.depth_branch is not None: 82 | depths = self.depth_branch(feature_maps, metas.get("focal")) 83 | else: 84 | depths = None 85 | if self.use_deformable_func: 86 | feature_maps = feature_maps_format(feature_maps) 87 | if return_depth: 88 | return feature_maps, depths 89 | return feature_maps 90 | 91 | @force_fp32(apply_to=("img",)) 92 | def forward(self, img, **data): 93 | if self.training: 94 | return self.forward_train(img, **data) 95 | else: 96 | return self.forward_test(img, **data) 97 | 98 | def forward_train(self, img, **data): 99 | feature_maps, depths = self.extract_feat(img, True, data) 100 | model_outs = self.head(feature_maps, data) 101 | output = self.head.loss(model_outs, data) 102 | if depths is not None and "gt_depth" in data: 103 | output["loss_dense_depth"] = self.depth_branch.loss( 104 | depths, data["gt_depth"] 105 | ) 106 | return output 107 | 108 | def forward_test(self, img, **data): 109 | if isinstance(img, list): 110 | return self.aug_test(img, **data) 111 | else: 112 | return self.simple_test(img, **data) 113 | 114 | def simple_test(self, img, **data): 115 | feature_maps = self.extract_feat(img) 116 | 117 | model_outs = self.head(feature_maps, data) 118 | results = self.head.post_process(model_outs, data) 119 | output = [{"img_bbox": result} for result in results] 120 | return output 121 | 122 | def aug_test(self, img, **data): 123 | # fake test time augmentation 124 | for key in data.keys(): 125 | if isinstance(data[key], list): 126 | data[key] = data[key][0] 127 | return self.simple_test(img[0], **data) 128 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/sparsedrive_head.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | 8 | from mmcv.runner import BaseModule 9 | from mmdet.models import HEADS 10 | from mmdet.models import build_head 11 | 12 | 13 | @HEADS.register_module() 14 | class SparseDriveHead(BaseModule): 15 | def __init__( 16 | self, 17 | task_config: dict, 18 | det_head = dict, 19 | map_head = dict, 20 | motion_plan_head = dict, 21 | init_cfg=None, 22 | **kwargs, 23 | ): 24 | super(SparseDriveHead, self).__init__(init_cfg) 25 | self.task_config = task_config 26 | if self.task_config['with_det']: 27 | self.det_head = build_head(det_head) 28 | if self.task_config['with_map']: 29 | self.map_head = build_head(map_head) 30 | if self.task_config['with_motion_plan']: 31 | self.motion_plan_head = build_head(motion_plan_head) 32 | 33 | def init_weights(self): 34 | if self.task_config['with_det']: 35 | self.det_head.init_weights() 36 | if self.task_config['with_map']: 37 | self.map_head.init_weights() 38 | if self.task_config['with_motion_plan']: 39 | self.motion_plan_head.init_weights() 40 | 41 | def forward( 42 | self, 43 | feature_maps: Union[torch.Tensor, List], 44 | metas: dict, 45 | ): 46 | if self.task_config['with_det']: 47 | det_output = self.det_head(feature_maps, metas) 48 | else: 49 | det_output = None 50 | 51 | if self.task_config['with_map']: 52 | map_output = self.map_head(feature_maps, metas) 53 | else: 54 | map_output = None 55 | 56 | if self.task_config['with_motion_plan']: 57 | motion_output, planning_output = self.motion_plan_head( 58 | det_output, 59 | map_output, 60 | feature_maps, 61 | metas, 62 | self.det_head.anchor_encoder, 63 | self.det_head.instance_bank.mask, 64 | self.det_head.instance_bank.anchor_handler, 65 | ) 66 | else: 67 | motion_output, planning_output = None, None 68 | 69 | return det_output, map_output, motion_output, planning_output 70 | 71 | def loss(self, model_outs, data): 72 | det_output, map_output, motion_output, planning_output = model_outs 73 | losses = dict() 74 | if self.task_config['with_det']: 75 | loss_det = self.det_head.loss(det_output, data) 76 | losses.update(loss_det) 77 | 78 | if self.task_config['with_map']: 79 | loss_map = self.map_head.loss(map_output, data) 80 | losses.update(loss_map) 81 | 82 | if self.task_config['with_motion_plan']: 83 | motion_loss_cache = dict( 84 | indices=self.det_head.sampler.indices, 85 | ) 86 | loss_motion = self.motion_plan_head.loss( 87 | motion_output, 88 | planning_output, 89 | data, 90 | motion_loss_cache 91 | ) 92 | losses.update(loss_motion) 93 | 94 | return losses 95 | 96 | def post_process(self, model_outs, data): 97 | det_output, map_output, motion_output, planning_output = model_outs 98 | if self.task_config['with_det']: 99 | det_result = self.det_head.post_process(det_output) 100 | batch_size = len(det_result) 101 | 102 | if self.task_config['with_map']: 103 | map_result= self.map_head.post_process(map_output) 104 | batch_size = len(map_result) 105 | 106 | if self.task_config['with_motion_plan']: 107 | motion_result, planning_result = self.motion_plan_head.post_process( 108 | det_output, 109 | motion_output, 110 | planning_output, 111 | data, 112 | ) 113 | 114 | results = [dict()] * batch_size 115 | for i in range(batch_size): 116 | if self.task_config['with_det']: 117 | results[i].update(det_result[i]) 118 | if self.task_config['with_map']: 119 | results[i].update(map_result[i]) 120 | if self.task_config['with_motion_plan']: 121 | results[i].update(motion_result[i]) 122 | results[i].update(planning_result[i]) 123 | 124 | return results 125 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .deformable_aggregation import DeformableAggregationFunction 4 | 5 | 6 | def deformable_aggregation_function( 7 | feature_maps, 8 | spatial_shape, 9 | scale_start_index, 10 | sampling_location, 11 | weights, 12 | ): 13 | return DeformableAggregationFunction.apply( 14 | feature_maps, 15 | spatial_shape, 16 | scale_start_index, 17 | sampling_location, 18 | weights, 19 | ) 20 | 21 | 22 | def feature_maps_format(feature_maps, inverse=False): 23 | if inverse: 24 | col_feats, spatial_shape, scale_start_index = feature_maps 25 | num_cams, num_levels = spatial_shape.shape[:2] 26 | 27 | split_size = spatial_shape[..., 0] * spatial_shape[..., 1] 28 | split_size = split_size.cpu().numpy().tolist() 29 | 30 | idx = 0 31 | cam_split = [1] 32 | cam_split_size = [sum(split_size[0])] 33 | for i in range(num_cams - 1): 34 | if not torch.all(spatial_shape[i] == spatial_shape[i + 1]): 35 | cam_split.append(0) 36 | cam_split_size.append(0) 37 | cam_split[-1] += 1 38 | cam_split_size[-1] += sum(split_size[i + 1]) 39 | mc_feat = [ 40 | x.unflatten(1, (cam_split[i], -1)) 41 | for i, x in enumerate(col_feats.split(cam_split_size, dim=1)) 42 | ] 43 | 44 | spatial_shape = spatial_shape.cpu().numpy().tolist() 45 | mc_ms_feat = [] 46 | shape_index = 0 47 | for i, feat in enumerate(mc_feat): 48 | feat = list(feat.split(split_size[shape_index], dim=2)) 49 | for j, f in enumerate(feat): 50 | feat[j] = f.unflatten(2, spatial_shape[shape_index][j]) 51 | feat[j] = feat[j].permute(0, 1, 4, 2, 3) 52 | mc_ms_feat.append(feat) 53 | shape_index += cam_split[i] 54 | return mc_ms_feat 55 | 56 | if isinstance(feature_maps[0], (list, tuple)): 57 | formated = [feature_maps_format(x) for x in feature_maps] 58 | col_feats = torch.cat([x[0] for x in formated], dim=1) 59 | spatial_shape = torch.cat([x[1] for x in formated], dim=0) 60 | scale_start_index = torch.cat([x[2] for x in formated], dim=0) 61 | return [col_feats, spatial_shape, scale_start_index] 62 | 63 | bs, num_cams = feature_maps[0].shape[:2] 64 | spatial_shape = [] 65 | 66 | col_feats = [] 67 | for i, feat in enumerate(feature_maps): 68 | spatial_shape.append(feat.shape[-2:]) 69 | col_feats.append( 70 | torch.reshape(feat, (bs, num_cams, feat.shape[2], -1)) 71 | ) 72 | 73 | col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2).flatten(1, 2) 74 | spatial_shape = [spatial_shape] * num_cams 75 | spatial_shape = torch.tensor( 76 | spatial_shape, 77 | dtype=torch.int64, 78 | device=col_feats.device, 79 | ) 80 | scale_start_index = spatial_shape[..., 0] * spatial_shape[..., 1] 81 | scale_start_index = scale_start_index.flatten().cumsum(dim=0) 82 | scale_start_index = torch.cat( 83 | [torch.tensor([0]).to(scale_start_index), scale_start_index[:-1]] 84 | ) 85 | scale_start_index = scale_start_index.reshape(num_cams, -1) 86 | 87 | feature_maps = [ 88 | col_feats, 89 | spatial_shape, 90 | scale_start_index, 91 | ] 92 | return feature_maps 93 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/deformable_aggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function, once_differentiable 3 | 4 | from . import deformable_aggregation_ext 5 | 6 | 7 | class DeformableAggregationFunction(Function): 8 | @staticmethod 9 | def forward( 10 | ctx, 11 | mc_ms_feat, 12 | spatial_shape, 13 | scale_start_index, 14 | sampling_location, 15 | weights, 16 | ): 17 | # output: [bs, num_pts, num_embeds] 18 | mc_ms_feat = mc_ms_feat.contiguous().float() 19 | spatial_shape = spatial_shape.contiguous().int() 20 | scale_start_index = scale_start_index.contiguous().int() 21 | sampling_location = sampling_location.contiguous().float() 22 | weights = weights.contiguous().float() 23 | output = deformable_aggregation_ext.deformable_aggregation_forward( 24 | mc_ms_feat, 25 | spatial_shape, 26 | scale_start_index, 27 | sampling_location, 28 | weights, 29 | ) 30 | ctx.save_for_backward( 31 | mc_ms_feat, 32 | spatial_shape, 33 | scale_start_index, 34 | sampling_location, 35 | weights, 36 | ) 37 | return output 38 | 39 | @staticmethod 40 | @once_differentiable 41 | def backward(ctx, grad_output): 42 | ( 43 | mc_ms_feat, 44 | spatial_shape, 45 | scale_start_index, 46 | sampling_location, 47 | weights, 48 | ) = ctx.saved_tensors 49 | mc_ms_feat = mc_ms_feat.contiguous().float() 50 | spatial_shape = spatial_shape.contiguous().int() 51 | scale_start_index = scale_start_index.contiguous().int() 52 | sampling_location = sampling_location.contiguous().float() 53 | weights = weights.contiguous().float() 54 | 55 | grad_mc_ms_feat = torch.zeros_like(mc_ms_feat) 56 | grad_sampling_location = torch.zeros_like(sampling_location) 57 | grad_weights = torch.zeros_like(weights) 58 | deformable_aggregation_ext.deformable_aggregation_backward( 59 | mc_ms_feat, 60 | spatial_shape, 61 | scale_start_index, 62 | sampling_location, 63 | weights, 64 | grad_output.contiguous(), 65 | grad_mc_ms_feat, 66 | grad_sampling_location, 67 | grad_weights, 68 | ) 69 | return ( 70 | grad_mc_ms_feat, 71 | None, 72 | None, 73 | grad_sampling_location, 74 | grad_weights, 75 | ) 76 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from setuptools import setup 5 | from torch.utils.cpp_extension import ( 6 | BuildExtension, 7 | CppExtension, 8 | CUDAExtension, 9 | ) 10 | 11 | 12 | def make_cuda_ext( 13 | name, 14 | module, 15 | sources, 16 | sources_cuda=[], 17 | extra_args=[], 18 | extra_include_path=[], 19 | ): 20 | 21 | define_macros = [] 22 | extra_compile_args = {"cxx": [] + extra_args} 23 | 24 | if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1": 25 | define_macros += [("WITH_CUDA", None)] 26 | extension = CUDAExtension 27 | extra_compile_args["nvcc"] = extra_args + [ 28 | "-D__CUDA_NO_HALF_OPERATORS__", 29 | "-D__CUDA_NO_HALF_CONVERSIONS__", 30 | "-D__CUDA_NO_HALF2_OPERATORS__", 31 | ] 32 | sources += sources_cuda 33 | else: 34 | print("Compiling {} without CUDA".format(name)) 35 | extension = CppExtension 36 | 37 | return extension( 38 | name="{}.{}".format(module, name), 39 | sources=[os.path.join(*module.split("."), p) for p in sources], 40 | include_dirs=extra_include_path, 41 | define_macros=define_macros, 42 | extra_compile_args=extra_compile_args, 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | setup( 48 | name="deformable_aggregation_ext", 49 | ext_modules=[ 50 | make_cuda_ext( 51 | "deformable_aggregation_ext", 52 | module=".", 53 | sources=[ 54 | f"src/deformable_aggregation.cpp", 55 | f"src/deformable_aggregation_cuda.cu", 56 | ], 57 | ), 58 | ], 59 | cmdclass={"build_ext": BuildExtension}, 60 | ) 61 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void deformable_aggregation( 5 | float* output, 6 | const float* mc_ms_feat, 7 | const int* spatial_shape, 8 | const int* scale_start_index, 9 | const float* sample_location, 10 | const float* weights, 11 | int batch_size, 12 | int num_cams, 13 | int num_feat, 14 | int num_embeds, 15 | int num_scale, 16 | int num_anchors, 17 | int num_pts, 18 | int num_groups 19 | ); 20 | 21 | 22 | /* feat: bs, num_feat, c */ 23 | /* _spatial_shape: cam, scale, 2 */ 24 | /* _scale_start_index: cam, scale */ 25 | /* _sampling_location: bs, anchor, pts, cam, 2 */ 26 | /* _weights: bs, anchor, pts, cam, scale, group */ 27 | /* output: bs, anchor, c */ 28 | /* kernel: bs, anchor, pts, c */ 29 | 30 | 31 | at::Tensor deformable_aggregation_forward( 32 | const at::Tensor &_mc_ms_feat, 33 | const at::Tensor &_spatial_shape, 34 | const at::Tensor &_scale_start_index, 35 | const at::Tensor &_sampling_location, 36 | const at::Tensor &_weights 37 | ) { 38 | at::DeviceGuard guard(_mc_ms_feat.device()); 39 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat)); 40 | int batch_size = _mc_ms_feat.size(0); 41 | int num_feat = _mc_ms_feat.size(1); 42 | int num_embeds = _mc_ms_feat.size(2); 43 | int num_cams = _spatial_shape.size(0); 44 | int num_scale = _spatial_shape.size(1); 45 | int num_anchors = _sampling_location.size(1); 46 | int num_pts = _sampling_location.size(2); 47 | int num_groups = _weights.size(5); 48 | 49 | const float* mc_ms_feat = _mc_ms_feat.data_ptr(); 50 | const int* spatial_shape = _spatial_shape.data_ptr(); 51 | const int* scale_start_index = _scale_start_index.data_ptr(); 52 | const float* sampling_location = _sampling_location.data_ptr(); 53 | const float* weights = _weights.data_ptr(); 54 | 55 | auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options()); 56 | deformable_aggregation( 57 | output.data_ptr(), 58 | mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, 59 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups 60 | ); 61 | return output; 62 | } 63 | 64 | 65 | void deformable_aggregation_grad( 66 | const float* mc_ms_feat, 67 | const int* spatial_shape, 68 | const int* scale_start_index, 69 | const float* sample_location, 70 | const float* weights, 71 | const float* grad_output, 72 | float* grad_mc_ms_feat, 73 | float* grad_sampling_location, 74 | float* grad_weights, 75 | int batch_size, 76 | int num_cams, 77 | int num_feat, 78 | int num_embeds, 79 | int num_scale, 80 | int num_anchors, 81 | int num_pts, 82 | int num_groups 83 | ); 84 | 85 | 86 | void deformable_aggregation_backward( 87 | const at::Tensor &_mc_ms_feat, 88 | const at::Tensor &_spatial_shape, 89 | const at::Tensor &_scale_start_index, 90 | const at::Tensor &_sampling_location, 91 | const at::Tensor &_weights, 92 | const at::Tensor &_grad_output, 93 | at::Tensor &_grad_mc_ms_feat, 94 | at::Tensor &_grad_sampling_location, 95 | at::Tensor &_grad_weights 96 | ) { 97 | at::DeviceGuard guard(_mc_ms_feat.device()); 98 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat)); 99 | int batch_size = _mc_ms_feat.size(0); 100 | int num_feat = _mc_ms_feat.size(1); 101 | int num_embeds = _mc_ms_feat.size(2); 102 | int num_cams = _spatial_shape.size(0); 103 | int num_scale = _spatial_shape.size(1); 104 | int num_anchors = _sampling_location.size(1); 105 | int num_pts = _sampling_location.size(2); 106 | int num_groups = _weights.size(5); 107 | 108 | const float* mc_ms_feat = _mc_ms_feat.data_ptr(); 109 | const int* spatial_shape = _spatial_shape.data_ptr(); 110 | const int* scale_start_index = _scale_start_index.data_ptr(); 111 | const float* sampling_location = _sampling_location.data_ptr(); 112 | const float* weights = _weights.data_ptr(); 113 | const float* grad_output = _grad_output.data_ptr(); 114 | 115 | float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr(); 116 | float* grad_sampling_location = _grad_sampling_location.data_ptr(); 117 | float* grad_weights = _grad_weights.data_ptr(); 118 | 119 | deformable_aggregation_grad( 120 | mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, 121 | grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights, 122 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups 123 | ); 124 | } 125 | 126 | 127 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 128 | m.def( 129 | "deformable_aggregation_forward", 130 | &deformable_aggregation_forward, 131 | "deformable_aggregation_forward" 132 | ); 133 | m.def( 134 | "deformable_aggregation_backward", 135 | &deformable_aggregation_backward, 136 | "deformable_aggregation_backward" 137 | ); 138 | } 139 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | mmcv_full==1.7.1 3 | mmdet==2.28.2 4 | urllib3==1.26.16 5 | pyquaternion==0.9.9 6 | nuscenes-devkit==1.1.10 7 | yapf==0.33.0 8 | tensorboard==2.14.0 9 | motmetrics==1.1.3 10 | pandas==1.1.5 11 | flash-attn==2.3.2 12 | opencv-python==4.8.1.78 13 | prettytable==3.7.0 14 | scikit-learn==1.3.0 15 | -------------------------------------------------------------------------------- /resources/legend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swc-17/SparseDrive/52c4c05b6d446b710c8a12eb9fb19d698b33cb2b/resources/legend.png -------------------------------------------------------------------------------- /resources/motion_planner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swc-17/SparseDrive/52c4c05b6d446b710c8a12eb9fb19d698b33cb2b/resources/motion_planner.png -------------------------------------------------------------------------------- /resources/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swc-17/SparseDrive/52c4c05b6d446b710c8a12eb9fb19d698b33cb2b/resources/overview.png -------------------------------------------------------------------------------- /resources/sdc_car.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swc-17/SparseDrive/52c4c05b6d446b710c8a12eb9fb19d698b33cb2b/resources/sdc_car.png -------------------------------------------------------------------------------- /resources/sparse_perception.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/swc-17/SparseDrive/52c4c05b6d446b710c8a12eb9fb19d698b33cb2b/resources/sparse_perception.png -------------------------------------------------------------------------------- /scripts/create_data.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="$(dirname $0)/..":$PYTHONPATH 2 | 3 | python tools/data_converter/nuscenes_converter.py nuscenes \ 4 | --root-path ./data/nuscenes \ 5 | --canbus ./data/nuscenes \ 6 | --out-dir ./data/infos/ \ 7 | --extra-tag nuscenes \ 8 | --version v1.0-mini 9 | 10 | python tools/data_converter/nuscenes_converter.py nuscenes \ 11 | --root-path ./data/nuscenes \ 12 | --canbus ./data/nuscenes \ 13 | --out-dir ./data/infos/ \ 14 | --extra-tag nuscenes \ 15 | --version v1.0 16 | 17 | -------------------------------------------------------------------------------- /scripts/kmeans.sh: -------------------------------------------------------------------------------- 1 | python tools/kmeans/kmeans_det.py 2 | python tools/kmeans/kmeans_map.py 3 | python tools/kmeans/kmeans_motion.py 4 | python tools/kmeans/kmeans_plan.py -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | bash ./tools/dist_test.sh \ 2 | projects/configs/sparsedrive_small_stage2.py \ 3 | ckpt/sparsedrive_stage2.pth \ 4 | 8 \ 5 | --deterministic \ 6 | --eval bbox 7 | # --result_file ./work_dirs/sparsedrive_small_stage2/results.pkl -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | ## stage1 2 | bash ./tools/dist_train.sh \ 3 | projects/configs/sparsedrive_small_stage1.py \ 4 | 8 \ 5 | --deterministic 6 | 7 | ## stage2 8 | bash ./tools/dist_train.sh \ 9 | projects/configs/sparsedrive_small_stage2.py \ 10 | 8 \ 11 | --deterministic -------------------------------------------------------------------------------- /scripts/visualize.sh: -------------------------------------------------------------------------------- 1 | export PYTHONPATH="$(dirname $0)/..":$PYTHONPATH 2 | python tools/visualization/visualize.py \ 3 | projects/configs/sparsedrive_small_stage2.py \ 4 | --result-path work_dirs/sparsedrive_small_stage2/results.pkl -------------------------------------------------------------------------------- /tools/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import time 4 | import torch 5 | from mmcv import Config 6 | from mmcv.parallel import MMDataParallel 7 | from mmcv.runner import load_checkpoint, wrap_fp16_model 8 | import sys 9 | sys.path.append('.') 10 | from projects.mmdet3d_plugin.datasets.builder import build_dataloader 11 | from projects.mmdet3d_plugin.datasets import custom_build_dataset 12 | from mmdet.models import build_detector 13 | from mmcv.cnn.utils.flops_counter import add_flops_counting_methods 14 | from mmcv.parallel import scatter 15 | 16 | 17 | def parse_args(): 18 | parser = argparse.ArgumentParser(description='MMDet benchmark a model') 19 | parser.add_argument('config', help='test config file path') 20 | parser.add_argument('--checkpoint', default=None, help='checkpoint file') 21 | parser.add_argument('--samples', default=1000, help='samples to benchmark') 22 | parser.add_argument( 23 | '--log-interval', default=50, help='interval of logging') 24 | parser.add_argument( 25 | '--fuse-conv-bn', 26 | action='store_true', 27 | help='Whether to fuse conv and bn, this will slightly increase' 28 | 'the inference speed') 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def get_max_memory(model): 34 | device = getattr(model, 'output_device', None) 35 | mem = torch.cuda.max_memory_allocated(device=device) 36 | mem_mb = torch.tensor([mem / (1024 * 1024)], 37 | dtype=torch.int, 38 | device=device) 39 | return mem_mb.item() 40 | 41 | 42 | def main(): 43 | args = parse_args() 44 | get_flops_params(args) 45 | get_mem_fps(args) 46 | 47 | def get_mem_fps(args): 48 | cfg = Config.fromfile(args.config) 49 | # set cudnn_benchmark 50 | if cfg.get('cudnn_benchmark', False): 51 | torch.backends.cudnn.benchmark = True 52 | cfg.model.pretrained = None 53 | cfg.data.test.test_mode = True 54 | 55 | # build the dataloader 56 | # TODO: support multiple images per gpu (only minor changes are needed) 57 | print(cfg.data.test) 58 | dataset = custom_build_dataset(cfg.data.test) 59 | data_loader = build_dataloader( 60 | dataset, 61 | samples_per_gpu=1, 62 | workers_per_gpu=cfg.data.workers_per_gpu, 63 | dist=False, 64 | shuffle=False) 65 | 66 | # build the model and load checkpoint 67 | cfg.model.train_cfg = None 68 | model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) 69 | fp16_cfg = cfg.get('fp16', None) 70 | if fp16_cfg is not None: 71 | wrap_fp16_model(model) 72 | if args.checkpoint is not None: 73 | load_checkpoint(model, args.checkpoint, map_location='cpu') 74 | # if args.fuse_conv_bn: 75 | # model = fuse_module(model) 76 | 77 | model = MMDataParallel(model, device_ids=[0]) 78 | 79 | model.eval() 80 | 81 | # the first several iterations may be very slow so skip them 82 | num_warmup = 5 83 | pure_inf_time = 0 84 | 85 | # benchmark with several samples and take the average 86 | max_memory = 0 87 | for i, data in enumerate(data_loader): 88 | # torch.cuda.synchronize() 89 | with torch.no_grad(): 90 | start_time = time.perf_counter() 91 | model(return_loss=False, rescale=True, **data) 92 | 93 | torch.cuda.synchronize() 94 | elapsed = time.perf_counter() - start_time 95 | max_memory = max(max_memory, get_max_memory(model)) 96 | 97 | if i >= num_warmup: 98 | pure_inf_time += elapsed 99 | if (i + 1) % args.log_interval == 0: 100 | fps = (i + 1 - num_warmup) / pure_inf_time 101 | print(f'Done image [{i + 1:<3}/ {args.samples}], ' 102 | f'fps: {fps:.1f} img / s, ' 103 | f"gpu mem: {max_memory} M") 104 | 105 | if (i + 1) == args.samples: 106 | pure_inf_time += elapsed 107 | fps = (i + 1 - num_warmup) / pure_inf_time 108 | print(f'Overall fps: {fps:.1f} img / s') 109 | break 110 | 111 | 112 | def get_flops_params(args): 113 | gpu_id = 0 114 | cfg = Config.fromfile(args.config) 115 | dataset = custom_build_dataset(cfg.data.val) 116 | dataloader = build_dataloader( 117 | dataset, 118 | samples_per_gpu=1, 119 | workers_per_gpu=0, 120 | dist=False, 121 | shuffle=False, 122 | ) 123 | data_iter = dataloader.__iter__() 124 | data = next(data_iter) 125 | data = scatter(data, [gpu_id])[0] 126 | 127 | cfg.model.train_cfg = None 128 | model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg')) 129 | fp16_cfg = cfg.get('fp16', None) 130 | if fp16_cfg is not None: 131 | wrap_fp16_model(model) 132 | if args.checkpoint is not None: 133 | load_checkpoint(model, args.checkpoint, map_location='cpu') 134 | model = model.cuda(gpu_id) 135 | model.eval() 136 | 137 | bilinear_flops = 11 138 | num_key_pts_det = ( 139 | cfg.model["head"]['det_head']["deformable_model"]["kps_generator"]["num_learnable_pts"] 140 | + len(cfg.model["head"]['det_head']["deformable_model"]["kps_generator"]["fix_scale"]) 141 | ) 142 | deformable_agg_flops_det = ( 143 | cfg.num_decoder 144 | * cfg.embed_dims 145 | * cfg.num_levels 146 | * cfg.model["head"]['det_head']["instance_bank"]["num_anchor"] 147 | * cfg.model["head"]['det_head']["deformable_model"]["num_cams"] 148 | * num_key_pts_det 149 | * bilinear_flops 150 | ) 151 | num_key_pts_map = ( 152 | cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["num_learnable_pts"] 153 | + len(cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["fix_height"]) 154 | ) * cfg.model["head"]['map_head']["deformable_model"]["kps_generator"]["num_sample"] 155 | deformable_agg_flops_map = ( 156 | cfg.num_decoder 157 | * cfg.embed_dims 158 | * cfg.num_levels 159 | * cfg.model["head"]['map_head']["instance_bank"]["num_anchor"] 160 | * cfg.model["head"]['map_head']["deformable_model"]["num_cams"] 161 | * num_key_pts_map 162 | * bilinear_flops 163 | ) 164 | deformable_agg_flops = deformable_agg_flops_det + deformable_agg_flops_map 165 | 166 | for module in ["total", "img_backbone", "img_neck", "head"]: 167 | if module != "total": 168 | flops_model = add_flops_counting_methods(getattr(model, module)) 169 | else: 170 | flops_model = add_flops_counting_methods(model) 171 | flops_model.eval() 172 | flops_model.start_flops_count() 173 | 174 | if module == "img_backbone": 175 | flops_model(data["img"].flatten(0, 1)) 176 | elif module == "img_neck": 177 | flops_model(model.img_backbone(data["img"].flatten(0, 1))) 178 | elif module == "head": 179 | flops_model(model.extract_feat(data["img"], metas=data), data) 180 | else: 181 | flops_model(**data) 182 | flops_count, params_count = flops_model.compute_average_flops_cost() 183 | flops_count *= flops_model.__batch_counter__ 184 | flops_model.stop_flops_count() 185 | if module == "head" or module == "total": 186 | flops_count += deformable_agg_flops 187 | if module == "total": 188 | total_flops = flops_count 189 | total_params = params_count 190 | print( 191 | f"{module:<13} complexity: " 192 | f"FLOPs={flops_count/ 10.**9:>8.4f} G / {flops_count/total_flops*100:>6.2f}%, " 193 | f"Params={params_count/10**6:>8.4f} M / {params_count/total_params*100:>6.2f}%." 194 | ) 195 | 196 | if __name__ == '__main__': 197 | main() 198 | -------------------------------------------------------------------------------- /tools/data_converter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29610} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-28651} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/fuse_conv_bn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | import torch 5 | from mmcv.runner import save_checkpoint 6 | from torch import nn as nn 7 | 8 | from mmdet3d.apis import init_model 9 | 10 | 11 | def fuse_conv_bn(conv, bn): 12 | """During inference, the functionary of batch norm layers is turned off but 13 | only the mean and var alone channels are used, which exposes the chance to 14 | fuse it with the preceding conv layers to save computations and simplify 15 | network structures.""" 16 | conv_w = conv.weight 17 | conv_b = conv.bias if conv.bias is not None else torch.zeros_like( 18 | bn.running_mean) 19 | 20 | factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) 21 | conv.weight = nn.Parameter(conv_w * 22 | factor.reshape([conv.out_channels, 1, 1, 1])) 23 | conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) 24 | return conv 25 | 26 | 27 | def fuse_module(m): 28 | last_conv = None 29 | last_conv_name = None 30 | 31 | for name, child in m.named_children(): 32 | if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)): 33 | if last_conv is None: # only fuse BN that is after Conv 34 | continue 35 | fused_conv = fuse_conv_bn(last_conv, child) 36 | m._modules[last_conv_name] = fused_conv 37 | # To reduce changes, set BN as Identity instead of deleting it. 38 | m._modules[name] = nn.Identity() 39 | last_conv = None 40 | elif isinstance(child, nn.Conv2d): 41 | last_conv = child 42 | last_conv_name = name 43 | else: 44 | fuse_module(child) 45 | return m 46 | 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser( 50 | description='fuse Conv and BN layers in a model') 51 | parser.add_argument('config', help='config file path') 52 | parser.add_argument('checkpoint', help='checkpoint file path') 53 | parser.add_argument('out', help='output path of the converted model') 54 | args = parser.parse_args() 55 | return args 56 | 57 | 58 | def main(): 59 | args = parse_args() 60 | # build the model from a config file and a checkpoint file 61 | model = init_model(args.config, args.checkpoint) 62 | # fuse conv and bn layers of the model 63 | fused_model = fuse_module(model) 64 | save_checkpoint(fused_model, args.out) 65 | 66 | 67 | if __name__ == '__main__': 68 | main() 69 | -------------------------------------------------------------------------------- /tools/kmeans/kmeans_det.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from sklearn.cluster import KMeans 8 | 9 | import mmcv 10 | 11 | os.makedirs('data/kmeans', exist_ok=True) 12 | os.makedirs('vis/kmeans', exist_ok=True) 13 | 14 | K = 900 15 | DIS_THRESH = 55 16 | 17 | fp = 'data/infos/nuscenes_infos_train.pkl' 18 | data = mmcv.load(fp) 19 | data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) 20 | center = [] 21 | for idx in tqdm(range(len(data_infos))): 22 | boxes = data_infos[idx]['gt_boxes'][:,:3] 23 | if len(boxes) == 0: 24 | continue 25 | distance = np.linalg.norm(boxes[:, :2], axis=1) 26 | center.append(boxes[distance < DIS_THRESH]) 27 | center = np.concatenate(center, axis=0) 28 | print("start clustering, may take a few minutes.") 29 | cluster = KMeans(n_clusters=K).fit(center).cluster_centers_ 30 | plt.scatter(cluster[:,0], cluster[:,1]) 31 | plt.savefig(f'vis/kmeans/det_anchor_{K}', bbox_inches='tight') 32 | others = np.array([1,1,1,1,0,0,0,0])[np.newaxis].repeat(K, axis=0) 33 | cluster = np.concatenate([cluster, others], axis=1) 34 | np.save(f'data/kmeans/kmeans_det_{K}.npy', cluster) -------------------------------------------------------------------------------- /tools/kmeans/kmeans_map.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from sklearn.cluster import KMeans 8 | 9 | import mmcv 10 | 11 | K = 100 12 | num_sample = 20 13 | 14 | fp = 'data/infos/nuscenes_infos_train.pkl' 15 | data = mmcv.load(fp) 16 | data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) 17 | center = [] 18 | for idx in tqdm(range(len(data_infos))): 19 | for cls, geoms in data_infos[idx]["map_annos"].items(): 20 | for geom in geoms: 21 | center.append(geom.mean(axis=0)) 22 | center = np.stack(center, axis=0) 23 | center = KMeans(n_clusters=K).fit(center).cluster_centers_ 24 | delta_y = np.linspace(-4, 4, num_sample) 25 | delta_x = np.zeros([num_sample]) 26 | delta = np.stack([delta_x, delta_y], axis=-1) 27 | vecs = center[:, np.newaxis] + delta[np.newaxis] 28 | 29 | for i in range(K): 30 | x = vecs[i, :, 0] 31 | y = vecs[i, :, 1] 32 | plt.plot(x, y, linewidth=1, marker='o', linestyle='-', markersize=2) 33 | plt.savefig(f'vis/kmeans/map_anchor_{K}', bbox_inches='tight') 34 | np.save(f'data/kmeans/kmeans_map_{K}.npy', vecs) -------------------------------------------------------------------------------- /tools/kmeans/kmeans_motion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from sklearn.cluster import KMeans 8 | 9 | import mmcv 10 | 11 | CLASSES = [ 12 | "car", 13 | "truck", 14 | "construction_vehicle", 15 | "bus", 16 | "trailer", 17 | "barrier", 18 | "motorcycle", 19 | "bicycle", 20 | "pedestrian", 21 | "traffic_cone", 22 | ] 23 | 24 | def lidar2agent(trajs_offset, boxes): 25 | origin = np.zeros((trajs_offset.shape[0], 1, 2), dtype=np.float32) 26 | trajs_offset = np.concatenate([origin, trajs_offset], axis=1) 27 | trajs = trajs_offset.cumsum(axis=1) 28 | yaws = - boxes[:, 6] 29 | rot_sin = np.sin(yaws) 30 | rot_cos = np.cos(yaws) 31 | rot_mat_T = np.stack( 32 | [ 33 | np.stack([rot_cos, rot_sin]), 34 | np.stack([-rot_sin, rot_cos]), 35 | ] 36 | ) 37 | trajs_new = np.einsum('aij,jka->aik', trajs, rot_mat_T) 38 | trajs_new = trajs_new[:, 1:] 39 | return trajs_new 40 | 41 | K = 6 42 | DIS_THRESH = 55 43 | 44 | fp = 'data/infos/nuscenes_infos_train.pkl' 45 | data = mmcv.load(fp) 46 | data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) 47 | intention = dict() 48 | for i in range(len(CLASSES)): 49 | intention[i] = [] 50 | for idx in tqdm(range(len(data_infos))): 51 | info = data_infos[idx] 52 | boxes = info['gt_boxes'] 53 | names = info['gt_names'] 54 | fut_masks = info['gt_agent_fut_masks'] 55 | trajs = info['gt_agent_fut_trajs'] 56 | velos = info['gt_velocity'] 57 | labels = [] 58 | for cat in names: 59 | if cat in CLASSES: 60 | labels.append(CLASSES.index(cat)) 61 | else: 62 | labels.append(-1) 63 | labels = np.array(labels) 64 | if len(boxes) == 0: 65 | continue 66 | for i in range(len(CLASSES)): 67 | cls_mask = (labels == i) 68 | box_cls = boxes[cls_mask] 69 | fut_masks_cls = fut_masks[cls_mask] 70 | trajs_cls = trajs[cls_mask] 71 | velos_cls = velos[cls_mask] 72 | 73 | distance = np.linalg.norm(box_cls[:, :2], axis=1) 74 | mask = np.logical_and( 75 | fut_masks_cls.sum(axis=1) == 12, 76 | distance < DIS_THRESH, 77 | ) 78 | trajs_cls = trajs_cls[mask] 79 | box_cls = box_cls[mask] 80 | velos_cls = velos_cls[mask] 81 | 82 | trajs_agent = lidar2agent(trajs_cls, box_cls) 83 | if trajs_agent.shape[0] == 0: 84 | continue 85 | intention[i].append(trajs_agent) 86 | 87 | clusters = [] 88 | for i in range(len(CLASSES)): 89 | intention_cls = np.concatenate(intention[i], axis=0).reshape(-1, 24) 90 | if intention_cls.shape[0] < K: 91 | continue 92 | cluster = KMeans(n_clusters=K).fit(intention_cls).cluster_centers_ 93 | cluster = cluster.reshape(-1, 12, 2) 94 | clusters.append(cluster) 95 | for j in range(K): 96 | plt.scatter(cluster[j, :, 0], cluster[j, :,1]) 97 | plt.savefig(f'vis/kmeans/motion_intention_{CLASSES[i]}_{K}', bbox_inches='tight') 98 | plt.close() 99 | 100 | clusters = np.stack(clusters, axis=0) 101 | np.save(f'data/kmeans/kmeans_motion_{K}.npy', clusters) -------------------------------------------------------------------------------- /tools/kmeans/kmeans_plan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from tqdm import tqdm 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | from sklearn.cluster import KMeans 8 | 9 | import mmcv 10 | 11 | K = 6 12 | 13 | fp = 'data/infos/nuscenes_infos_train.pkl' 14 | data = mmcv.load(fp) 15 | data_infos = list(sorted(data["infos"], key=lambda e: e["timestamp"])) 16 | navi_trajs = [[], [], []] 17 | for idx in tqdm(range(len(data_infos))): 18 | info = data_infos[idx] 19 | plan_traj = info['gt_ego_fut_trajs'].cumsum(axis=-2) 20 | plan_mask = info['gt_ego_fut_masks'] 21 | cmd = info['gt_ego_fut_cmd'].astype(np.int32) 22 | cmd = cmd.argmax(axis=-1) 23 | if not plan_mask.sum() == 6: 24 | continue 25 | navi_trajs[cmd].append(plan_traj) 26 | 27 | clusters = [] 28 | for trajs in navi_trajs: 29 | trajs = np.concatenate(trajs, axis=0).reshape(-1, 12) 30 | cluster = KMeans(n_clusters=K).fit(trajs).cluster_centers_ 31 | cluster = cluster.reshape(-1, 6, 2) 32 | clusters.append(cluster) 33 | for j in range(K): 34 | plt.scatter(cluster[j, :, 0], cluster[j, :,1]) 35 | plt.savefig(f'vis/kmeans/plan_{K}', bbox_inches='tight') 36 | plt.close() 37 | 38 | clusters = np.stack(clusters, axis=0) 39 | np.save(f'data/kmeans/kmeans_plan_{K}.npy', clusters) -------------------------------------------------------------------------------- /tools/visualization/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | import cv2 7 | import numpy as np 8 | from PIL import Image 9 | 10 | import mmcv 11 | from mmcv import Config 12 | from mmdet.datasets import build_dataset 13 | 14 | from tools.visualization.bev_render import BEVRender 15 | from tools.visualization.cam_render import CamRender 16 | 17 | plot_choices = dict( 18 | draw_pred = True, # True: draw gt and pred; False: only draw gt 19 | det = True, 20 | track = True, # True: draw history tracked boxes 21 | motion = True, 22 | map = True, 23 | planning = True, 24 | ) 25 | START = 0 26 | END = 81 27 | INTERVAL = 1 28 | 29 | 30 | class Visualizer: 31 | def __init__( 32 | self, 33 | args, 34 | plot_choices, 35 | ): 36 | self.out_dir = args.out_dir 37 | self.combine_dir = os.path.join(self.out_dir, 'combine') 38 | os.makedirs(self.combine_dir, exist_ok=True) 39 | 40 | cfg = Config.fromfile(args.config) 41 | self.dataset = build_dataset(cfg.data.val) 42 | self.results = mmcv.load(args.result_path) 43 | self.bev_render = BEVRender(plot_choices, self.out_dir) 44 | self.cam_render = CamRender(plot_choices, self.out_dir) 45 | 46 | def add_vis(self, index): 47 | data = self.dataset.get_data_info(index) 48 | result = self.results[index]['img_bbox'] 49 | 50 | bev_gt_path, bev_pred_path = self.bev_render.render(data, result, index) 51 | cam_pred_path = self.cam_render.render(data, result, index) 52 | self.combine(bev_gt_path, bev_pred_path, cam_pred_path, index) 53 | 54 | def combine(self, bev_gt_path, bev_pred_path, cam_pred_path, index): 55 | bev_gt = cv2.imread(bev_gt_path) 56 | bev_image = cv2.imread(bev_pred_path) 57 | cam_image = cv2.imread(cam_pred_path) 58 | merge_image = cv2.hconcat([cam_image, bev_image, bev_gt]) 59 | save_path = os.path.join(self.combine_dir, str(index).zfill(4) + '.jpg') 60 | cv2.imwrite(save_path, merge_image) 61 | 62 | def image2video(self, fps=12, downsample=4): 63 | imgs_path = glob.glob(os.path.join(self.combine_dir, '*.jpg')) 64 | imgs_path = sorted(imgs_path) 65 | img_array = [] 66 | for img_path in tqdm(imgs_path): 67 | img = cv2.imread(img_path) 68 | height, width, channel = img.shape 69 | img = cv2.resize(img, (width//downsample, height // 70 | downsample), interpolation=cv2.INTER_AREA) 71 | height, width, channel = img.shape 72 | size = (width, height) 73 | img_array.append(img) 74 | out_path = os.path.join(self.out_dir, 'video.mp4') 75 | out = cv2.VideoWriter( 76 | out_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, size) 77 | for i in range(len(img_array)): 78 | out.write(img_array[i]) 79 | out.release() 80 | 81 | 82 | def parse_args(): 83 | parser = argparse.ArgumentParser( 84 | description='Visualize groundtruth and results') 85 | parser.add_argument('config', help='config file path') 86 | parser.add_argument('--result-path', 87 | default=None, 88 | help='prediction result to visualize' 89 | 'If submission file is not provided, only gt will be visualized') 90 | parser.add_argument( 91 | '--out-dir', 92 | default='vis', 93 | help='directory where visualize results will be saved') 94 | args = parser.parse_args() 95 | 96 | return args 97 | 98 | def main(): 99 | args = parse_args() 100 | visualizer = Visualizer(args, plot_choices) 101 | 102 | for idx in tqdm(range(START, END, INTERVAL)): 103 | if idx > len(visualizer.results): 104 | break 105 | visualizer.add_vis(idx) 106 | 107 | visualizer.image2video() 108 | 109 | if __name__ == '__main__': 110 | main() --------------------------------------------------------------------------------