├── projects ├── __init__.py └── mmdet3d_plugin │ ├── core │ ├── evaluation │ │ ├── __init__.py │ │ └── eval_hooks.py │ └── box3d.py │ ├── __init__.py │ ├── apis │ ├── __init__.py │ ├── train.py │ ├── test.py │ └── mmdet_train.py │ ├── models │ ├── detection2d │ │ ├── __init__.py │ │ ├── blocks.py │ │ └── coster.py │ ├── detection3d │ │ ├── __init__.py │ │ ├── losses.py │ │ ├── decoder.py │ │ └── blocks.py │ ├── __init__.py │ ├── utils.py │ ├── grid_mask.py │ ├── simpb.py │ ├── aggregation.py │ ├── allocation.py │ └── instance_bank.py │ ├── datasets │ ├── samplers │ │ ├── sampler.py │ │ ├── __init__.py │ │ ├── distributed_sampler.py │ │ ├── group_sampler.py │ │ ├── infinite_group_each_sample_in_batch_sampler.py │ │ └── group_in_batch_sampler.py │ ├── __init__.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── loading.py │ │ └── transform.py │ ├── builder.py │ └── utils.py │ └── ops │ ├── setup.py │ ├── deformable_aggregation.py │ ├── __init__.py │ └── src │ ├── deformable_aggregation.cpp │ └── deformable_aggregation_cuda.cu ├── docs ├── figs │ └── arch.png ├── prepare_environment.md ├── prepare_dataset.md └── training_evaluation.md ├── requirement.txt ├── tools ├── dist_train.sh ├── dist_test.sh ├── anchor_generator.py ├── fuse_conv_bn.py ├── benchmark.py └── train.py ├── README.md └── LICENSE /projects/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/figs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nullmax-vision/SimPB/HEAD/docs/figs/arch.png -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_hooks import CustomDistEvalHook -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .models import * 3 | from .apis import * 4 | from .core.evaluation import * 5 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import custom_train_model 2 | from .mmdet_train import custom_train_detector 3 | 4 | # from .test import custom_multi_gpu_test 5 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/core/box3d.py: -------------------------------------------------------------------------------- 1 | X, Y, Z, W, L, H, SIN_YAW, COS_YAW, VX, VY, VZ = list(range(11)) # undecoded 2 | CNS, YNS = 0, 1 # centerness and yawness indices in qulity 3 | YAW = 6 # decoded 4 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | numpy==1.23.5 2 | mmcv_full==1.7.1 3 | mmdet==2.28.2 4 | urllib3==1.26.16 5 | pyquaternion==0.9.9 6 | nuscenes-devkit==1.1.10 7 | yapf==0.33.0 8 | tensorboard==2.14.0 9 | motmetrics==1.1.3 10 | pandas==1.1.5 -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .coster import SparseBox2DCoster 2 | from .target import SparseBox2DTarget 3 | from .denoise import Denoise2D 4 | from .blocks import SparseBox2DEncoder, SparseBox2DRefinementModule -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/sampler.py: -------------------------------------------------------------------------------- 1 | from mmcv.utils.registry import Registry, build_from_cfg 2 | 3 | SAMPLER = Registry("sampler") 4 | 5 | 6 | def build_sampler(cfg, default_args): 7 | return build_from_cfg(cfg, SAMPLER, default_args) 8 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .group_sampler import DistributedGroupSampler 2 | from .distributed_sampler import DistributedSampler 3 | from .sampler import SAMPLER, build_sampler 4 | from .group_in_batch_sampler import ( 5 | GroupInBatchSampler, 6 | ) 7 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .nuscenes_dataset import NuScenes3DDetTrackDataset 2 | from .builder import * 3 | from .pipelines import * 4 | from .samplers import * 5 | 6 | __all__ = [ 7 | 'NuScenes3DDetTrackDataset', 8 | "custom_build_dataset", 9 | ] 10 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-28650} 6 | 7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 8 | python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection3d/__init__.py: -------------------------------------------------------------------------------- 1 | from .decoder import SparseBox3DDecoder 2 | from .target import SparseBox3DTarget, SparseBox3DTargetWith2D 3 | from .blocks import ( 4 | SparseBox3DRefinementModule, 5 | SparseBox3DKeyPointsGenerator, 6 | SparseBox3DEncoder, 7 | ) 8 | from .losses import SparseBox3DLoss 9 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29610} 7 | 8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 9 | python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /docs/prepare_environment.md: -------------------------------------------------------------------------------- 1 | ## Prepare Environment 2 | * Linux 3 | * python 3.8 4 | * Pytorch 1.10.0+ 5 | * CUDA 11.1+ 6 | 7 | **1. Create a conda virtual environment** 8 | ```bash 9 | conda create -n simpb python=3.8 -y 10 | conda activate simpb 11 | ``` 12 | 13 | **2. Install Pytorch** 14 | ```bash 15 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html 16 | ``` 17 | 18 | **3. Install other packages** 19 | ```bash 20 | pip install --upgrade pip 21 | pip install -r requirement.txt 22 | ``` 23 | 24 | **4. Compile the deformable_aggregation CUDA op** 25 | ```bash 26 | cd projects/mmdet3d_plugin/ops 27 | python setup.py develop 28 | cd ../../../ 29 | ``` -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | from .transform import ( 2 | InstanceNameFilter, 3 | CircleObjectRangeFilter, 4 | NormalizeMultiviewImage, 5 | NuScenesSparse4DAdaptor, 6 | MultiScaleDepthMapGenerator, 7 | ) 8 | from .augment import ( 9 | ResizeCropFlipImage, 10 | BBoxRotation, 11 | BBoxScale, 12 | PhotoMetricDistortionMultiViewImage, 13 | ) 14 | from .loading import LoadMultiViewImageFromFiles, LoadPointsFromFile 15 | 16 | __all__ = [ 17 | "InstanceNameFilter", 18 | "ResizeCropFlipImage", 19 | "BBoxRotation", 20 | "BBoxScale", 21 | "CircleObjectRangeFilter", 22 | "MultiScaleDepthMapGenerator", 23 | "NormalizeMultiviewImage", 24 | "PhotoMetricDistortionMultiViewImage", 25 | "NuScenesSparse4DAdaptor", 26 | "LoadMultiViewImageFromFiles", 27 | "LoadPointsFromFile", 28 | ] 29 | -------------------------------------------------------------------------------- /docs/prepare_dataset.md: -------------------------------------------------------------------------------- 1 | ## Prepare Dataset 2 | ### NuScenes 3 | **1. Download [nuScenes](https://www.nuscenes.org/download) V1.0 dataset** 4 | 5 | **2. Link dataset to project** 6 | ```bash 7 | ln -s path/to/nuscenes ./data/nuscenes 8 | ``` 9 | **3. Convert nuscenes dataset** 10 | ```bash 11 | python tools/data_converter/nuscenes_converter.py --info_prefix ./data/nuscenes/simpb_nuscenes 12 | ``` 13 | ### Kmean Anchors 14 | ```bash 15 | python tools/anchor_generator.py --ann_file ./data/nuscenes/simpb_nuscenes_infos_train.pkl 16 | ``` 17 | 18 | **Folder structure** 19 | ``` 20 | SimPB 21 | ├── projects/ 22 | ├── tools/ 23 | ├── ckpts/ 24 | ├── data/ 25 | │ ├── nuscenes/ 26 | │ │ ├── maps/ 27 | │ │ ├── samples/ 28 | │ │ ├── sweeps/ 29 | │ │ ├── v1.0-test/ 30 | | | ├── v1.0-trainval/ 31 | | | ├── nuscenes_kmeans900.npy 32 | | | ├── simpb_nuscenes_infos_test.pkl 33 | | | ├── simpb_nuscenes_infos_train.pkl 34 | | | ├── simpb_nuscenes_infos_val.pkl 35 | ``` -------------------------------------------------------------------------------- /docs/training_evaluation.md: -------------------------------------------------------------------------------- 1 | ## Getting Started 2 | 3 | ## Train 4 | 5 | **1. Download pretrained backbone** 6 | ```bash 7 | cd ckpts 8 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth 9 | ``` 10 | 11 | **2. Train simpb with multiple GPUs** 12 | ```bash 13 | bash ./tools/dist_train.sh ./projects/configs/simpb_nus_r50_img_704x256.py 8 --no-validate 14 | ``` 15 | 16 | ## Test 17 | **1. Download pretrained model** 18 | 19 | download pretrained model [here](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_img.pth), or use your own training weight 20 | 21 | **2. Evaluate the pretrained model** 22 | ```bash 23 | bash ./tools/dist_test.sh ./projects/configs/simpb_nus_r50_img_704x256.py path/to/model.pth 8 --eval bbox 24 | ``` 25 | 26 | ## Visualize 27 | **1. Get results file** 28 | ```bash 29 | python ./tools/test.py ./projects/configs/simpb_nus_r50_img_704x256.py path/to/model.pth --out path/to/model.pkl 30 | ``` 31 | 32 | **2. Load and show results** 33 | ```bash 34 | python ./tools/test.py ./projects/configs/simpb_nus_r50_img_704x256.py path/to/model.pth --result_file path/to/model.pkl --show_only --show-dir ./ 35 | ``` -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .simpb import SimPB 2 | from .simpb_head import SimPBHead 3 | from .blocks import ( 4 | DeformableFeatureAggregation, 5 | DenseDepthNet, 6 | AsymmetricFFN, 7 | ) 8 | from .instance_bank import InstanceBank 9 | 10 | from .detection2d import ( 11 | SparseBox2DEncoder, 12 | SparseBox2DRefinementModule, 13 | ) 14 | 15 | from .detection3d import ( 16 | SparseBox3DDecoder, 17 | SparseBox3DTarget, 18 | SparseBox3DRefinementModule, 19 | SparseBox3DKeyPointsGenerator, 20 | SparseBox3DEncoder, 21 | ) 22 | 23 | from .allocation import DynamicQueryAllocation 24 | from .aggregation import AdaptiveQueryAggregation 25 | from .group_attn import (QueryGroupMultiheadAttention, 26 | QueryGroupMultiScaleDeformableAttention) 27 | 28 | __all__ = [ 29 | "SimPB", 30 | "SimPBHead", 31 | "DeformableFeatureAggregation", 32 | "DenseDepthNet", 33 | "AsymmetricFFN", 34 | "InstanceBank", 35 | "SparseBox3DDecoder", 36 | "SparseBox3DTarget", 37 | "SparseBox3DRefinementModule", 38 | "SparseBox3DKeyPointsGenerator", 39 | "SparseBox3DEncoder", 40 | "DynamicQueryAllocation", 41 | "AdaptiveQueryAggregation", 42 | "QueryGroupMultiheadAttention", 43 | "QueryGroupMultiScaleDeformableAttention", 44 | ] 45 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/apis/train.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------- 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | # --------------------------------------------- 4 | # Modified by Zhiqi Li 5 | # --------------------------------------------- 6 | 7 | from .mmdet_train import custom_train_detector 8 | # from mmseg.apis import train_segmentor 9 | from mmdet.apis import train_detector 10 | 11 | 12 | def custom_train_model( 13 | model, 14 | dataset, 15 | cfg, 16 | distributed=False, 17 | validate=False, 18 | timestamp=None, 19 | meta=None, 20 | ): 21 | """A function wrapper for launching model training according to cfg. 22 | 23 | Because we need different eval_hook in runner. Should be deprecated in the 24 | future. 25 | """ 26 | if cfg.model.type in ["EncoderDecoder3D"]: 27 | assert False 28 | else: 29 | custom_train_detector( 30 | model, 31 | dataset, 32 | cfg, 33 | distributed=distributed, 34 | validate=validate, 35 | timestamp=timestamp, 36 | meta=meta, 37 | ) 38 | 39 | 40 | def train_model( 41 | model, 42 | dataset, 43 | cfg, 44 | distributed=False, 45 | validate=False, 46 | timestamp=None, 47 | meta=None, 48 | ): 49 | """A function wrapper for launching model training according to cfg. 50 | 51 | Because we need different eval_hook in runner. Should be deprecated in the 52 | future. 53 | """ 54 | train_detector( 55 | model, 56 | dataset, 57 | cfg, 58 | distributed=distributed, 59 | validate=validate, 60 | timestamp=timestamp, 61 | meta=meta, 62 | ) 63 | -------------------------------------------------------------------------------- /tools/anchor_generator.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import argparse 3 | import numpy as np 4 | from sklearn.cluster import KMeans 5 | from projects.mmdet3d_plugin.core.box3d import * 6 | 7 | 8 | def get_kmeans_anchor( 9 | ann_file, 10 | num_anchor=900, 11 | detection_range=55, 12 | output_file_name="nuscenes_kmeans900.npy", 13 | verbose=False, 14 | ): 15 | data = mmcv.load(ann_file, file_format="pkl") 16 | gt_boxes = np.concatenate([x["gt_boxes"] for x in data["infos"]], axis=0) 17 | distance = np.linalg.norm(gt_boxes[:, :3], axis=-1, ord=2) 18 | mask = distance <= detection_range 19 | gt_boxes = gt_boxes[mask] 20 | clf = KMeans(n_clusters=num_anchor, verbose=verbose) 21 | print("===========Starting kmeans, please wait.===========") 22 | clf.fit(gt_boxes[:, [X, Y, Z]]) 23 | anchor = np.zeros((num_anchor, 11)) 24 | anchor[:, [X, Y, Z]] = clf.cluster_centers_ 25 | anchor[:, [W, L, H]] = np.log(gt_boxes[:, [W, L, H]].mean(axis=0)) 26 | anchor[:, COS_YAW] = 1 27 | np.save(output_file_name, anchor) 28 | print(f"===========Done! Save results to {output_file_name}.===========") 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(description="anchor kmeans") 33 | parser.add_argument("--ann_file", type=str, required=True) 34 | parser.add_argument("--num_anchor", type=int, default=900) 35 | parser.add_argument("--detection_range", type=float, default=55) 36 | parser.add_argument("--output_file_name", type=str, default="kmeans900.npy") 37 | parser.add_argument("--verbose", action="store_true") 38 | args = parser.parse_args() 39 | 40 | get_kmeans_anchor( 41 | args.ann_file, 42 | args.num_anchor, 43 | args.detection_range, 44 | args.output_file_name, 45 | args.verbose, 46 | ) 47 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from setuptools import setup 5 | from torch.utils.cpp_extension import ( 6 | BuildExtension, 7 | CppExtension, 8 | CUDAExtension, 9 | ) 10 | 11 | 12 | def make_cuda_ext( 13 | name, 14 | module, 15 | sources, 16 | sources_cuda=[], 17 | extra_args=[], 18 | extra_include_path=[], 19 | ): 20 | 21 | define_macros = [] 22 | extra_compile_args = {"cxx": [] + extra_args} 23 | 24 | if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1": 25 | define_macros += [("WITH_CUDA", None)] 26 | extension = CUDAExtension 27 | extra_compile_args["nvcc"] = extra_args + [ 28 | "-D__CUDA_NO_HALF_OPERATORS__", 29 | "-D__CUDA_NO_HALF_CONVERSIONS__", 30 | "-D__CUDA_NO_HALF2_OPERATORS__", 31 | ] 32 | sources += sources_cuda 33 | else: 34 | print("Compiling {} without CUDA".format(name)) 35 | extension = CppExtension 36 | 37 | return extension( 38 | name="{}.{}".format(module, name), 39 | sources=[os.path.join(*module.split("."), p) for p in sources], 40 | include_dirs=extra_include_path, 41 | define_macros=define_macros, 42 | extra_compile_args=extra_compile_args, 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | setup( 48 | name="deformable_aggregation_ext", 49 | ext_modules=[ 50 | make_cuda_ext( 51 | "deformable_aggregation_ext", 52 | module=".", 53 | sources=[ 54 | f"src/deformable_aggregation.cpp", 55 | f"src/deformable_aggregation_cuda.cu", 56 | ], 57 | ), 58 | ], 59 | cmdclass={"build_ext": BuildExtension}, 60 | ) 61 | -------------------------------------------------------------------------------- /tools/fuse_conv_bn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | 4 | import torch 5 | from torch import nn as nn 6 | from mmcv.runner import save_checkpoint 7 | from mmdet.apis import init_detector 8 | 9 | 10 | def fuse_conv_bn(conv, bn): 11 | """During inference, the functionary of batch norm layers is turned off but 12 | only the mean and var alone channels are used, which exposes the chance to 13 | fuse it with the preceding conv layers to save computations and simplify 14 | network structures.""" 15 | conv_w = conv.weight 16 | conv_b = ( 17 | conv.bias 18 | if conv.bias is not None 19 | else torch.zeros_like(bn.running_mean) 20 | ) 21 | 22 | factor = bn.weight / torch.sqrt(bn.running_var + bn.eps) 23 | conv.weight = nn.Parameter( 24 | conv_w * factor.reshape([conv.out_channels, 1, 1, 1]) 25 | ) 26 | conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias) 27 | return conv 28 | 29 | 30 | def fuse_module(m): 31 | last_conv = None 32 | last_conv_name = None 33 | 34 | for name, child in m.named_children(): 35 | if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)): 36 | if last_conv is None: # only fuse BN that is after Conv 37 | continue 38 | fused_conv = fuse_conv_bn(last_conv, child) 39 | m._modules[last_conv_name] = fused_conv 40 | # To reduce changes, set BN as Identity instead of deleting it. 41 | m._modules[name] = nn.Identity() 42 | last_conv = None 43 | elif isinstance(child, nn.Conv2d): 44 | last_conv = child 45 | last_conv_name = name 46 | else: 47 | fuse_module(child) 48 | return m 49 | 50 | 51 | def parse_args(): 52 | parser = argparse.ArgumentParser( 53 | description="fuse Conv and BN layers in a model" 54 | ) 55 | parser.add_argument("config", help="config file path") 56 | parser.add_argument("checkpoint", help="checkpoint file path") 57 | parser.add_argument("out", help="output path of the converted model") 58 | args = parser.parse_args() 59 | return args 60 | 61 | 62 | def main(): 63 | args = parse_args() 64 | # build the model from a config file and a checkpoint file 65 | model = init_detector(args.config, args.checkpoint) 66 | # fuse conv and bn layers of the model 67 | fused_model = fuse_module(model) 68 | save_checkpoint(fused_model, args.out) 69 | 70 | 71 | if __name__ == "__main__": 72 | main() 73 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/deformable_aggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd.function import Function, once_differentiable 3 | 4 | from . import deformable_aggregation_ext 5 | 6 | 7 | class DeformableAggregationFunction(Function): 8 | @staticmethod 9 | def forward( 10 | ctx, 11 | mc_ms_feat, 12 | spatial_shape, 13 | scale_start_index, 14 | sampling_location, 15 | weights, 16 | ): 17 | # output: [bs, num_pts, num_embeds] 18 | mc_ms_feat = mc_ms_feat.contiguous().float() 19 | spatial_shape = spatial_shape.contiguous().int() 20 | scale_start_index = scale_start_index.contiguous().int() 21 | sampling_location = sampling_location.contiguous().float() 22 | weights = weights.contiguous().float() 23 | output = deformable_aggregation_ext.deformable_aggregation_forward( 24 | mc_ms_feat, 25 | spatial_shape, 26 | scale_start_index, 27 | sampling_location, 28 | weights, 29 | ) 30 | ctx.save_for_backward( 31 | mc_ms_feat, 32 | spatial_shape, 33 | scale_start_index, 34 | sampling_location, 35 | weights, 36 | ) 37 | return output 38 | 39 | @staticmethod 40 | @once_differentiable 41 | def backward(ctx, grad_output): 42 | ( 43 | mc_ms_feat, 44 | spatial_shape, 45 | scale_start_index, 46 | sampling_location, 47 | weights, 48 | ) = ctx.saved_tensors 49 | mc_ms_feat = mc_ms_feat.contiguous().float() 50 | spatial_shape = spatial_shape.contiguous().int() 51 | scale_start_index = scale_start_index.contiguous().int() 52 | sampling_location = sampling_location.contiguous().float() 53 | weights = weights.contiguous().float() 54 | 55 | grad_mc_ms_feat = torch.zeros_like(mc_ms_feat) 56 | grad_sampling_location = torch.zeros_like(sampling_location) 57 | grad_weights = torch.zeros_like(weights) 58 | deformable_aggregation_ext.deformable_aggregation_backward( 59 | mc_ms_feat, 60 | spatial_shape, 61 | scale_start_index, 62 | sampling_location, 63 | weights, 64 | grad_output.contiguous(), 65 | grad_mc_ms_feat, 66 | grad_sampling_location, 67 | grad_weights, 68 | ) 69 | return ( 70 | grad_mc_ms_feat, 71 | None, 72 | None, 73 | grad_sampling_location, 74 | grad_weights, 75 | ) 76 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection3d/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from mmcv.utils import build_from_cfg 5 | from mmdet.models.builder import LOSSES 6 | 7 | from projects.mmdet3d_plugin.core.box3d import * 8 | from torch.nn.functional import cosine_similarity 9 | 10 | 11 | @LOSSES.register_module() 12 | class SparseBox3DLoss(nn.Module): 13 | def __init__( 14 | self, 15 | loss_box, 16 | loss_centerness=None, 17 | loss_yawness=None, 18 | cls_allow_reverse=None, 19 | ): 20 | super().__init__() 21 | 22 | def build(cfg, registry): 23 | if cfg is None: 24 | return None 25 | return build_from_cfg(cfg, registry) 26 | 27 | self.loss_box = build(loss_box, LOSSES) 28 | self.loss_cns = build(loss_centerness, LOSSES) 29 | self.loss_yns = build(loss_yawness, LOSSES) 30 | self.cls_allow_reverse = cls_allow_reverse 31 | 32 | def forward( 33 | self, 34 | box, 35 | box_target, 36 | weight=None, 37 | avg_factor=None, 38 | suffix="", 39 | quality=None, 40 | cls_target=None, 41 | **kwargs, 42 | ): 43 | # Some categories do not distinguish between positive and negative 44 | # directions. For example, barrier in nuScenes dataset. 45 | if self.cls_allow_reverse is not None and cls_target is not None: 46 | if_reverse = (cosine_similarity( 47 | box_target[..., [SIN_YAW, COS_YAW]], box[..., [SIN_YAW, COS_YAW]], dim=-1) < 0) 48 | if_reverse = (torch.isin(cls_target, cls_target.new_tensor(self.cls_allow_reverse)) & if_reverse) 49 | 50 | box_target[..., [SIN_YAW, COS_YAW]] = torch.where( 51 | if_reverse[..., None], -box_target[..., [SIN_YAW, COS_YAW]], box_target[..., [SIN_YAW, COS_YAW]]) 52 | 53 | output = {} 54 | box_loss = self.loss_box(box, box_target, weight=weight, avg_factor=avg_factor) 55 | output[f"loss_box{suffix}"] = box_loss 56 | 57 | if quality is not None: 58 | cns = quality[..., CNS] 59 | yns = quality[..., YNS].sigmoid() 60 | cns_target = torch.norm(box_target[..., [X, Y, Z]] - box[..., [X, Y, Z]], p=2, dim=-1) 61 | cns_target = torch.exp(-cns_target) 62 | cns_loss = self.loss_cns(cns, cns_target, avg_factor=avg_factor) 63 | output[f"loss_cns{suffix}"] = cns_loss 64 | 65 | yns_target = (cosine_similarity(box_target[..., [SIN_YAW, COS_YAW]], box[..., [SIN_YAW, COS_YAW]], dim=-1) > 0) 66 | yns_target = yns_target.float() 67 | yns_loss = self.loss_yns(yns, yns_target, avg_factor=avg_factor) 68 | output[f"loss_yns{suffix}"] = yns_loss 69 | return output 70 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | def inverse_sigmoid(x, eps=1e-5): 5 | x = x.clamp(min=0, max=1) 6 | x1 = x.clamp(min=eps) 7 | x2 = (1 - x).clamp(min=eps) 8 | return torch.log(x1 / x2) 9 | 10 | def get_valid_ratio(mask): 11 | """Get the valid radios of feature maps of all level.""" 12 | _, H, W = mask.shape 13 | valid_H = torch.sum(~mask[:, :, 0], 1) 14 | valid_W = torch.sum(~mask[:, 0, :], 1) 15 | valid_ratio_h = valid_H.float() / H 16 | valid_ratio_w = valid_W.float() / W 17 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 18 | return valid_ratio 19 | 20 | def get_reference_points(spatial_shapes, valid_ratios, device): 21 | reference_points_list = [] 22 | for lvl, (H, W) in enumerate(spatial_shapes): 23 | # TODO check this 0.5 24 | ref_y, ref_x = torch.meshgrid( 25 | torch.linspace( 26 | 0.5, H - 0.5, H, dtype=torch.float32, device=device), 27 | torch.linspace( 28 | 0.5, W - 0.5, W, dtype=torch.float32, device=device)) 29 | ref_y = ref_y.reshape(-1)[None] / ( 30 | valid_ratios[:, None, lvl, 1] * H) 31 | ref_x = ref_x.reshape(-1)[None] / ( 32 | valid_ratios[:, None, lvl, 0] * W) 33 | ref = torch.stack((ref_x, ref_y), -1) 34 | reference_points_list.append(ref) 35 | reference_points = torch.cat(reference_points_list, 1) 36 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 37 | return reference_points 38 | 39 | 40 | def pos2posemb2d(pos, num_pos_feats=128, temperature=10000): 41 | scale = 2 * math.pi 42 | pos = pos * scale 43 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device) 44 | dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats) 45 | pos_x = pos[..., 0, None] / dim_t 46 | pos_y = pos[..., 1, None] / dim_t 47 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2) 48 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2) 49 | if pos.size(-1) == 2: 50 | posemb = torch.cat((pos_y, pos_x), dim=-1) 51 | elif pos.size(-1) == 4: 52 | w_embed = pos[:, :, 2] * scale 53 | pos_w = w_embed[:, :, None] / dim_t 54 | pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) 55 | 56 | h_embed = pos[:, :, 3] * scale 57 | pos_h = h_embed[:, :, None] / dim_t 58 | pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) 59 | 60 | posemb = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) 61 | else: 62 | raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos.size(-1))) 63 | return posemb -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch.utils.data import DistributedSampler as _DistributedSampler 5 | from .sampler import SAMPLER 6 | 7 | import pdb 8 | import sys 9 | 10 | 11 | class ForkedPdb(pdb.Pdb): 12 | def interaction(self, *args, **kwargs): 13 | _stdin = sys.stdin 14 | try: 15 | sys.stdin = open("/dev/stdin") 16 | pdb.Pdb.interaction(self, *args, **kwargs) 17 | finally: 18 | sys.stdin = _stdin 19 | 20 | 21 | def set_trace(): 22 | ForkedPdb().set_trace(sys._getframe().f_back) 23 | 24 | 25 | @SAMPLER.register_module() 26 | class DistributedSampler(_DistributedSampler): 27 | def __init__( 28 | self, dataset=None, num_replicas=None, rank=None, shuffle=True, seed=0 29 | ): 30 | super().__init__( 31 | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle 32 | ) 33 | # for the compatibility from PyTorch 1.3+ 34 | self.seed = seed if seed is not None else 0 35 | 36 | def __iter__(self): 37 | # deterministically shuffle based on epoch 38 | assert not self.shuffle 39 | if "data_infos" in dir(self.dataset): 40 | timestamps = [ 41 | x["timestamp"] / 1e6 for x in self.dataset.data_infos 42 | ] 43 | vehicle_idx = [ 44 | x["lidar_path"].split("/")[-1][:4] 45 | if "lidar_path" in x 46 | else None 47 | for x in self.dataset.data_infos 48 | ] 49 | else: 50 | timestamps = [ 51 | x["timestamp"] / 1e6 52 | for x in self.dataset.datasets[0].data_infos 53 | ] * len(self.dataset.datasets) 54 | vehicle_idx = [ 55 | x["lidar_path"].split("/")[-1][:4] 56 | if "lidar_path" in x 57 | else None 58 | for x in self.dataset.datasets[0].data_infos 59 | ] * len(self.dataset.datasets) 60 | 61 | sequence_splits = [] 62 | for i in range(len(timestamps)): 63 | if i == 0 or ( 64 | abs(timestamps[i] - timestamps[i - 1]) > 4 65 | or vehicle_idx[i] != vehicle_idx[i - 1] 66 | ): 67 | sequence_splits.append([i]) 68 | else: 69 | sequence_splits[-1].append(i) 70 | 71 | indices = [] 72 | perfix_sum = 0 73 | split_length = len(self.dataset) // self.num_replicas 74 | for i in range(len(sequence_splits)): 75 | if perfix_sum >= (self.rank + 1) * split_length: 76 | break 77 | elif perfix_sum >= self.rank * split_length: 78 | indices.extend(sequence_splits[i]) 79 | perfix_sum += len(sequence_splits[i]) 80 | 81 | self.num_samples = len(indices) 82 | return iter(indices) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

SimPB

3 |

[ECCV 2024] SimPB: A Single Model for 2D and 3D Object Detection from Multiple Cameras

4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2403.10353) 6 | 7 |
8 | 9 | ![method](docs/figs/arch.png "model arch") 10 | 11 | ## Introduction 12 | This repository is an official implementation of *SimPB*, which **Sim**ultaneously detects 2D objects in the 13 | **P**erspective view and 3D objects in the **B**EV space from multiple cameras. 14 | 15 | 16 | ## Getting started 17 | - [Prepare Environment](docs/prepare_environment.md) 18 | - [Prepare Dataset](docs/prepare_dataset.md) 19 | - [Training and Evaluation](docs/training_evaluation.md) 20 | 21 | ## Model Zoo 22 | 23 | **Results on NuScenes validation** 24 | 25 | | method | backbone | pretrain | img size | mAP | NDS | config | ckpt | log | 26 | |:-------:|:---------:|:---------------------------------------------------------------------:|:--------:|:-----:|:-----:|:--------------------------------------------------------------:|:----:|:---:| 27 | | SimPB+ | ResNet50 | [ImageNet](https://download.pytorch.org/models/resnet50-19c8e357.pth) | 704x256 | 0.479 | 0.586 | [config](projects/configs/simpb_nus_r50_img_704x256.py) | [ckpt](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_img.pth) | [log](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_img.log) | 28 | | SimPB+ | ResNet50 | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth) | 704x256 | 0.489 | 0.591 | [config](projects/configs/simpb_nus_r50_uimg_704x256.py) | [ckpt](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_uimg.pth) | [log](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_uimg.log) | 29 | | SimPB | ResNet101 | nuImg | 1408x512 | 0.539 | 0.629 | | | | 30 | 31 | Note: SimPB+ is a modified architecture that introduces 2d denoise and removes the encoder. This slightly reduces the runtime while maintaining comparable performance compared to the released script. 32 | 33 | ## Acknowledgement 34 | Thanks to these excellent open-source works: 35 | 36 | [Sparse4Dv3](https://github.com/HorizonRobotics/Sparse4D), 37 | [StreamPETR](https://github.com/exiawsh/StreamPETR), 38 | [SparseBEV](https://github.com/MCG-NJU/SparseBEV), 39 | [Far3D](https://github.com/megvii-research/Far3D), 40 | [MMDetection3D](https://github.com/open-mmlab/mmdetection3d) 41 | 42 | ## Citation 43 | ```bibtex 44 | @article{simpb, 45 | title={SimPB: A Single Model for 2D and 3D Object Detection from Multiple Cameras}, 46 | author={Yingqi Tang and Zhaotie Meng and Guoliang Chen and Erkang Cheng}, 47 | journal={ECCV}, 48 | year={2024} 49 | } 50 | ``` -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .deformable_aggregation import DeformableAggregationFunction 4 | 5 | 6 | def deformable_aggregation_function( 7 | feature_maps, 8 | spatial_shape, 9 | scale_start_index, 10 | sampling_location, 11 | weights, 12 | ): 13 | return DeformableAggregationFunction.apply( 14 | feature_maps, 15 | spatial_shape, 16 | scale_start_index, 17 | sampling_location, 18 | weights, 19 | ) 20 | 21 | 22 | def feature_maps_format(feature_maps, inverse=False): 23 | if inverse: 24 | col_feats, spatial_shape, scale_start_index = feature_maps 25 | num_cams, num_levels = spatial_shape.shape[:2] 26 | 27 | split_size = spatial_shape[..., 0] * spatial_shape[..., 1] 28 | split_size = split_size.cpu().numpy().tolist() 29 | 30 | idx = 0 31 | cam_split = [1] 32 | cam_split_size = [sum(split_size[0])] 33 | for i in range(num_cams - 1): 34 | if not torch.all(spatial_shape[i] == spatial_shape[i + 1]): 35 | cam_split.append(0) 36 | cam_split_size.append(0) 37 | cam_split[-1] += 1 38 | cam_split_size[-1] += sum(split_size[i + 1]) 39 | mc_feat = [ 40 | x.unflatten(1, (cam_split[i], -1)) 41 | for i, x in enumerate(col_feats.split(cam_split_size, dim=1)) 42 | ] 43 | 44 | spatial_shape = spatial_shape.cpu().numpy().tolist() 45 | mc_ms_feat = [] 46 | shape_index = 0 47 | for i, feat in enumerate(mc_feat): 48 | feat = list(feat.split(split_size[shape_index], dim=2)) 49 | for j, f in enumerate(feat): 50 | feat[j] = f.unflatten(2, spatial_shape[shape_index][j]) 51 | feat[j] = feat[j].permute(0, 1, 4, 2, 3) 52 | mc_ms_feat.append(feat) 53 | shape_index += cam_split[i] 54 | return mc_ms_feat 55 | 56 | if isinstance(feature_maps[0], (list, tuple)): 57 | formated = [feature_maps_format(x) for x in feature_maps] 58 | col_feats = torch.cat([x[0] for x in formated], dim=1) 59 | spatial_shape = torch.cat([x[1] for x in formated], dim=0) 60 | scale_start_index = torch.cat([x[2] for x in formated], dim=0) 61 | return [col_feats, spatial_shape, scale_start_index] 62 | 63 | bs, num_cams = feature_maps[0].shape[:2] 64 | spatial_shape = [] 65 | 66 | col_feats = [] 67 | for i, feat in enumerate(feature_maps): 68 | spatial_shape.append(feat.shape[-2:]) 69 | col_feats.append( 70 | torch.reshape(feat, (bs, num_cams, feat.shape[2], -1)) 71 | ) 72 | 73 | col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2).flatten(1, 2) 74 | spatial_shape = [spatial_shape] * num_cams 75 | spatial_shape = torch.tensor( 76 | spatial_shape, 77 | dtype=torch.int64, 78 | device=col_feats.device, 79 | ) 80 | scale_start_index = spatial_shape[..., 0] * spatial_shape[..., 1] 81 | scale_start_index = scale_start_index.flatten().cumsum(dim=0) 82 | scale_start_index = torch.cat( 83 | [torch.tensor([0]).to(scale_start_index), scale_start_index[:-1]] 84 | ) 85 | scale_start_index = scale_start_index.reshape(num_cams, -1) 86 | 87 | feature_maps = [ 88 | col_feats, 89 | spatial_shape, 90 | scale_start_index, 91 | ] 92 | return feature_maps 93 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Note: Considering that MMCV's EvalHook updated its interface in V1.3.16, 2 | # in order to avoid strong version dependency, we did not directly 3 | # inherit EvalHook but BaseDistEvalHook. 4 | 5 | import bisect 6 | import os.path as osp 7 | 8 | import mmcv 9 | import torch.distributed as dist 10 | from mmcv.runner import DistEvalHook as BaseDistEvalHook 11 | from mmcv.runner import EvalHook as BaseEvalHook 12 | from torch.nn.modules.batchnorm import _BatchNorm 13 | from mmdet.core.evaluation.eval_hooks import DistEvalHook 14 | 15 | 16 | def _calc_dynamic_intervals(start_interval, dynamic_interval_list): 17 | assert mmcv.is_list_of(dynamic_interval_list, tuple) 18 | 19 | dynamic_milestones = [0] 20 | dynamic_milestones.extend( 21 | [dynamic_interval[0] for dynamic_interval in dynamic_interval_list] 22 | ) 23 | dynamic_intervals = [start_interval] 24 | dynamic_intervals.extend( 25 | [dynamic_interval[1] for dynamic_interval in dynamic_interval_list] 26 | ) 27 | return dynamic_milestones, dynamic_intervals 28 | 29 | 30 | class CustomDistEvalHook(BaseDistEvalHook): 31 | def __init__(self, *args, dynamic_intervals=None, **kwargs): 32 | super(CustomDistEvalHook, self).__init__(*args, **kwargs) 33 | self.use_dynamic_intervals = dynamic_intervals is not None 34 | if self.use_dynamic_intervals: 35 | ( 36 | self.dynamic_milestones, 37 | self.dynamic_intervals, 38 | ) = _calc_dynamic_intervals(self.interval, dynamic_intervals) 39 | 40 | def _decide_interval(self, runner): 41 | if self.use_dynamic_intervals: 42 | progress = runner.epoch if self.by_epoch else runner.iter 43 | step = bisect.bisect(self.dynamic_milestones, (progress + 1)) 44 | # Dynamically modify the evaluation interval 45 | self.interval = self.dynamic_intervals[step - 1] 46 | 47 | def before_train_epoch(self, runner): 48 | """Evaluate the model only at the start of training by epoch.""" 49 | self._decide_interval(runner) 50 | super().before_train_epoch(runner) 51 | 52 | def before_train_iter(self, runner): 53 | self._decide_interval(runner) 54 | super().before_train_iter(runner) 55 | 56 | def _do_evaluate(self, runner): 57 | """perform evaluation and save ckpt.""" 58 | # Synchronization of BatchNorm's buffer (running_mean 59 | # and running_var) is not supported in the DDP of pytorch, 60 | # which may cause the inconsistent performance of models in 61 | # different ranks, so we broadcast BatchNorm's buffers 62 | # of rank 0 to other ranks to avoid this. 63 | if self.broadcast_bn_buffer: 64 | model = runner.model 65 | for name, module in model.named_modules(): 66 | if ( 67 | isinstance(module, _BatchNorm) 68 | and module.track_running_stats 69 | ): 70 | dist.broadcast(module.running_var, 0) 71 | dist.broadcast(module.running_mean, 0) 72 | 73 | if not self._should_evaluate(runner): 74 | return 75 | 76 | tmpdir = self.tmpdir 77 | if tmpdir is None: 78 | tmpdir = osp.join(runner.work_dir, ".eval_hook") 79 | 80 | from projects.mmdet3d_plugin.apis.test import ( 81 | custom_multi_gpu_test, 82 | ) # to solve circlur import 83 | 84 | results = custom_multi_gpu_test( 85 | runner.model, 86 | self.dataloader, 87 | tmpdir=tmpdir, 88 | gpu_collect=self.gpu_collect, 89 | ) 90 | if runner.rank == 0: 91 | print("\n") 92 | runner.log_buffer.output["eval_iter_num"] = len(self.dataloader) 93 | 94 | key_score = self.evaluate(runner, results) 95 | 96 | if self.save_best: 97 | self._save_ckpt(runner, key_score) 98 | -------------------------------------------------------------------------------- /tools/benchmark.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import time 4 | import torch 5 | from mmcv import Config 6 | from mmcv.parallel import MMDataParallel 7 | from mmcv.runner import load_checkpoint, wrap_fp16_model 8 | import sys 9 | 10 | sys.path.append(".") 11 | from projects.mmdet3d_plugin.datasets.builder import build_dataloader 12 | from projects.mmdet3d_plugin.datasets import custom_build_dataset 13 | 14 | # from mmdet3d.datasets import build_dataloader, build_dataset 15 | from mmdet.models import build_detector 16 | 17 | from tools.fuse_conv_bn import fuse_module 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser(description="MMDet benchmark a model") 22 | parser.add_argument("config", help="test config file path") 23 | parser.add_argument("--checkpoint", default=None, help="checkpoint file") 24 | parser.add_argument("--samples", default=2000, help="samples to benchmark") 25 | parser.add_argument( 26 | "--log-interval", default=50, help="interval of logging" 27 | ) 28 | parser.add_argument( 29 | "--fuse-conv-bn", 30 | action="store_true", 31 | help="Whether to fuse conv and bn, this will slightly increase" 32 | "the inference speed", 33 | ) 34 | args = parser.parse_args() 35 | return args 36 | 37 | 38 | def get_max_memory(model): 39 | device = getattr(model, "output_device", None) 40 | mem = torch.cuda.max_memory_allocated(device=device) 41 | mem_mb = torch.tensor( 42 | [mem / (1024 * 1024)], dtype=torch.int, device=device 43 | ) 44 | return mem_mb.item() 45 | 46 | 47 | def main(): 48 | args = parse_args() 49 | 50 | cfg = Config.fromfile(args.config) 51 | # set cudnn_benchmark 52 | if cfg.get("cudnn_benchmark", False): 53 | torch.backends.cudnn.benchmark = True 54 | cfg.model.pretrained = None 55 | cfg.data.test.test_mode = True 56 | 57 | # build the dataloader 58 | # TODO: support multiple images per gpu (only minor changes are needed) 59 | print(cfg.data.test) 60 | dataset = custom_build_dataset(cfg.data.test) 61 | data_loader = build_dataloader( 62 | dataset, 63 | samples_per_gpu=1, 64 | workers_per_gpu=cfg.data.workers_per_gpu, 65 | dist=False, 66 | shuffle=False, 67 | ) 68 | 69 | # build the model and load checkpoint 70 | cfg.model.train_cfg = None 71 | model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg")) 72 | fp16_cfg = cfg.get("fp16", None) 73 | if fp16_cfg is not None: 74 | wrap_fp16_model(model) 75 | if args.checkpoint is not None: 76 | load_checkpoint(model, args.checkpoint, map_location="cpu") 77 | if args.fuse_conv_bn: 78 | model = fuse_module(model) 79 | 80 | model = MMDataParallel(model, device_ids=[0]) 81 | 82 | model.eval() 83 | 84 | # the first several iterations may be very slow so skip them 85 | num_warmup = 5 86 | pure_inf_time = 0 87 | 88 | # benchmark with several samples and take the average 89 | max_memory = 0 90 | for i, data in enumerate(data_loader): 91 | # torch.cuda.synchronize() 92 | with torch.no_grad(): 93 | start_time = time.perf_counter() 94 | model(return_loss=False, rescale=True, **data) 95 | 96 | torch.cuda.synchronize() 97 | elapsed = time.perf_counter() - start_time 98 | max_memory = max(max_memory, get_max_memory(model)) 99 | 100 | if i >= num_warmup: 101 | pure_inf_time += elapsed 102 | if (i + 1) % args.log_interval == 0: 103 | fps = (i + 1 - num_warmup) / pure_inf_time 104 | print( 105 | f"Done image [{i + 1:<3}/ {args.samples}], " 106 | f"fps: {fps:.1f} img / s, " 107 | f"gpu mem: {max_memory} M" 108 | ) 109 | 110 | if (i + 1) == args.samples: 111 | pure_inf_time += elapsed 112 | fps = (i + 1 - num_warmup) / pure_inf_time 113 | print(f"Overall fps: {fps:.1f} img / s") 114 | break 115 | 116 | 117 | if __name__ == "__main__": 118 | main() 119 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/group_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import math 3 | 4 | import numpy as np 5 | import torch 6 | from mmcv.runner import get_dist_info 7 | from torch.utils.data import Sampler 8 | from .sampler import SAMPLER 9 | import random 10 | from IPython import embed 11 | 12 | 13 | @SAMPLER.register_module() 14 | class DistributedGroupSampler(Sampler): 15 | """Sampler that restricts data loading to a subset of the dataset. 16 | It is especially useful in conjunction with 17 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 18 | process can pass a DistributedSampler instance as a DataLoader sampler, 19 | and load a subset of the original dataset that is exclusive to it. 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | Arguments: 23 | dataset: Dataset used for sampling. 24 | num_replicas (optional): Number of processes participating in 25 | distributed training. 26 | rank (optional): Rank of the current process within num_replicas. 27 | seed (int, optional): random seed used to shuffle the sampler if 28 | ``shuffle=True``. This number should be identical across all 29 | processes in the distributed group. Default: 0. 30 | """ 31 | 32 | def __init__( 33 | self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0 34 | ): 35 | _rank, _num_replicas = get_dist_info() 36 | if num_replicas is None: 37 | num_replicas = _num_replicas 38 | if rank is None: 39 | rank = _rank 40 | self.dataset = dataset 41 | self.samples_per_gpu = samples_per_gpu 42 | self.num_replicas = num_replicas 43 | self.rank = rank 44 | self.epoch = 0 45 | self.seed = seed if seed is not None else 0 46 | 47 | assert hasattr(self.dataset, "flag") 48 | self.flag = self.dataset.flag 49 | self.group_sizes = np.bincount(self.flag) 50 | 51 | self.num_samples = 0 52 | for i, j in enumerate(self.group_sizes): 53 | self.num_samples += ( 54 | int( 55 | math.ceil( 56 | self.group_sizes[i] 57 | * 1.0 58 | / self.samples_per_gpu 59 | / self.num_replicas 60 | ) 61 | ) 62 | * self.samples_per_gpu 63 | ) 64 | self.total_size = self.num_samples * self.num_replicas 65 | 66 | def __iter__(self): 67 | # deterministically shuffle based on epoch 68 | g = torch.Generator() 69 | g.manual_seed(self.epoch + self.seed) 70 | 71 | indices = [] 72 | for i, size in enumerate(self.group_sizes): 73 | if size > 0: 74 | indice = np.where(self.flag == i)[0] 75 | assert len(indice) == size 76 | # add .numpy() to avoid bug when selecting indice in parrots. 77 | # TODO: check whether torch.randperm() can be replaced by 78 | # numpy.random.permutation(). 79 | indice = indice[ 80 | list(torch.randperm(int(size), generator=g).numpy()) 81 | ].tolist() 82 | extra = int( 83 | math.ceil( 84 | size * 1.0 / self.samples_per_gpu / self.num_replicas 85 | ) 86 | ) * self.samples_per_gpu * self.num_replicas - len(indice) 87 | # pad indice 88 | tmp = indice.copy() 89 | for _ in range(extra // size): 90 | indice.extend(tmp) 91 | indice.extend(tmp[: extra % size]) 92 | indices.extend(indice) 93 | 94 | assert len(indices) == self.total_size 95 | 96 | indices = [ 97 | indices[j] 98 | for i in list( 99 | torch.randperm( 100 | len(indices) // self.samples_per_gpu, generator=g 101 | ) 102 | ) 103 | for j in range( 104 | i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu 105 | ) 106 | ] 107 | 108 | # subsample 109 | offset = self.num_samples * self.rank 110 | indices = indices[offset : offset + self.num_samples] 111 | assert len(indices) == self.num_samples 112 | 113 | return iter(indices) 114 | 115 | def __len__(self): 116 | return self.num_samples 117 | 118 | def set_epoch(self, epoch): 119 | self.epoch = epoch 120 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/grid_mask.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | class Grid(object): 8 | def __init__( 9 | self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0 10 | ): 11 | self.use_h = use_h 12 | self.use_w = use_w 13 | self.rotate = rotate 14 | self.offset = offset 15 | self.ratio = ratio 16 | self.mode = mode 17 | self.st_prob = prob 18 | self.prob = prob 19 | 20 | def set_prob(self, epoch, max_epoch): 21 | self.prob = self.st_prob * epoch / max_epoch 22 | 23 | def __call__(self, img, label): 24 | if np.random.rand() > self.prob: 25 | return img, label 26 | h = img.size(1) 27 | w = img.size(2) 28 | self.d1 = 2 29 | self.d2 = min(h, w) 30 | hh = int(1.5 * h) 31 | ww = int(1.5 * w) 32 | d = np.random.randint(self.d1, self.d2) 33 | if self.ratio == 1: 34 | self.l = np.random.randint(1, d) 35 | else: 36 | self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) 37 | mask = np.ones((hh, ww), np.float32) 38 | st_h = np.random.randint(d) 39 | st_w = np.random.randint(d) 40 | if self.use_h: 41 | for i in range(hh // d): 42 | s = d * i + st_h 43 | t = min(s + self.l, hh) 44 | mask[s:t, :] *= 0 45 | if self.use_w: 46 | for i in range(ww // d): 47 | s = d * i + st_w 48 | t = min(s + self.l, ww) 49 | mask[:, s:t] *= 0 50 | 51 | r = np.random.randint(self.rotate) 52 | mask = Image.fromarray(np.uint8(mask)) 53 | mask = mask.rotate(r) 54 | mask = np.asarray(mask) 55 | mask = mask[ 56 | (hh - h) // 2 : (hh - h) // 2 + h, 57 | (ww - w) // 2 : (ww - w) // 2 + w, 58 | ] 59 | 60 | mask = torch.from_numpy(mask).float() 61 | if self.mode == 1: 62 | mask = 1 - mask 63 | 64 | mask = mask.expand_as(img) 65 | if self.offset: 66 | offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float() 67 | offset = (1 - mask) * offset 68 | img = img * mask + offset 69 | else: 70 | img = img * mask 71 | 72 | return img, label 73 | 74 | 75 | class GridMask(nn.Module): 76 | def __init__( 77 | self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0 78 | ): 79 | super(GridMask, self).__init__() 80 | self.use_h = use_h 81 | self.use_w = use_w 82 | self.rotate = rotate 83 | self.offset = offset 84 | self.ratio = ratio 85 | self.mode = mode 86 | self.st_prob = prob 87 | self.prob = prob 88 | 89 | def set_prob(self, epoch, max_epoch): 90 | self.prob = self.st_prob * epoch / max_epoch # + 1.#0.5 91 | 92 | def forward(self, x): 93 | if np.random.rand() > self.prob or not self.training: 94 | return x 95 | n, c, h, w = x.size() 96 | x = x.view(-1, h, w) 97 | hh = int(1.5 * h) 98 | ww = int(1.5 * w) 99 | d = np.random.randint(2, h) 100 | self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1) 101 | mask = np.ones((hh, ww), np.float32) 102 | st_h = np.random.randint(d) 103 | st_w = np.random.randint(d) 104 | if self.use_h: 105 | for i in range(hh // d): 106 | s = d * i + st_h 107 | t = min(s + self.l, hh) 108 | mask[s:t, :] *= 0 109 | if self.use_w: 110 | for i in range(ww // d): 111 | s = d * i + st_w 112 | t = min(s + self.l, ww) 113 | mask[:, s:t] *= 0 114 | 115 | r = np.random.randint(self.rotate) 116 | mask = Image.fromarray(np.uint8(mask)) 117 | mask = mask.rotate(r) 118 | mask = np.asarray(mask) 119 | mask = mask[ 120 | (hh - h) // 2 : (hh - h) // 2 + h, 121 | (ww - w) // 2 : (ww - w) // 2 + w, 122 | ] 123 | 124 | mask = torch.from_numpy(mask.copy()).float().cuda() 125 | if self.mode == 1: 126 | mask = 1 - mask 127 | mask = mask.expand_as(x) 128 | if self.offset: 129 | offset = ( 130 | torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)) 131 | .float() 132 | .cuda() 133 | ) 134 | x = x * mask + offset * (1 - mask) 135 | else: 136 | x = x * mask 137 | 138 | return x.view(n, c, h, w) 139 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/simpb.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from inspect import signature 4 | from mmcv.runner import force_fp32, auto_fp16 5 | from mmcv.utils import build_from_cfg 6 | from mmcv.cnn.bricks.registry import PLUGIN_LAYERS 7 | from mmdet.models import ( 8 | DETECTORS, 9 | BaseDetector, 10 | build_backbone, 11 | build_head, 12 | build_neck, 13 | ) 14 | from .grid_mask import GridMask 15 | 16 | try: 17 | from ..ops import feature_maps_format 18 | DAF_VALID = True 19 | except: 20 | DAF_VALID = False 21 | 22 | __all__ = ["SimPB"] 23 | 24 | 25 | @DETECTORS.register_module() 26 | class SimPB(BaseDetector): 27 | def __init__( 28 | self, 29 | img_backbone, 30 | head, 31 | img_neck=None, 32 | init_cfg=None, 33 | train_cfg=None, 34 | test_cfg=None, 35 | pretrained=None, 36 | use_grid_mask=True, 37 | use_deformable_func=False, 38 | depth_branch=None, 39 | ): 40 | super(SimPB, self).__init__(init_cfg=init_cfg) 41 | if pretrained is not None: 42 | backbone.pretrained = pretrained 43 | self.img_backbone = build_backbone(img_backbone) 44 | if img_neck is not None: 45 | self.img_neck = build_neck(img_neck) 46 | if test_cfg is not None: 47 | head['test_cfg'] = test_cfg 48 | self.head = build_head(head) 49 | self.use_grid_mask = use_grid_mask 50 | if use_deformable_func: 51 | assert DAF_VALID, "deformable_aggregation needs to be set up." 52 | self.use_deformable_func = use_deformable_func 53 | self.head.use_deformable_func = self.use_deformable_func 54 | if depth_branch is not None: 55 | self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS) 56 | else: 57 | self.depth_branch = None 58 | if use_grid_mask: 59 | self.grid_mask = GridMask( 60 | True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7 61 | ) 62 | 63 | @auto_fp16(apply_to=("img",), out_fp32=True) 64 | def extract_feat(self, img, return_depth=False, metas=None): 65 | bs = img.shape[0] 66 | if img.dim() == 5: # multi-view 67 | num_cams = img.shape[1] 68 | img = img.flatten(end_dim=1) 69 | else: 70 | num_cams = 1 71 | if self.use_grid_mask: 72 | img = self.grid_mask(img) 73 | if "metas" in signature(self.img_backbone.forward).parameters: 74 | feature_maps = self.img_backbone(img, num_cams, metas=metas) 75 | else: 76 | feature_maps = self.img_backbone(img) 77 | if self.img_neck is not None: 78 | feature_maps = list(self.img_neck(feature_maps)) 79 | for i, feat in enumerate(feature_maps): 80 | feature_maps[i] = torch.reshape( 81 | feat, (bs, num_cams) + feat.shape[1:] 82 | ) 83 | if return_depth and self.depth_branch is not None: 84 | depths = self.depth_branch(feature_maps, metas.get("focal")) 85 | else: 86 | depths = None 87 | if self.use_deformable_func: 88 | feature_maps = feature_maps_format(feature_maps) 89 | if return_depth: 90 | return feature_maps, depths 91 | return feature_maps 92 | 93 | @force_fp32(apply_to=("img",)) 94 | def forward(self, img, **data): 95 | if self.training: 96 | return self.forward_train(img, **data) 97 | else: 98 | return self.forward_test(img, **data) 99 | 100 | def forward_train(self, img, **data): 101 | feature_maps, depths = self.extract_feat(img, True, data) 102 | model_outs = self.head(feature_maps, data) 103 | output = self.head.loss(model_outs, data) 104 | if depths is not None and "gt_depth" in data: 105 | output["loss_dense_depth"] = self.depth_branch.loss( 106 | depths, data["gt_depth"] 107 | ) 108 | return output 109 | 110 | def forward_test(self, img, **data): 111 | if isinstance(img, list): 112 | return self.aug_test(img, **data) 113 | else: 114 | return self.simple_test(img, **data) 115 | 116 | def simple_test(self, img, **data): 117 | feature_maps = self.extract_feat(img) 118 | 119 | model_outs = self.head(feature_maps, data) 120 | results = self.head.post_process(model_outs, data) 121 | 122 | return results 123 | 124 | def aug_test(self, img, **data): 125 | # fake test time augmentation 126 | for key in data.keys(): 127 | if isinstance(data[key], list): 128 | data[key] = data[key][0] 129 | return self.simple_test(img[0], **data) 130 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | void deformable_aggregation( 5 | float* output, 6 | const float* mc_ms_feat, 7 | const int* spatial_shape, 8 | const int* scale_start_index, 9 | const float* sample_location, 10 | const float* weights, 11 | int batch_size, 12 | int num_cams, 13 | int num_feat, 14 | int num_embeds, 15 | int num_scale, 16 | int num_anchors, 17 | int num_pts, 18 | int num_groups 19 | ); 20 | 21 | 22 | /* feat: bs, num_feat, c */ 23 | /* _spatial_shape: cam, scale, 2 */ 24 | /* _scale_start_index: cam, scale */ 25 | /* _sampling_location: bs, anchor, pts, cam, 2 */ 26 | /* _weights: bs, anchor, pts, cam, scale, group */ 27 | /* output: bs, anchor, c */ 28 | /* kernel: bs, anchor, pts, c */ 29 | 30 | 31 | at::Tensor deformable_aggregation_forward( 32 | const at::Tensor &_mc_ms_feat, 33 | const at::Tensor &_spatial_shape, 34 | const at::Tensor &_scale_start_index, 35 | const at::Tensor &_sampling_location, 36 | const at::Tensor &_weights 37 | ) { 38 | at::DeviceGuard guard(_mc_ms_feat.device()); 39 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat)); 40 | int batch_size = _mc_ms_feat.size(0); 41 | int num_feat = _mc_ms_feat.size(1); 42 | int num_embeds = _mc_ms_feat.size(2); 43 | int num_cams = _spatial_shape.size(0); 44 | int num_scale = _spatial_shape.size(1); 45 | int num_anchors = _sampling_location.size(1); 46 | int num_pts = _sampling_location.size(2); 47 | int num_groups = _weights.size(5); 48 | 49 | const float* mc_ms_feat = _mc_ms_feat.data_ptr(); 50 | const int* spatial_shape = _spatial_shape.data_ptr(); 51 | const int* scale_start_index = _scale_start_index.data_ptr(); 52 | const float* sampling_location = _sampling_location.data_ptr(); 53 | const float* weights = _weights.data_ptr(); 54 | 55 | auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options()); 56 | deformable_aggregation( 57 | output.data_ptr(), 58 | mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, 59 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups 60 | ); 61 | return output; 62 | } 63 | 64 | 65 | void deformable_aggregation_grad( 66 | const float* mc_ms_feat, 67 | const int* spatial_shape, 68 | const int* scale_start_index, 69 | const float* sample_location, 70 | const float* weights, 71 | const float* grad_output, 72 | float* grad_mc_ms_feat, 73 | float* grad_sampling_location, 74 | float* grad_weights, 75 | int batch_size, 76 | int num_cams, 77 | int num_feat, 78 | int num_embeds, 79 | int num_scale, 80 | int num_anchors, 81 | int num_pts, 82 | int num_groups 83 | ); 84 | 85 | 86 | void deformable_aggregation_backward( 87 | const at::Tensor &_mc_ms_feat, 88 | const at::Tensor &_spatial_shape, 89 | const at::Tensor &_scale_start_index, 90 | const at::Tensor &_sampling_location, 91 | const at::Tensor &_weights, 92 | const at::Tensor &_grad_output, 93 | at::Tensor &_grad_mc_ms_feat, 94 | at::Tensor &_grad_sampling_location, 95 | at::Tensor &_grad_weights 96 | ) { 97 | at::DeviceGuard guard(_mc_ms_feat.device()); 98 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat)); 99 | int batch_size = _mc_ms_feat.size(0); 100 | int num_feat = _mc_ms_feat.size(1); 101 | int num_embeds = _mc_ms_feat.size(2); 102 | int num_cams = _spatial_shape.size(0); 103 | int num_scale = _spatial_shape.size(1); 104 | int num_anchors = _sampling_location.size(1); 105 | int num_pts = _sampling_location.size(2); 106 | int num_groups = _weights.size(5); 107 | 108 | const float* mc_ms_feat = _mc_ms_feat.data_ptr(); 109 | const int* spatial_shape = _spatial_shape.data_ptr(); 110 | const int* scale_start_index = _scale_start_index.data_ptr(); 111 | const float* sampling_location = _sampling_location.data_ptr(); 112 | const float* weights = _weights.data_ptr(); 113 | const float* grad_output = _grad_output.data_ptr(); 114 | 115 | float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr(); 116 | float* grad_sampling_location = _grad_sampling_location.data_ptr(); 117 | float* grad_weights = _grad_weights.data_ptr(); 118 | 119 | deformable_aggregation_grad( 120 | mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, 121 | grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights, 122 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups 123 | ); 124 | } 125 | 126 | 127 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 128 | m.def( 129 | "deformable_aggregation_forward", 130 | &deformable_aggregation_forward, 131 | "deformable_aggregation_forward" 132 | ); 133 | m.def( 134 | "deformable_aggregation_backward", 135 | &deformable_aggregation_backward, 136 | "deformable_aggregation_backward" 137 | ); 138 | } 139 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/aggregation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from mmcv.cnn import Linear 6 | from mmcv.utils import build_from_cfg 7 | from mmcv.cnn.bricks.registry import ATTENTION, PLUGIN_LAYERS 8 | 9 | 10 | class ReWeight(nn.Module): 11 | def __init__(self, c_dim, f_dim=256, trans=True, with_pos=False): 12 | super().__init__() 13 | self.c_dim = c_dim 14 | self.f_dim = f_dim 15 | self.trans = trans 16 | self.with_pos = with_pos 17 | 18 | self.reduce = nn.Sequential( 19 | nn.Linear(c_dim, f_dim), 20 | nn.ReLU(), 21 | ) 22 | self.alpha = nn.Sequential( 23 | nn.Linear(f_dim, 1), 24 | nn.Sigmoid() 25 | ) 26 | 27 | def forward(self, query, query_pos, parameter, trans_matrix=None): 28 | 29 | alpha = self.alpha(self.reduce(parameter)) 30 | 31 | if self.trans: 32 | reweight_matrix = (trans_matrix * alpha).permute(0, 2, 1) 33 | reweight_divisor = torch.clamp(reweight_matrix.sum(-1).unsqueeze(-1), 1e-5) 34 | query = torch.div(torch.matmul(reweight_matrix, query), reweight_divisor) 35 | query_pos = torch.div(torch.matmul(reweight_matrix, query_pos), reweight_divisor) if self.with_pos else None 36 | else: 37 | query = alpha * query 38 | query_pos = alpha * query_pos if self.with_pos else None 39 | 40 | return query, query_pos 41 | 42 | 43 | @PLUGIN_LAYERS.register_module() 44 | class AdaptiveQueryAggregation(nn.Module): 45 | def __init__(self, self_attn=None, reweight=None, decouple_attn=False, with_pos=False): 46 | super().__init__() 47 | self.with_pos = with_pos 48 | self.decouple_attn = decouple_attn 49 | trans = True if self_attn is not None else False 50 | self.reweight = ReWeight(c_dim=257, trans=trans, with_pos=with_pos) if reweight is not None else None 51 | self.self_attn = build_from_cfg(self_attn, ATTENTION) if self_attn is not None else None 52 | 53 | 54 | def forward(self, 55 | query2d, query_pos2d, anchor2d, 56 | query3d, query_pos3d, anchor3d, 57 | dn_query2d=None, dn_query_pos2d=None, dn_anchor2d=None, 58 | dn_query3d=None, dn_query_pos3d=None, dn_anchor3d=None, 59 | trans_matrix=None, center_matrix=None, dn_trans_matrix=None, dn_center_matrix=None, 60 | attn_mask=None, graph_model=None, **kwargs): 61 | 62 | if self.reweight is not None: 63 | center_param = torch.cat([query2d, center_matrix.sum(-1).unsqueeze(-1)], dim=-1) 64 | query3d_from2d, query_pos3d_from2d = self.reweight(query2d, query_pos2d, center_param, trans_matrix) 65 | 66 | if dn_query2d is not None: 67 | dn_center_param = torch.cat([dn_query2d, dn_center_matrix.sum(-1).unsqueeze(-1)], dim=-1) 68 | dn_query3d_from2d, dn_query_pos3d_from2d = self.reweight(dn_query2d, dn_query_pos2d, dn_center_param, dn_trans_matrix) 69 | 70 | else: 71 | trans_matrix_t = trans_matrix.permute(0, 2, 1) 72 | trans_divisor = torch.clamp(trans_matrix_t.sum(-1).unsqueeze(-1), 1e-5) 73 | query3d_from2d = torch.div(torch.matmul(trans_matrix_t, query2d), trans_divisor) 74 | query_pos3d_from2d = torch.div(torch.matmul(trans_matrix_t, query_pos2d), trans_divisor) if self.with_pos else None 75 | 76 | if dn_query2d is not None: 77 | query3d_from2d = torch.div(torch.matmul(trans_matrix_t, query2d), trans_divisor) 78 | query_pos3d_from2d = torch.div(torch.matmul(trans_matrix_t, query_pos2d), trans_divisor) if self.with_pos else None 79 | 80 | # merge with denoise 81 | if dn_query3d is not None: 82 | query3d = torch.cat([query3d, dn_query3d], dim=1) 83 | query_pos3d = torch.cat([query_pos3d, dn_query_pos3d], dim=1) 84 | anchor3d = torch.cat([anchor3d, dn_anchor3d], dim=1) 85 | 86 | if dn_query2d is not None: 87 | query3d_from2d = torch.cat([query3d_from2d, dn_query3d_from2d], dim=1) 88 | query_pos3d_from2d = torch.cat([query_pos3d_from2d, dn_query_pos3d_from2d], dim=1) if self.with_pos else None 89 | else: 90 | query3d_from2d = torch.cat([query3d_from2d, torch.zeros_like(dn_query3d)], dim=1) 91 | query_pos3d_from2d = torch.cat([query_pos3d_from2d, torch.zeros_like(dn_query3d)], dim=1) if self.with_pos else None 92 | 93 | query3d = query3d + query3d_from2d 94 | query_pos3d = query_pos3d + query_pos3d_from2d if self.with_pos else query_pos3d 95 | 96 | aggregated_query3d = graph_model(self.self_attn, 97 | query=query3d, 98 | query_pos=query_pos3d, 99 | attn_mask=attn_mask) 100 | 101 | return aggregated_query3d, query_pos3d, anchor3d 102 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection2d/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from mmcv.cnn import Linear, Scale, bias_init_with_prob 6 | from mmcv.runner.base_module import Sequential, BaseModule 7 | from mmcv.cnn.bricks.registry import ( 8 | PLUGIN_LAYERS, 9 | POSITIONAL_ENCODING, 10 | ) 11 | 12 | from ..blocks import linear_relu_ln 13 | from ..utils import inverse_sigmoid, pos2posemb2d 14 | 15 | __all__ = [ 16 | "SparseBox2DRefinementModule", 17 | "SparseBox2DEncoder", 18 | ] 19 | 20 | @POSITIONAL_ENCODING.register_module() 21 | class SparseBox2DEncoder(BaseModule): 22 | def __init__( 23 | self, 24 | embed_dims=256, 25 | with_size=False, 26 | with_sin_embed=False, 27 | mode="add", 28 | in_loops=1, 29 | out_loops=2, 30 | ): 31 | super(SparseBox2DEncoder, self).__init__() 32 | self.embed_dims = embed_dims 33 | self.mode = mode 34 | self.with_size = with_size 35 | self.with_sin_embed = with_sin_embed 36 | 37 | def embedding_layer(input_dims): 38 | return nn.Sequential(*linear_relu_ln(embed_dims, in_loops, out_loops, input_dims)) 39 | 40 | if self.with_sin_embed: 41 | self.query_embeddings2d = embedding_layer(256) 42 | else: 43 | self.pos_fc = embedding_layer(2) 44 | if self.with_size: 45 | self.size_fc = embedding_layer(2) 46 | self.output_fc = embedding_layer(self.embed_dims) 47 | 48 | def forward(self, box_2d): 49 | if self.with_sin_embed: 50 | output = self.query_embeddings2d(pos2posemb2d(box_2d)) 51 | else: 52 | pos_feat = self.pos_fc(box_2d[..., :2]) 53 | if self.with_size: 54 | size_feat = self.size_fc(box_2d[..., 2:4]) 55 | if self.mode == "add": 56 | output = pos_feat + size_feat 57 | elif self.mode == "cat": 58 | output = torch.cat([pos_feat, size_feat], dim=-1) 59 | output = self.output_fc(output) 60 | else: 61 | output = pos_feat 62 | 63 | return output 64 | 65 | @PLUGIN_LAYERS.register_module() 66 | class SparseBox2DRefinementModule(BaseModule): 67 | def __init__(self, embed_dims=256, output_dim=4, num_cls=10, alpha_dim=2, 68 | with_cls_branch=True, with_alpha_branch=False, with_depth_branch=False, 69 | with_multibin_depth=False, depth_bin_num=64): 70 | super(SparseBox2DRefinementModule, self).__init__() 71 | self.embed_dims = embed_dims 72 | self.output_dim = output_dim 73 | self.num_cls = num_cls 74 | 75 | self.layers = nn.Sequential( 76 | *linear_relu_ln(embed_dims, 2, 2), 77 | Linear(self.embed_dims, self.output_dim), 78 | Scale([1.0] * self.output_dim) 79 | ) 80 | 81 | self.with_cls_branch = with_cls_branch 82 | if with_cls_branch: 83 | self.cls_layers = nn.Sequential( 84 | *linear_relu_ln(embed_dims, 1, 2), 85 | Linear(self.embed_dims, self.num_cls), 86 | ) 87 | self.with_alpha_branch = with_alpha_branch 88 | if with_alpha_branch: 89 | self.alpha_layers = nn.Sequential( 90 | *linear_relu_ln(embed_dims, 1, 2), 91 | Linear(self.embed_dims, alpha_dim), 92 | Scale([1.0] * 2) 93 | ) 94 | self.with_depth_branch = with_depth_branch 95 | self.with_multibin_depth = with_multibin_depth 96 | if with_depth_branch: 97 | if with_multibin_depth: 98 | self.depth_layers = nn.Sequential( 99 | *linear_relu_ln(embed_dims, 2, 2), 100 | Linear(self.embed_dims, depth_bin_num), 101 | ) 102 | else: 103 | self.depth_layers = nn.Sequential( 104 | *linear_relu_ln(embed_dims, 2, 2), 105 | Linear(self.embed_dims, 1), 106 | Scale([1.0] * 1) 107 | ) 108 | 109 | def init_weight(self): 110 | if self.with_cls_branch: 111 | bias_init = bias_init_with_prob(0.01) 112 | nn.init.constant_(self.cls_layers[-1].bias, bias_init) 113 | if self.with_multibin_depth: 114 | bias_init = bias_init_with_prob(0.01) 115 | nn.init.constant_(self.depth_layers[-1].bias, bias_init) 116 | 117 | def forward(self, instance_feature, anchor2d, anchor2d_embed, 118 | metas=None, return_cls=True, query_groups=None): 119 | output = self.layers(instance_feature + anchor2d_embed) 120 | 121 | if anchor2d.shape[-1] == 2: 122 | output[..., :2] = output[..., :2] + inverse_sigmoid(anchor2d) 123 | elif anchor2d.shape[-1] == 4: 124 | output[..., :4] = output[..., :4] + inverse_sigmoid(anchor2d) 125 | 126 | cls = None 127 | if return_cls: 128 | cls = self.cls_layers(instance_feature) 129 | 130 | alpha = None 131 | if self.with_alpha_branch: 132 | alpha = self.alpha_layers(instance_feature) 133 | 134 | depth = None 135 | if self.with_depth_branch: 136 | if self.with_multibin_depth: 137 | depth = self.depth_layers(instance_feature + anchor2d_embed) 138 | else: 139 | focal = torch.cat([ 140 | metas['focal'][:, i:i+1].repeat(1, qg[1]-qg[0]) for i, qg in enumerate(query_groups) 141 | ], dim=-1) 142 | depth = self.depth_layers(instance_feature).exp() 143 | depth = depth * focal.unsqueeze(-1) / 100 144 | 145 | return output.sigmoid(), cls, depth, alpha -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/apis/test.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------- 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | # --------------------------------------------- 4 | # Modified by Zhiqi Li 5 | # --------------------------------------------- 6 | import os.path as osp 7 | import pickle 8 | import shutil 9 | import tempfile 10 | import time 11 | 12 | import mmcv 13 | import torch 14 | import torch.distributed as dist 15 | from mmcv.image import tensor2imgs 16 | from mmcv.runner import get_dist_info 17 | 18 | from mmdet.core import encode_mask_results 19 | 20 | 21 | import mmcv 22 | import numpy as np 23 | import pycocotools.mask as mask_util 24 | 25 | 26 | def custom_encode_mask_results(mask_results): 27 | """Encode bitmap mask to RLE code. Semantic Masks only 28 | Args: 29 | mask_results (list | tuple[list]): bitmap mask results. 30 | In mask scoring rcnn, mask_results is a tuple of (segm_results, 31 | segm_cls_score). 32 | Returns: 33 | list | tuple: RLE encoded mask. 34 | """ 35 | cls_segms = mask_results 36 | num_classes = len(cls_segms) 37 | encoded_mask_results = [] 38 | for i in range(len(cls_segms)): 39 | encoded_mask_results.append( 40 | mask_util.encode( 41 | np.array( 42 | cls_segms[i][:, :, np.newaxis], order="F", dtype="uint8" 43 | ) 44 | )[0] 45 | ) # encoded with RLE 46 | return [encoded_mask_results] 47 | 48 | 49 | def custom_multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False): 50 | """Test model with multiple gpus. 51 | This method tests model with multiple gpus and collects the results 52 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True' 53 | it encodes results to gpu tensors and use gpu communication for results 54 | collection. On cpu mode it saves the results on different gpus to 'tmpdir' 55 | and collects them by the rank 0 worker. 56 | Args: 57 | model (nn.Module): Model to be tested. 58 | data_loader (nn.Dataloader): Pytorch data loader. 59 | tmpdir (str): Path of directory to save the temporary results from 60 | different gpus under cpu mode. 61 | gpu_collect (bool): Option to use either gpu or cpu to collect results. 62 | Returns: 63 | list: The prediction results. 64 | """ 65 | model.eval() 66 | bbox_results = [] 67 | mask_results = [] 68 | dataset = data_loader.dataset 69 | rank, world_size = get_dist_info() 70 | if rank == 0: 71 | prog_bar = mmcv.ProgressBar(len(dataset)) 72 | time.sleep(2) # This line can prevent deadlock problem in some cases. 73 | have_mask = False 74 | for i, data in enumerate(data_loader): 75 | with torch.no_grad(): 76 | result = model(return_loss=False, rescale=True, **data) 77 | # encode mask results 78 | if isinstance(result, dict): 79 | if "bbox_results" in result.keys(): 80 | bbox_result = result["bbox_results"] 81 | batch_size = len(result["bbox_results"]) 82 | bbox_results.extend(bbox_result) 83 | if ( 84 | "mask_results" in result.keys() 85 | and result["mask_results"] is not None 86 | ): 87 | mask_result = custom_encode_mask_results( 88 | result["mask_results"] 89 | ) 90 | mask_results.extend(mask_result) 91 | have_mask = True 92 | else: 93 | batch_size = len(result) 94 | bbox_results.extend(result) 95 | 96 | if rank == 0: 97 | for _ in range(batch_size * world_size): 98 | prog_bar.update() 99 | 100 | # collect results from all ranks 101 | if gpu_collect: 102 | bbox_results = collect_results_gpu(bbox_results, len(dataset)) 103 | if have_mask: 104 | mask_results = collect_results_gpu(mask_results, len(dataset)) 105 | else: 106 | mask_results = None 107 | else: 108 | bbox_results = collect_results_cpu(bbox_results, len(dataset), tmpdir) 109 | tmpdir = tmpdir + "_mask" if tmpdir is not None else None 110 | if have_mask: 111 | mask_results = collect_results_cpu( 112 | mask_results, len(dataset), tmpdir 113 | ) 114 | else: 115 | mask_results = None 116 | 117 | if mask_results is None: 118 | return bbox_results 119 | return {"bbox_results": bbox_results, "mask_results": mask_results} 120 | 121 | 122 | def collect_results_cpu(result_part, size, tmpdir=None): 123 | rank, world_size = get_dist_info() 124 | # create a tmp dir if it is not specified 125 | if tmpdir is None: 126 | MAX_LEN = 512 127 | # 32 is whitespace 128 | dir_tensor = torch.full( 129 | (MAX_LEN,), 32, dtype=torch.uint8, device="cuda" 130 | ) 131 | if rank == 0: 132 | mmcv.mkdir_or_exist(".dist_test") 133 | tmpdir = tempfile.mkdtemp(dir=".dist_test") 134 | tmpdir = torch.tensor( 135 | bytearray(tmpdir.encode()), dtype=torch.uint8, device="cuda" 136 | ) 137 | dir_tensor[: len(tmpdir)] = tmpdir 138 | dist.broadcast(dir_tensor, 0) 139 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() 140 | else: 141 | mmcv.mkdir_or_exist(tmpdir) 142 | # dump the part result to the dir 143 | mmcv.dump(result_part, osp.join(tmpdir, f"part_{rank}.pkl")) 144 | dist.barrier() 145 | # collect all parts 146 | if rank != 0: 147 | return None 148 | else: 149 | # load results of all parts from tmp dir 150 | part_list = [] 151 | for i in range(world_size): 152 | part_file = osp.join(tmpdir, f"part_{i}.pkl") 153 | part_list.append(mmcv.load(part_file)) 154 | # sort the results 155 | ordered_results = [] 156 | """ 157 | bacause we change the sample of the evaluation stage to make sure that 158 | each gpu will handle continuous sample, 159 | """ 160 | # for res in zip(*part_list): 161 | for res in part_list: 162 | ordered_results.extend(list(res)) 163 | # the dataloader may pad some samples 164 | ordered_results = ordered_results[:size] 165 | # remove tmp dir 166 | shutil.rmtree(tmpdir) 167 | return ordered_results 168 | 169 | 170 | def collect_results_gpu(result_part, size): 171 | collect_results_cpu(result_part, size) 172 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import platform 3 | import random 4 | from functools import partial 5 | 6 | import numpy as np 7 | from mmcv.parallel import collate 8 | from mmcv.runner import get_dist_info 9 | from mmcv.utils import Registry, build_from_cfg 10 | from torch.utils.data import DataLoader 11 | 12 | from mmdet.datasets.samplers import GroupSampler 13 | from projects.mmdet3d_plugin.datasets.samplers import ( 14 | GroupInBatchSampler, 15 | DistributedGroupSampler, 16 | DistributedSampler, 17 | build_sampler 18 | ) 19 | 20 | 21 | def build_dataloader( 22 | dataset, 23 | samples_per_gpu, 24 | workers_per_gpu, 25 | num_gpus=1, 26 | dist=True, 27 | shuffle=True, 28 | seed=None, 29 | shuffler_sampler=None, 30 | nonshuffler_sampler=None, 31 | runner_type="EpochBasedRunner", 32 | **kwargs 33 | ): 34 | """Build PyTorch DataLoader. 35 | In distributed training, each GPU/process has a dataloader. 36 | In non-distributed training, there is only one dataloader for all GPUs. 37 | Args: 38 | dataset (Dataset): A PyTorch dataset. 39 | samples_per_gpu (int): Number of training samples on each GPU, i.e., 40 | batch size of each GPU. 41 | workers_per_gpu (int): How many subprocesses to use for data loading 42 | for each GPU. 43 | num_gpus (int): Number of GPUs. Only used in non-distributed training. 44 | dist (bool): Distributed training/test or not. Default: True. 45 | shuffle (bool): Whether to shuffle the data at every epoch. 46 | Default: True. 47 | kwargs: any keyword argument to be used to initialize DataLoader 48 | Returns: 49 | DataLoader: A PyTorch dataloader. 50 | """ 51 | rank, world_size = get_dist_info() 52 | batch_sampler = None 53 | if runner_type == 'IterBasedRunner': 54 | print("Use GroupInBatchSampler !!!") 55 | batch_sampler = GroupInBatchSampler( 56 | dataset, 57 | samples_per_gpu, 58 | world_size, 59 | rank, 60 | seed=seed, 61 | ) 62 | batch_size = 1 63 | sampler = None 64 | num_workers = workers_per_gpu 65 | elif dist: 66 | # DistributedGroupSampler will definitely shuffle the data to satisfy 67 | # that images on each GPU are in the same group 68 | if shuffle: 69 | print("Use DistributedGroupSampler !!!") 70 | sampler = build_sampler( 71 | shuffler_sampler 72 | if shuffler_sampler is not None 73 | else dict(type="DistributedGroupSampler"), 74 | dict( 75 | dataset=dataset, 76 | samples_per_gpu=samples_per_gpu, 77 | num_replicas=world_size, 78 | rank=rank, 79 | seed=seed, 80 | ), 81 | ) 82 | else: 83 | sampler = build_sampler( 84 | nonshuffler_sampler 85 | if nonshuffler_sampler is not None 86 | else dict(type="DistributedSampler"), 87 | dict( 88 | dataset=dataset, 89 | num_replicas=world_size, 90 | rank=rank, 91 | shuffle=shuffle, 92 | seed=seed, 93 | ), 94 | ) 95 | 96 | batch_size = samples_per_gpu 97 | num_workers = workers_per_gpu 98 | else: 99 | # assert False, 'not support in bevformer' 100 | print("WARNING!!!!, Only can be used for obtain inference speed!!!!") 101 | sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None 102 | batch_size = num_gpus * samples_per_gpu 103 | num_workers = num_gpus * workers_per_gpu 104 | 105 | init_fn = ( 106 | partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) 107 | if seed is not None 108 | else None 109 | ) 110 | 111 | data_loader = DataLoader( 112 | dataset, 113 | batch_size=batch_size, 114 | sampler=sampler, 115 | batch_sampler=batch_sampler, 116 | num_workers=num_workers, 117 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), 118 | pin_memory=False, 119 | worker_init_fn=init_fn, 120 | **kwargs 121 | ) 122 | 123 | return data_loader 124 | 125 | 126 | def worker_init_fn(worker_id, num_workers, rank, seed): 127 | # The seed of each worker equals to 128 | # num_worker * rank + worker_id + user_seed 129 | worker_seed = num_workers * rank + worker_id + seed 130 | np.random.seed(worker_seed) 131 | random.seed(worker_seed) 132 | 133 | 134 | # Copyright (c) OpenMMLab. All rights reserved. 135 | import platform 136 | from mmcv.utils import Registry, build_from_cfg 137 | 138 | from mmdet.datasets import DATASETS 139 | from mmdet.datasets.builder import _concat_dataset 140 | 141 | if platform.system() != "Windows": 142 | # https://github.com/pytorch/pytorch/issues/973 143 | import resource 144 | 145 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 146 | base_soft_limit = rlimit[0] 147 | hard_limit = rlimit[1] 148 | soft_limit = min(max(4096, base_soft_limit), hard_limit) 149 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit)) 150 | 151 | OBJECTSAMPLERS = Registry("Object sampler") 152 | 153 | 154 | def custom_build_dataset(cfg, default_args=None): 155 | try: 156 | from mmdet3d.datasets.dataset_wrappers import CBGSDataset 157 | except: 158 | CBGSDataset = None 159 | from mmdet.datasets.dataset_wrappers import ( 160 | ClassBalancedDataset, 161 | ConcatDataset, 162 | RepeatDataset, 163 | ) 164 | 165 | if isinstance(cfg, (list, tuple)): 166 | dataset = ConcatDataset( 167 | [custom_build_dataset(c, default_args) for c in cfg] 168 | ) 169 | elif cfg["type"] == "ConcatDataset": 170 | dataset = ConcatDataset( 171 | [custom_build_dataset(c, default_args) for c in cfg["datasets"]], 172 | cfg.get("separate_eval", True), 173 | ) 174 | elif cfg["type"] == "RepeatDataset": 175 | dataset = RepeatDataset( 176 | custom_build_dataset(cfg["dataset"], default_args), cfg["times"] 177 | ) 178 | elif cfg["type"] == "ClassBalancedDataset": 179 | dataset = ClassBalancedDataset( 180 | custom_build_dataset(cfg["dataset"], default_args), 181 | cfg["oversample_thr"], 182 | ) 183 | elif cfg["type"] == "CBGSDataset": 184 | dataset = CBGSDataset( 185 | custom_build_dataset(cfg["dataset"], default_args) 186 | ) 187 | elif isinstance(cfg.get("ann_file"), (list, tuple)): 188 | dataset = _concat_dataset(cfg, default_args) 189 | else: 190 | dataset = build_from_cfg(cfg, DATASETS, default_args) 191 | 192 | return dataset 193 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import copy 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from mmcv.runner import get_dist_info 8 | from torch.utils.data.sampler import Sampler 9 | 10 | # https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157 11 | def sync_random_seed(seed=None, device="cuda"): 12 | """Make sure different ranks share the same seed. 13 | All workers must call this function, otherwise it will deadlock. 14 | This method is generally used in `DistributedSampler`, 15 | because the seed should be identical across all processes 16 | in the distributed group. 17 | In distributed sampling, different ranks should sample non-overlapped 18 | data in the dataset. Therefore, this function is used to make sure that 19 | each rank shuffles the data indices in the same order based 20 | on the same seed. Then different ranks could use different indices 21 | to select non-overlapped data from the same data list. 22 | Args: 23 | seed (int, Optional): The seed. Default to None. 24 | device (str): The device where the seed will be put on. 25 | Default to 'cuda'. 26 | Returns: 27 | int: Seed to be used. 28 | """ 29 | if seed is None: 30 | seed = np.random.randint(2 ** 31) 31 | assert isinstance(seed, int) 32 | 33 | rank, world_size = get_dist_info() 34 | 35 | if world_size == 1: 36 | return seed 37 | 38 | if rank == 0: 39 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 40 | else: 41 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 42 | dist.broadcast(random_num, src=0) 43 | return random_num.item() 44 | 45 | 46 | class InfiniteGroupEachSampleInBatchSampler(Sampler): 47 | """ 48 | Pardon this horrendous name. Basically, we want every sample to be from its own group. 49 | If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on 50 | its own group. 51 | 52 | Shuffling is only done for group order, not done within groups. 53 | """ 54 | 55 | def __init__( 56 | self, 57 | dataset, 58 | batch_size=1, 59 | world_size=None, 60 | rank=None, 61 | seed=0, 62 | skip_prob=0.5, 63 | sequence_flip_prob=0.1, 64 | ): 65 | 66 | _rank, _world_size = get_dist_info() 67 | if world_size is None: 68 | world_size = _world_size 69 | if rank is None: 70 | rank = _rank 71 | 72 | self.dataset = dataset 73 | self.batch_size = batch_size 74 | self.world_size = world_size 75 | self.rank = rank 76 | self.seed = sync_random_seed(seed) 77 | 78 | self.size = len(self.dataset) 79 | 80 | assert hasattr(self.dataset, "flag") 81 | self.flag = self.dataset.flag 82 | self.group_sizes = np.bincount(self.flag) 83 | self.groups_num = len(self.group_sizes) 84 | self.global_batch_size = batch_size * world_size 85 | assert self.groups_num >= self.global_batch_size 86 | 87 | # Now, for efficiency, make a dict group_idx: List[dataset sample_idxs] 88 | self.group_idx_to_sample_idxs = { 89 | group_idx: np.where(self.flag == group_idx)[0].tolist() 90 | for group_idx in range(self.groups_num) 91 | } 92 | 93 | # Get a generator per sample idx. Considering samples over all 94 | # GPUs, each sample position has its own generator 95 | self.group_indices_per_global_sample_idx = [ 96 | self._group_indices_per_global_sample_idx( 97 | self.rank * self.batch_size + local_sample_idx 98 | ) 99 | for local_sample_idx in range(self.batch_size) 100 | ] 101 | 102 | # Keep track of a buffer of dataset sample idxs for each local sample idx 103 | self.buffer_per_local_sample = [[] for _ in range(self.batch_size)] 104 | self.aug_per_local_sample = [None for _ in range(self.batch_size)] 105 | self.skip_prob = skip_prob 106 | self.sequence_flip_prob = sequence_flip_prob 107 | 108 | def _infinite_group_indices(self): 109 | g = torch.Generator() 110 | g.manual_seed(self.seed) 111 | while True: 112 | yield from torch.randperm(self.groups_num, generator=g).tolist() 113 | 114 | def _group_indices_per_global_sample_idx(self, global_sample_idx): 115 | yield from itertools.islice( 116 | self._infinite_group_indices(), 117 | global_sample_idx, 118 | None, 119 | self.global_batch_size, 120 | ) 121 | 122 | def __iter__(self): 123 | while True: 124 | curr_batch = [] 125 | for local_sample_idx in range(self.batch_size): 126 | skip = ( 127 | np.random.uniform() < self.skip_prob 128 | and len(self.buffer_per_local_sample[local_sample_idx]) > 1 129 | ) 130 | if len(self.buffer_per_local_sample[local_sample_idx]) == 0: 131 | # Finished current group, refill with next group 132 | # skip = False 133 | new_group_idx = next( 134 | self.group_indices_per_global_sample_idx[ 135 | local_sample_idx 136 | ] 137 | ) 138 | self.buffer_per_local_sample[ 139 | local_sample_idx 140 | ] = copy.deepcopy( 141 | self.group_idx_to_sample_idxs[new_group_idx] 142 | ) 143 | if np.random.uniform() < self.sequence_flip_prob: 144 | self.buffer_per_local_sample[local_sample_idx] = self.buffer_per_local_sample[local_sample_idx][::-1] 145 | if self.dataset.keep_consistent_seq_aug: 146 | self.aug_per_local_sample[local_sample_idx] = self.get_aug() 147 | 148 | if not self.dataset.keep_consistent_seq_aug: 149 | self.aug_per_local_sample[local_sample_idx] = self.get_aug() 150 | 151 | if skip: 152 | self.buffer_per_local_sample[local_sample_idx].pop(0) 153 | curr_batch.append( 154 | dict( 155 | idx=self.buffer_per_local_sample[local_sample_idx].pop(0), 156 | aug=self.aug_per_local_sample[local_sample_idx], 157 | ) 158 | ) 159 | 160 | yield curr_batch 161 | 162 | def __len__(self): 163 | """Length of base dataset.""" 164 | return self.size 165 | 166 | def set_epoch(self, epoch): 167 | self.epoch = epoch 168 | 169 | def get_aug(self): 170 | rot_angle = np.random.uniform(*self.dataset.rot_range) 171 | scale_ratio = np.random.uniform(*self.dataset.scale_ratio_range) 172 | aug_configs = self.dataset._sample_augmentation() 173 | return rot_angle, scale_ratio, aug_configs 174 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # https://github.com/Divadi/SOLOFusion/blob/main/mmdet3d/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py 2 | import itertools 3 | import copy 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | from mmcv.runner import get_dist_info 9 | from torch.utils.data.sampler import Sampler 10 | 11 | 12 | # https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157 13 | def sync_random_seed(seed=None, device="cuda"): 14 | """Make sure different ranks share the same seed. 15 | All workers must call this function, otherwise it will deadlock. 16 | This method is generally used in `DistributedSampler`, 17 | because the seed should be identical across all processes 18 | in the distributed group. 19 | In distributed sampling, different ranks should sample non-overlapped 20 | data in the dataset. Therefore, this function is used to make sure that 21 | each rank shuffles the data indices in the same order based 22 | on the same seed. Then different ranks could use different indices 23 | to select non-overlapped data from the same data list. 24 | Args: 25 | seed (int, Optional): The seed. Default to None. 26 | device (str): The device where the seed will be put on. 27 | Default to 'cuda'. 28 | Returns: 29 | int: Seed to be used. 30 | """ 31 | if seed is None: 32 | seed = np.random.randint(2**31) 33 | assert isinstance(seed, int) 34 | 35 | rank, world_size = get_dist_info() 36 | 37 | if world_size == 1: 38 | return seed 39 | 40 | if rank == 0: 41 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 42 | else: 43 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 44 | dist.broadcast(random_num, src=0) 45 | return random_num.item() 46 | 47 | 48 | class GroupInBatchSampler(Sampler): 49 | """ 50 | Pardon this horrendous name. Basically, we want every sample to be from its own group. 51 | If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on 52 | its own group. 53 | 54 | Shuffling is only done for group order, not done within groups. 55 | """ 56 | 57 | def __init__( 58 | self, 59 | dataset, 60 | batch_size=1, 61 | world_size=None, 62 | rank=None, 63 | seed=0, 64 | skip_prob=0.5, 65 | sequence_flip_prob=0.1, 66 | ): 67 | _rank, _world_size = get_dist_info() 68 | if world_size is None: 69 | world_size = _world_size 70 | if rank is None: 71 | rank = _rank 72 | 73 | self.dataset = dataset 74 | self.batch_size = batch_size 75 | self.world_size = world_size 76 | self.rank = rank 77 | self.seed = sync_random_seed(seed) 78 | 79 | self.size = len(self.dataset) 80 | 81 | assert hasattr(self.dataset, "flag") 82 | self.flag = self.dataset.flag 83 | self.group_sizes = np.bincount(self.flag) 84 | self.groups_num = len(self.group_sizes) 85 | self.global_batch_size = batch_size * world_size 86 | assert self.groups_num >= self.global_batch_size 87 | 88 | # Now, for efficiency, make a dict group_idx: List[dataset sample_idxs] 89 | self.group_idx_to_sample_idxs = { 90 | group_idx: np.where(self.flag == group_idx)[0].tolist() 91 | for group_idx in range(self.groups_num) 92 | } 93 | 94 | # Get a generator per sample idx. Considering samples over all 95 | # GPUs, each sample position has its own generator 96 | self.group_indices_per_global_sample_idx = [ 97 | self._group_indices_per_global_sample_idx( 98 | self.rank * self.batch_size + local_sample_idx 99 | ) 100 | for local_sample_idx in range(self.batch_size) 101 | ] 102 | 103 | # Keep track of a buffer of dataset sample idxs for each local sample idx 104 | self.buffer_per_local_sample = [[] for _ in range(self.batch_size)] 105 | self.aug_per_local_sample = [None for _ in range(self.batch_size)] 106 | self.skip_prob = skip_prob 107 | self.sequence_flip_prob = sequence_flip_prob 108 | 109 | def _infinite_group_indices(self): 110 | g = torch.Generator() 111 | g.manual_seed(self.seed) 112 | while True: 113 | yield from torch.randperm(self.groups_num, generator=g).tolist() 114 | 115 | def _group_indices_per_global_sample_idx(self, global_sample_idx): 116 | yield from itertools.islice( 117 | self._infinite_group_indices(), 118 | global_sample_idx, 119 | None, 120 | self.global_batch_size, 121 | ) 122 | 123 | def __iter__(self): 124 | while True: 125 | curr_batch = [] 126 | for local_sample_idx in range(self.batch_size): 127 | skip = ( 128 | np.random.uniform() < self.skip_prob 129 | and len(self.buffer_per_local_sample[local_sample_idx]) > 1 130 | ) 131 | if len(self.buffer_per_local_sample[local_sample_idx]) == 0: 132 | # Finished current group, refill with next group 133 | # skip = False 134 | new_group_idx = next( 135 | self.group_indices_per_global_sample_idx[ 136 | local_sample_idx 137 | ] 138 | ) 139 | self.buffer_per_local_sample[ 140 | local_sample_idx 141 | ] = copy.deepcopy( 142 | self.group_idx_to_sample_idxs[new_group_idx] 143 | ) 144 | if np.random.uniform() < self.sequence_flip_prob: 145 | self.buffer_per_local_sample[ 146 | local_sample_idx 147 | ] = self.buffer_per_local_sample[local_sample_idx][ 148 | ::-1 149 | ] 150 | if self.dataset.keep_consistent_seq_aug: 151 | self.aug_per_local_sample[ 152 | local_sample_idx 153 | ] = self.dataset.get_augmentation() 154 | 155 | if not self.dataset.keep_consistent_seq_aug: 156 | self.aug_per_local_sample[ 157 | local_sample_idx 158 | ] = self.dataset.get_augmentation() 159 | 160 | if skip: 161 | self.buffer_per_local_sample[local_sample_idx].pop(0) 162 | curr_batch.append( 163 | dict( 164 | idx=self.buffer_per_local_sample[local_sample_idx].pop( 165 | 0 166 | ), 167 | aug_config=self.aug_per_local_sample[local_sample_idx], 168 | ) 169 | ) 170 | 171 | yield curr_batch 172 | 173 | def __len__(self): 174 | """Length of base dataset.""" 175 | return self.size 176 | 177 | def set_epoch(self, epoch): 178 | self.epoch = epoch 179 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mmcv 3 | from mmdet.datasets.builder import PIPELINES 4 | 5 | 6 | @PIPELINES.register_module() 7 | class LoadMultiViewImageFromFiles(object): 8 | """Load multi channel images from a list of separate channel files. 9 | 10 | Expects results['img_filename'] to be a list of filenames. 11 | 12 | Args: 13 | to_float32 (bool, optional): Whether to convert the img to float32. 14 | Defaults to False. 15 | color_type (str, optional): Color type of the file. 16 | Defaults to 'unchanged'. 17 | """ 18 | 19 | def __init__(self, to_float32=False, color_type="unchanged"): 20 | self.to_float32 = to_float32 21 | self.color_type = color_type 22 | 23 | def __call__(self, results): 24 | """Call function to load multi-view image from files. 25 | 26 | Args: 27 | results (dict): Result dict containing multi-view image filenames. 28 | 29 | Returns: 30 | dict: The result dict containing the multi-view image data. 31 | Added keys and values are described below. 32 | 33 | - filename (str): Multi-view image filenames. 34 | - img (np.ndarray): Multi-view image arrays. 35 | - img_shape (tuple[int]): Shape of multi-view image arrays. 36 | - ori_shape (tuple[int]): Shape of original image arrays. 37 | - pad_shape (tuple[int]): Shape of padded image arrays. 38 | - scale_factor (float): Scale factor. 39 | - img_norm_cfg (dict): Normalization configuration of images. 40 | """ 41 | filename = results["img_filename"] 42 | # img is of shape (h, w, c, num_views) 43 | img = np.stack( 44 | [mmcv.imread(name, self.color_type) for name in filename], axis=-1 45 | ) 46 | if self.to_float32: 47 | img = img.astype(np.float32) 48 | results["filename"] = filename 49 | # unravel to list, see `DefaultFormatBundle` in formatting.py 50 | # which will transpose each image separately and then stack into array 51 | results["img"] = [img[..., i] for i in range(img.shape[-1])] 52 | results["img_shape"] = img.shape 53 | results["ori_shape"] = img.shape 54 | # Set initial values for default meta_keys 55 | results["pad_shape"] = img.shape 56 | results["scale_factor"] = 1.0 57 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 58 | results["img_norm_cfg"] = dict( 59 | mean=np.zeros(num_channels, dtype=np.float32), 60 | std=np.ones(num_channels, dtype=np.float32), 61 | to_rgb=False, 62 | ) 63 | return results 64 | 65 | def __repr__(self): 66 | """str: Return a string that describes the module.""" 67 | repr_str = self.__class__.__name__ 68 | repr_str += f"(to_float32={self.to_float32}, " 69 | repr_str += f"color_type='{self.color_type}')" 70 | return repr_str 71 | 72 | @PIPELINES.register_module() 73 | class LoadPointsFromFile(object): 74 | """Load Points From File. 75 | 76 | Load points from file. 77 | 78 | Args: 79 | coord_type (str): The type of coordinates of points cloud. 80 | Available options includes: 81 | - 'LIDAR': Points in LiDAR coordinates. 82 | - 'DEPTH': Points in depth coordinates, usually for indoor dataset. 83 | - 'CAMERA': Points in camera coordinates. 84 | load_dim (int, optional): The dimension of the loaded points. 85 | Defaults to 6. 86 | use_dim (list[int], optional): Which dimensions of the points to use. 87 | Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 88 | or use_dim=[0, 1, 2, 3] to use the intensity dimension. 89 | shift_height (bool, optional): Whether to use shifted height. 90 | Defaults to False. 91 | use_color (bool, optional): Whether to use color features. 92 | Defaults to False. 93 | file_client_args (dict, optional): Config dict of file clients, 94 | refer to 95 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py 96 | for more details. Defaults to dict(backend='disk'). 97 | """ 98 | 99 | def __init__( 100 | self, 101 | coord_type, 102 | load_dim=6, 103 | use_dim=[0, 1, 2], 104 | shift_height=False, 105 | use_color=False, 106 | file_client_args=dict(backend="disk"), 107 | feather_points=False, 108 | ): 109 | self.shift_height = shift_height 110 | self.use_color = use_color 111 | if isinstance(use_dim, int): 112 | use_dim = list(range(use_dim)) 113 | assert ( 114 | max(use_dim) < load_dim 115 | ), f"Expect all used dimensions < {load_dim}, got {use_dim}" 116 | assert coord_type in ["CAMERA", "LIDAR", "DEPTH"] 117 | 118 | self.coord_type = coord_type 119 | self.load_dim = load_dim 120 | self.use_dim = use_dim 121 | self.file_client_args = file_client_args.copy() 122 | self.file_client = None 123 | self.feather_points = feather_points 124 | 125 | def _load_points(self, pts_filename): 126 | """Private function to load point clouds data. 127 | 128 | Args: 129 | pts_filename (str): Filename of point clouds data. 130 | 131 | Returns: 132 | np.ndarray: An array containing point clouds data. 133 | """ 134 | if self.file_client is None: 135 | self.file_client = mmcv.FileClient(**self.file_client_args) 136 | try: 137 | pts_bytes = self.file_client.get(pts_filename) 138 | points = np.frombuffer(pts_bytes, dtype=np.float32) 139 | 140 | except ConnectionError: 141 | mmcv.check_file_exist(pts_filename) 142 | if pts_filename.endswith(".npy"): 143 | points = np.load(pts_filename) 144 | else: 145 | points = np.fromfile(pts_filename, dtype=np.float32) 146 | 147 | return points 148 | 149 | def __call__(self, results): 150 | """Call function to load points data from file. 151 | 152 | Args: 153 | results (dict): Result dict containing point clouds data. 154 | 155 | Returns: 156 | dict: The result dict containing the point clouds data. 157 | Added key and value are described below. 158 | 159 | - points (:obj:`BasePoints`): Point clouds data. 160 | """ 161 | pts_filename = results["pts_filename"] 162 | points = self._load_points(pts_filename) 163 | points = points.reshape(-1, self.load_dim) 164 | points = points[:, self.use_dim] 165 | attribute_dims = None 166 | 167 | if self.shift_height: 168 | floor_height = np.percentile(points[:, 2], 0.99) 169 | height = points[:, 2] - floor_height 170 | points = np.concatenate( 171 | [points[:, :3], np.expand_dims(height, 1), points[:, 3:]], 1 172 | ) 173 | attribute_dims = dict(height=3) 174 | 175 | if self.use_color: 176 | assert len(self.use_dim) >= 6 177 | if attribute_dims is None: 178 | attribute_dims = dict() 179 | attribute_dims.update( 180 | dict( 181 | color=[ 182 | points.shape[1] - 3, 183 | points.shape[1] - 2, 184 | points.shape[1] - 1, 185 | ] 186 | ) 187 | ) 188 | 189 | results["points"] = points 190 | return results 191 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/apis/mmdet_train.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------- 2 | # Copyright (c) OpenMMLab. All rights reserved. 3 | # --------------------------------------------- 4 | # Modified by Zhiqi Li 5 | # --------------------------------------------- 6 | import random 7 | import warnings 8 | 9 | import numpy as np 10 | import torch 11 | import torch.distributed as dist 12 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 13 | from mmcv.runner import ( 14 | HOOKS, 15 | DistSamplerSeedHook, 16 | EpochBasedRunner, 17 | Fp16OptimizerHook, 18 | OptimizerHook, 19 | build_optimizer, 20 | build_runner, 21 | get_dist_info, 22 | ) 23 | from mmcv.utils import build_from_cfg 24 | 25 | from mmdet.core import EvalHook 26 | 27 | from mmdet.datasets import build_dataset, replace_ImageToTensor 28 | from mmdet.utils import get_root_logger 29 | import time 30 | import os.path as osp 31 | from projects.mmdet3d_plugin.datasets.builder import build_dataloader 32 | from projects.mmdet3d_plugin.core.evaluation.eval_hooks import ( 33 | CustomDistEvalHook, 34 | ) 35 | from projects.mmdet3d_plugin.datasets import custom_build_dataset 36 | 37 | 38 | def custom_train_detector( 39 | model, 40 | dataset, 41 | cfg, 42 | distributed=False, 43 | validate=False, 44 | timestamp=None, 45 | meta=None, 46 | ): 47 | logger = get_root_logger(cfg.log_level) 48 | 49 | # prepare data loaders 50 | 51 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset] 52 | # assert len(dataset)==1s 53 | if "imgs_per_gpu" in cfg.data: 54 | logger.warning( 55 | '"imgs_per_gpu" is deprecated in MMDet V2.0. ' 56 | 'Please use "samples_per_gpu" instead' 57 | ) 58 | if "samples_per_gpu" in cfg.data: 59 | logger.warning( 60 | f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and ' 61 | f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"' 62 | f"={cfg.data.imgs_per_gpu} is used in this experiments" 63 | ) 64 | else: 65 | logger.warning( 66 | 'Automatically set "samples_per_gpu"="imgs_per_gpu"=' 67 | f"{cfg.data.imgs_per_gpu} in this experiments" 68 | ) 69 | cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu 70 | 71 | if "runner" in cfg: 72 | runner_type = cfg.runner["type"] 73 | else: 74 | runner_type = "EpochBasedRunner" 75 | data_loaders = [ 76 | build_dataloader( 77 | ds, 78 | cfg.data.samples_per_gpu, 79 | cfg.data.workers_per_gpu, 80 | # cfg.gpus will be ignored if distributed 81 | len(cfg.gpu_ids), 82 | dist=distributed, 83 | seed=cfg.seed, 84 | nonshuffler_sampler=dict( 85 | type="DistributedSampler" 86 | ), # dict(type='DistributedSampler'), 87 | runner_type=runner_type, 88 | ) 89 | for ds in dataset 90 | ] 91 | 92 | # put model on gpus 93 | if distributed: 94 | find_unused_parameters = cfg.get("find_unused_parameters", False) 95 | # Sets the `find_unused_parameters` parameter in 96 | # torch.nn.parallel.DistributedDataParallel 97 | model = MMDistributedDataParallel( 98 | model.cuda(), 99 | device_ids=[torch.cuda.current_device()], 100 | broadcast_buffers=False, 101 | find_unused_parameters=find_unused_parameters, 102 | ) 103 | 104 | else: 105 | model = MMDataParallel( 106 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids 107 | ) 108 | 109 | # build runner 110 | optimizer = build_optimizer(model, cfg.optimizer) 111 | 112 | if "runner" not in cfg: 113 | cfg.runner = { 114 | "type": "EpochBasedRunner", 115 | "max_epochs": cfg.total_epochs, 116 | } 117 | warnings.warn( 118 | "config is now expected to have a `runner` section, " 119 | "please set `runner` in your config.", 120 | UserWarning, 121 | ) 122 | else: 123 | if "total_epochs" in cfg: 124 | assert cfg.total_epochs == cfg.runner.max_epochs 125 | 126 | runner = build_runner( 127 | cfg.runner, 128 | default_args=dict( 129 | model=model, 130 | optimizer=optimizer, 131 | work_dir=cfg.work_dir, 132 | logger=logger, 133 | meta=meta, 134 | ), 135 | ) 136 | 137 | # an ugly workaround to make .log and .log.json filenames the same 138 | runner.timestamp = timestamp 139 | 140 | # fp16 setting 141 | fp16_cfg = cfg.get("fp16", None) 142 | if fp16_cfg is not None: 143 | optimizer_config = Fp16OptimizerHook( 144 | **cfg.optimizer_config, **fp16_cfg, distributed=distributed 145 | ) 146 | elif distributed and "type" not in cfg.optimizer_config: 147 | optimizer_config = OptimizerHook(**cfg.optimizer_config) 148 | else: 149 | optimizer_config = cfg.optimizer_config 150 | 151 | # register hooks 152 | runner.register_training_hooks( 153 | cfg.lr_config, 154 | optimizer_config, 155 | cfg.checkpoint_config, 156 | cfg.log_config, 157 | cfg.get("momentum_config", None), 158 | ) 159 | 160 | # register profiler hook 161 | # trace_config = dict(type='tb_trace', dir_name='work_dir') 162 | # profiler_config = dict(on_trace_ready=trace_config) 163 | # runner.register_profiler_hook(profiler_config) 164 | 165 | if distributed: 166 | if isinstance(runner, EpochBasedRunner): 167 | runner.register_hook(DistSamplerSeedHook()) 168 | 169 | # register eval hooks 170 | if validate: 171 | # Support batch_size > 1 in validation 172 | val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1) 173 | if val_samples_per_gpu > 1: 174 | assert False 175 | # Replace 'ImageToTensor' to 'DefaultFormatBundle' 176 | cfg.data.val.pipeline = replace_ImageToTensor( 177 | cfg.data.val.pipeline 178 | ) 179 | val_dataset = custom_build_dataset(cfg.data.val, dict(test_mode=True)) 180 | 181 | val_dataloader = build_dataloader( 182 | val_dataset, 183 | samples_per_gpu=val_samples_per_gpu, 184 | workers_per_gpu=cfg.data.workers_per_gpu, 185 | dist=distributed, 186 | shuffle=False, 187 | nonshuffler_sampler=dict(type="DistributedSampler"), 188 | ) 189 | eval_cfg = cfg.get("evaluation", {}) 190 | eval_cfg["by_epoch"] = cfg.runner["type"] != "IterBasedRunner" 191 | eval_cfg["jsonfile_prefix"] = osp.join( 192 | "val", 193 | cfg.work_dir, 194 | time.ctime().replace(" ", "_").replace(":", "_"), 195 | ) 196 | eval_hook = CustomDistEvalHook if distributed else EvalHook 197 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg)) 198 | 199 | # user-defined hooks 200 | if cfg.get("custom_hooks", None): 201 | custom_hooks = cfg.custom_hooks 202 | assert isinstance( 203 | custom_hooks, list 204 | ), f"custom_hooks expect list type, but got {type(custom_hooks)}" 205 | for hook_cfg in cfg.custom_hooks: 206 | assert isinstance(hook_cfg, dict), ( 207 | "Each item in custom_hooks expects dict type, but got " 208 | f"{type(hook_cfg)}" 209 | ) 210 | hook_cfg = hook_cfg.copy() 211 | priority = hook_cfg.pop("priority", "NORMAL") 212 | hook = build_from_cfg(hook_cfg, HOOKS) 213 | runner.register_hook(hook, priority=priority) 214 | 215 | if cfg.resume_from: 216 | runner.resume(cfg.resume_from) 217 | elif cfg.load_from: 218 | if cfg.get('revise_keys', None): 219 | runner.load_checkpoint(cfg.load_from, revise_keys=cfg['revise_keys']) 220 | else: 221 | runner.load_checkpoint(cfg.load_from) 222 | runner.run(data_loaders, cfg.workflow) 223 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/allocation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import numpy as np 5 | 6 | from mmcv.cnn.bricks.registry import PLUGIN_LAYERS 7 | from .detection3d.decoder import W, L, H, SIN_YAW, COS_YAW 8 | 9 | @PLUGIN_LAYERS.register_module() 10 | class DynamicQueryAllocation(nn.Module): 11 | def __init__(self, 12 | with_attn_mask=False, 13 | with_project_wh=False, 14 | limit_anchor_size=[35, 35, 10], 15 | limit_corners_num=[100] * 6, 16 | ): 17 | super().__init__() 18 | self.with_attn_mask = with_attn_mask 19 | self.with_project_wh = with_project_wh 20 | self.limit_anchor_size = limit_anchor_size 21 | self.limit_corners_num = limit_corners_num 22 | 23 | def forward(self, anchor3d, metas): 24 | outputs = self.projection_allocation(anchor3d, metas) 25 | return outputs 26 | 27 | def projection_allocation(self, anchor3d, metas): 28 | device = anchor3d.device 29 | anchor3d_center = anchor3d[..., :3] 30 | lidar2imgs = torch.tile(metas['projection_mat'][:, None], (1, anchor3d.shape[1], 1, 1, 1)) 31 | batch_size, num_anchor3d, num_cams = lidar2imgs.shape[:3] 32 | img_w, img_h = map(int, metas['image_wh'][0, 0].tolist()) 33 | 34 | # get rotation mat 35 | rotation_mat = anchor3d.new_zeros([batch_size, num_anchor3d, 3, 3]) 36 | rotation_mat[:, :, 0, 0] = anchor3d[:, :, COS_YAW] 37 | rotation_mat[:, :, 0, 1] = -anchor3d[:, :, SIN_YAW] 38 | rotation_mat[:, :, 1, 0] = anchor3d[:, :, SIN_YAW] 39 | rotation_mat[:, :, 1, 1] = anchor3d[:, :, COS_YAW] 40 | rotation_mat[:, :, 2, 2] = 1 41 | 42 | # get anchor corners 43 | corners_norm = anchor3d.new_tensor(np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)) 44 | corners_norm = corners_norm - anchor3d.new_tensor([0.5, 0.5, 0.5]) 45 | 46 | anchor3d_size = anchor3d[..., [W, L, H]].exp() 47 | anchor3d_size = anchor3d_size.clamp(max=torch.tensor(self.limit_anchor_size, device=device).view(1, 1, -1)) 48 | 49 | anchor3d_corners = anchor3d_size[:, :, None, :] * corners_norm[None, None, :, :] 50 | anchor3d_corners = torch.matmul(rotation_mat[:, :, None], anchor3d_corners[..., None]).squeeze(-1) 51 | anchor3d_corners = anchor3d_corners + anchor3d_center[:, :, None, :] 52 | anchor3d_corners = torch.cat([anchor3d_corners, anchor3d_center[:, :, None, :]], dim=-2) 53 | 54 | # get points in camera and image plane 55 | coord_pts3d = torch.cat([anchor3d_corners, torch.ones_like(anchor3d_corners[..., :1])], -1) 56 | coord_pts3d = coord_pts3d.view(batch_size, num_anchor3d, 1, 9, 4, 1).repeat(1, 1, num_cams, 1, 1, 1) 57 | coord_pts2d = torch.matmul(lidar2imgs[:, :, :, None], coord_pts3d).squeeze(-1) 58 | 59 | center_pts2d = coord_pts2d[..., -1, :] 60 | corner_pts2d = coord_pts2d[..., :-1, :] 61 | center_depth2d = center_pts2d[..., 2:3] 62 | corner_depth2d = corner_pts2d[..., 2:3] 63 | 64 | center_pts2d = center_pts2d[..., :2] / center_depth2d.clamp(1e-5) 65 | corner_pts2d = corner_pts2d[..., :2] / corner_depth2d.clamp(1e-5) 66 | 67 | center_valid = ((0 < center_pts2d[..., 0]) & (center_pts2d[..., 0] < img_w) & 68 | (0 < center_pts2d[..., 1]) & (center_pts2d[..., 1] < img_h)) 69 | 70 | corner_valid1 = (corner_depth2d[..., 0] > 0) 71 | corner_valid2 = ((0 < corner_pts2d[..., 0]) & (corner_pts2d[..., 0] < img_w) & 72 | (0 < corner_pts2d[..., 1]) & (corner_pts2d[..., 1] < img_h)) 73 | corner_valid = torch.logical_and(corner_valid1, corner_valid2).any(-1) 74 | 75 | # project corners to get corner-centers 76 | x_min = torch.clamp(corner_pts2d[..., 0].min(-1).values, min=0, max=img_w) 77 | x_max = torch.clamp(corner_pts2d[..., 0].max(-1).values, min=0, max=img_w) 78 | y_min = torch.clamp(corner_pts2d[..., 1].min(-1).values, min=0, max=img_h) 79 | y_max = torch.clamp(corner_pts2d[..., 1].max(-1).values, min=0, max=img_h) 80 | 81 | cx, cy = (x_min + x_max) / 2, (y_min + y_max) / 2 82 | select_centers = torch.stack([cx, cy], dim=-1) 83 | select_centers[center_valid] = center_pts2d[center_valid] # overwrite center points 84 | 85 | if self.training and self.limit_corners_num: 86 | corner_valid = torch.where(torch.logical_and(corner_valid, center_valid), False, corner_valid) 87 | corner_valid = self.random_sample_corner_mask(corner_valid) 88 | 89 | # divide to groups 90 | trans_mask = torch.logical_or(center_valid, corner_valid) 91 | trans_shape = trans_mask.sum(1) 92 | trans_meta_shape = trans_shape.max(0).values 93 | trans_meta_start = torch.cat([torch.zeros_like(trans_meta_shape[:1]), trans_meta_shape]) 94 | trans_meta_cumsum = trans_meta_start.cumsum(-1).tolist() 95 | 96 | trans_start = trans_meta_start.cumsum(-1)[:num_cams].unsqueeze(0).repeat(batch_size, 1) 97 | trans_end = trans_start + trans_shape 98 | 99 | query_groups = [(qs, qe) for qs, qe in zip(trans_meta_cumsum[:-1], trans_meta_cumsum[1:])] 100 | num_anchor2d = trans_meta_shape.sum() 101 | 102 | # create reference points 103 | trans_mask_tmp = trans_mask.permute(0, 2, 1).flatten(0, 1) 104 | select_centers = select_centers.permute(0, 2, 1, 3).flatten(0, 1) 105 | select_depths = center_depth2d.permute(0, 2, 1, 3).flatten(0, 1) # corner depth is fake 106 | 107 | select_centers = select_centers[trans_mask_tmp] 108 | select_depths = select_depths[trans_mask_tmp] 109 | 110 | selected_mask = torch.zeros((batch_size, num_anchor2d), device=device) 111 | attn_mask = torch.ones((batch_size, num_anchor2d, num_anchor2d), device=device).fill_(float("-inf")) 112 | for bs in range(batch_size): 113 | for st, ed in zip(trans_start[bs], trans_end[bs]): 114 | selected_mask[bs, st:ed] = 1.0 115 | if self.with_attn_mask: 116 | attn_mask[bs, st:ed, st:ed] = 0.0 117 | selected_mask = selected_mask.unsqueeze(-1).repeat(1, 1, 2).to(torch.bool) 118 | 119 | ref_pts2d = torch.zeros((batch_size, num_anchor2d, 2), device=device) 120 | ref_depth2d = torch.zeros((batch_size, num_anchor2d, 1), device=device) 121 | 122 | ref_pts2d = torch.masked_scatter(ref_pts2d, selected_mask[..., :2], select_centers) 123 | ref_depth2d = torch.masked_scatter(ref_depth2d, selected_mask[..., :1], select_depths.abs()) 124 | 125 | ref_pts2d = ref_pts2d / ref_pts2d.new_tensor([img_w, img_h]) 126 | 127 | # create trans matrix 128 | trans_matrix = nn.Parameter( 129 | torch.zeros((batch_size, num_anchor2d, num_anchor3d), device=device), requires_grad=False) 130 | 131 | meta_mask = trans_mask.to(torch.float) + center_valid.to(torch.float) 132 | meta_mask = meta_mask.permute(0, 2, 1) 133 | 134 | for bs in range(batch_size): 135 | cam_index, pts3d_index = torch.nonzero(meta_mask[bs]).chunk(2, dim=1) 136 | cam_index, pts3d_index = cam_index[:, 0], pts3d_index[:, 0] 137 | pts2d_index = torch.cat( 138 | [torch.arange(st, ed, device=device) for st, ed in zip(trans_start[bs], trans_end[bs])]) 139 | trans_matrix[bs, pts2d_index, pts3d_index] = meta_mask[bs, cam_index, pts3d_index] 140 | 141 | center_matrix = (trans_matrix == 2).to(torch.float) 142 | trans_matrix = (trans_matrix >= 1).to(torch.float) # include center and corner points 143 | 144 | return ref_pts2d, ref_depth2d, trans_mask, trans_shape, trans_matrix, center_matrix, query_groups, attn_mask 145 | 146 | def random_sample_corner_mask(self, corner_valid): 147 | batch_size, num_anchor3d, num_cams = corner_valid.shape 148 | corner_view = corner_valid.permute(0, 2, 1).reshape(batch_size * num_cams, -1) 149 | corner_nums = corner_view.sum(-1).detach().cpu().numpy() 150 | limit_nums = np.array(self.limit_corners_num * batch_size) 151 | 152 | remove_ids = np.where(corner_nums > limit_nums)[0] 153 | for remove_id in remove_ids: 154 | bs = remove_id // num_cams 155 | cam_id = remove_id % num_cams 156 | remove_num = max(corner_nums[remove_id] - limit_nums[remove_id], 0) 157 | remove_index = np.random.permutation(corner_nums[remove_id])[remove_num:] 158 | corner_valid[bs, :, cam_id][remove_index] = False 159 | 160 | return corner_valid -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | from projects.mmdet3d_plugin.core.box3d import * 8 | 9 | 10 | def box3d_to_corners(box3d): 11 | if isinstance(box3d, torch.Tensor): 12 | box3d = box3d.detach().cpu().numpy() 13 | corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1) 14 | corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]] 15 | # use relative origin [0.5, 0.5, 0] 16 | corners_norm = corners_norm - np.array([0.5, 0.5, 0.5]) 17 | corners = box3d[:, None, [W, L, H]] * corners_norm.reshape([1, 8, 3]) 18 | 19 | # rotate around z axis 20 | rot_cos = np.cos(box3d[:, YAW]) 21 | rot_sin = np.sin(box3d[:, YAW]) 22 | rot_mat = np.tile(np.eye(3)[None], (box3d.shape[0], 1, 1)) 23 | rot_mat[:, 0, 0] = rot_cos 24 | rot_mat[:, 0, 1] = -rot_sin 25 | rot_mat[:, 1, 0] = rot_sin 26 | rot_mat[:, 1, 1] = rot_cos 27 | corners = (rot_mat[:, None] @ corners[..., None]).squeeze(axis=-1) 28 | corners += box3d[:, None, :3] 29 | return corners 30 | 31 | 32 | def plot_rect3d_on_img( 33 | img, num_rects, rect_corners, color=(0, 255, 0), thickness=1 34 | ): 35 | """Plot the boundary lines of 3D rectangular on 2D images. 36 | 37 | Args: 38 | img (numpy.array): The numpy array of image. 39 | num_rects (int): Number of 3D rectangulars. 40 | rect_corners (numpy.array): Coordinates of the corners of 3D 41 | rectangulars. Should be in the shape of [num_rect, 8, 2]. 42 | color (tuple[int], optional): The color to draw bboxes. 43 | Default: (0, 255, 0). 44 | thickness (int, optional): The thickness of bboxes. Default: 1. 45 | """ 46 | line_indices = ( 47 | (0, 1), 48 | (0, 3), 49 | (0, 4), 50 | (1, 2), 51 | (1, 5), 52 | (3, 2), 53 | (3, 7), 54 | (4, 5), 55 | (4, 7), 56 | (2, 6), 57 | (5, 6), 58 | (6, 7), 59 | ) 60 | h, w = img.shape[:2] 61 | for i in range(num_rects): 62 | corners = np.clip(rect_corners[i], -1e4, 1e5).astype(np.int32) 63 | for start, end in line_indices: 64 | if ( 65 | (corners[start, 1] >= h or corners[start, 1] < 0) 66 | or (corners[start, 0] >= w or corners[start, 0] < 0) 67 | ) and ( 68 | (corners[end, 1] >= h or corners[end, 1] < 0) 69 | or (corners[end, 0] >= w or corners[end, 0] < 0) 70 | ): 71 | continue 72 | if isinstance(color[0], int): 73 | cv2.line( 74 | img, 75 | (corners[start, 0], corners[start, 1]), 76 | (corners[end, 0], corners[end, 1]), 77 | color, 78 | thickness, 79 | cv2.LINE_AA, 80 | ) 81 | else: 82 | cv2.line( 83 | img, 84 | (corners[start, 0], corners[start, 1]), 85 | (corners[end, 0], corners[end, 1]), 86 | color[i], 87 | thickness, 88 | cv2.LINE_AA, 89 | ) 90 | 91 | return img.astype(np.uint8) 92 | 93 | 94 | def draw_image_bbox2d_on_img( 95 | bboxes2d, raw_img, img_metas=None, color=(0, 255, 0), thickness=1 96 | ): 97 | for i, bbox2d in enumerate(bboxes2d): 98 | bbox = np.array(bbox2d) 99 | raw_img = cv2.rectangle(raw_img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, thickness) 100 | return raw_img 101 | 102 | 103 | def draw_lidar_bbox3d_on_img( 104 | bboxes3d, raw_img, lidar2img_rt, img_metas=None, color=(0, 255, 0), thickness=1 105 | ): 106 | """Project the 3D bbox on 2D plane and draw on input image. 107 | 108 | Args: 109 | bboxes3d (:obj:`LiDARInstance3DBoxes`): 110 | 3d bbox in lidar coordinate system to visualize. 111 | raw_img (numpy.array): The numpy array of image. 112 | lidar2img_rt (numpy.array, shape=[4, 4]): The projection matrix 113 | according to the camera intrinsic parameters. 114 | img_metas (dict): Useless here. 115 | color (tuple[int], optional): The color to draw bboxes. 116 | Default: (0, 255, 0). 117 | thickness (int, optional): The thickness of bboxes. Default: 1. 118 | """ 119 | img = raw_img.copy() 120 | # corners_3d = bboxes3d.corners 121 | corners_3d = box3d_to_corners(bboxes3d) 122 | num_bbox = corners_3d.shape[0] 123 | pts_4d = np.concatenate( 124 | [corners_3d.reshape(-1, 3), np.ones((num_bbox * 8, 1))], axis=-1 125 | ) 126 | lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4) 127 | if isinstance(lidar2img_rt, torch.Tensor): 128 | lidar2img_rt = lidar2img_rt.cpu().numpy() 129 | pts_2d = pts_4d @ lidar2img_rt.T 130 | 131 | pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5) 132 | pts_2d[:, 0] /= pts_2d[:, 2] 133 | pts_2d[:, 1] /= pts_2d[:, 2] 134 | imgfov_pts_2d = pts_2d[..., :2].reshape(num_bbox, 8, 2) 135 | 136 | return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness) 137 | 138 | 139 | def draw_points_on_img(points, img, lidar2img_rt, color=(0, 255, 0), circle=4): 140 | img = img.copy() 141 | N = points.shape[0] 142 | points = points.cpu().numpy() 143 | lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4) 144 | if isinstance(lidar2img_rt, torch.Tensor): 145 | lidar2img_rt = lidar2img_rt.cpu().numpy() 146 | pts_2d = ( 147 | np.sum(points[:, :, None] * lidar2img_rt[:3, :3], axis=-1) 148 | + lidar2img_rt[:3, 3] 149 | ) 150 | pts_2d[..., 2] = np.clip(pts_2d[..., 2], a_min=1e-5, a_max=1e5) 151 | pts_2d = pts_2d[..., :2] / pts_2d[..., 2:3] 152 | pts_2d = np.clip(pts_2d, -1e4, 1e4).astype(np.int32) 153 | 154 | for i in range(N): 155 | for point in pts_2d[i]: 156 | if isinstance(color[0], int): 157 | color_tmp = color 158 | else: 159 | color_tmp = color[i] 160 | cv2.circle(img, point.tolist(), circle, color_tmp, thickness=-1) 161 | return img.astype(np.uint8) 162 | 163 | 164 | def draw_lidar_bbox3d_on_bev(bboxes_3d, bev_size, bev_range=115, color=(255, 0, 0), thickness=3, bev_img=None): 165 | if isinstance(bev_size, (list, tuple)): 166 | bev_h, bev_w = bev_size 167 | else: 168 | bev_h, bev_w = bev_size, bev_size 169 | 170 | bev_resolution = bev_range / bev_h 171 | 172 | # init bev image 173 | if bev_img is None: 174 | bev = np.ones([bev_h, bev_w, 3]) * 255 175 | 176 | marking_color = (110, 110, 110) 177 | 178 | for cir in range(int(bev_range / 20)): 179 | cv2.circle(bev, (int(bev_h / 2), int(bev_w / 2)), 180 | int((cir + 1) * 10 / bev_resolution), marking_color, thickness=thickness) 181 | cv2.line(bev, (0, int(bev_h / 2)), (bev_w, int(bev_h / 2)), marking_color, thickness=thickness-1) 182 | cv2.line(bev, (int(bev_w / 2), 0), (int(bev_w / 2), bev_h), marking_color, thickness=thickness-1) 183 | else: 184 | bev = bev_img 185 | 186 | if len(bboxes_3d) != 0: 187 | bev_corners = box3d_to_corners(bboxes_3d)[:, [0, 3, 4, 7]][..., [0, 1]] 188 | xs = bev_corners[..., 0] / bev_resolution + bev_w / 2 189 | ys = -bev_corners[..., 1] / bev_resolution + bev_h / 2 190 | for obj_idx, (x, y) in enumerate(zip(xs, ys)): 191 | for p1, p2 in ((0, 1), (0, 2), (1, 3), (2, 3)): 192 | if isinstance(color[0], (list, tuple)): 193 | tmp = color[obj_idx] 194 | else: 195 | tmp = color 196 | cv2.line( 197 | bev, 198 | (int(x[p1]), int(y[p1])), 199 | (int(x[p2]), int(y[p2])), 200 | tmp, 201 | thickness=thickness, 202 | ) 203 | return bev.astype(np.uint8) 204 | 205 | 206 | def draw_lidar_bbox3d(bboxes_3d, imgs, lidar2imgs, color=(255, 0, 0)): 207 | vis_imgs = [] 208 | for i, (img, lidar2img) in enumerate(zip(imgs, lidar2imgs)): 209 | vis_imgs.append( 210 | draw_lidar_bbox3d_on_img(bboxes_3d, img, lidar2img, color=color) 211 | ) 212 | 213 | num_imgs = len(vis_imgs) 214 | if num_imgs < 4 or num_imgs % 2 != 0: 215 | vis_imgs = np.concatenate(vis_imgs, axis=1) 216 | else: 217 | vis_imgs = np.concatenate([ 218 | np.concatenate(vis_imgs[:num_imgs//2], axis=1), 219 | np.concatenate(vis_imgs[num_imgs//2:], axis=1) 220 | ], axis=0) 221 | 222 | bev = draw_lidar_bbox3d_on_bev(bboxes_3d, vis_imgs.shape[0], color=color) 223 | vis_imgs = np.concatenate([bev, vis_imgs], axis=1) 224 | return vis_imgs 225 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/instance_bank.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from mmcv.utils import build_from_cfg 8 | from mmcv.cnn.bricks.registry import PLUGIN_LAYERS 9 | 10 | __all__ = ["InstanceBank"] 11 | 12 | 13 | def topk(confidence, k, *inputs): 14 | bs, N = confidence.shape[:2] 15 | confidence, indices = torch.topk(confidence, k, dim=1) 16 | indices = (indices + torch.arange(bs, device=indices.device)[:, None] * N).reshape(-1) 17 | outputs = [] 18 | for input in inputs: 19 | outputs.append(input.flatten(end_dim=1)[indices].reshape(bs, k, -1)) 20 | return confidence, outputs 21 | 22 | 23 | @PLUGIN_LAYERS.register_module() 24 | class InstanceBank(nn.Module): 25 | def __init__( 26 | self, 27 | num_anchor, 28 | embed_dims, 29 | anchor, 30 | anchor_handler=None, 31 | num_temp_instances=0, 32 | default_time_interval=0.5, 33 | confidence_decay=0.6, 34 | anchor_grad=True, 35 | feat_grad=True, 36 | max_time_interval=2, 37 | ): 38 | super(InstanceBank, self).__init__() 39 | self.embed_dims = embed_dims 40 | self.num_temp_instances = num_temp_instances 41 | self.default_time_interval = default_time_interval 42 | self.confidence_decay = confidence_decay 43 | self.max_time_interval = max_time_interval 44 | 45 | if anchor_handler is not None: 46 | anchor_handler = build_from_cfg(anchor_handler, PLUGIN_LAYERS) 47 | assert hasattr(anchor_handler, "anchor_projection") 48 | self.anchor_handler = anchor_handler 49 | 50 | if isinstance(anchor, str): 51 | anchor = np.load(anchor) 52 | elif isinstance(anchor, (list, tuple)): 53 | anchor = np.array(anchor) 54 | self.num_anchor = min(len(anchor), num_anchor) 55 | anchor = anchor[:num_anchor] 56 | 57 | self.anchor_init = anchor 58 | self.anchor = nn.Parameter( 59 | torch.tensor(anchor, dtype=torch.float32), requires_grad=anchor_grad) 60 | self.instance_feature = nn.Parameter( 61 | torch.zeros([self.anchor.shape[0], self.embed_dims]), requires_grad=feat_grad) 62 | self.reset() 63 | 64 | def init_weight(self): 65 | self.anchor.data = self.anchor.data.new_tensor(self.anchor_init) 66 | if self.instance_feature.requires_grad: 67 | torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1) 68 | 69 | def reset(self): 70 | self.cached_feature = None 71 | self.cached_anchor = None 72 | self.metas = None 73 | self.mask = None 74 | self.confidence = None 75 | self.temp_confidence = None 76 | self.instance_id = None 77 | self.prev_id = 0 78 | 79 | def get(self, batch_size, metas=None, dn_metas=None): 80 | instance_feature = torch.tile(self.instance_feature[None], (batch_size, 1, 1)) 81 | anchor = torch.tile(self.anchor[None], (batch_size, 1, 1)) 82 | 83 | if self.cached_anchor is not None and batch_size == self.cached_anchor.shape[0]: 84 | history_time = self.metas["timestamp"] 85 | time_interval = metas["timestamp"] - history_time 86 | time_interval = time_interval.to(dtype=instance_feature.dtype) 87 | self.mask = torch.abs(time_interval) <= self.max_time_interval 88 | 89 | if self.anchor_handler is not None: 90 | T_temp2cur = self.cached_anchor.new_tensor( 91 | np.stack( 92 | [ 93 | x["T_global_inv"] @ self.metas["img_metas"][i]["T_global"] 94 | for i, x in enumerate(metas["img_metas"]) 95 | ] 96 | ) 97 | ) 98 | self.cached_anchor = self.anchor_handler.anchor_projection( 99 | self.cached_anchor, [T_temp2cur], time_intervals=[-time_interval] 100 | )[0] 101 | 102 | if (self.anchor_handler is not None and dn_metas is not None 103 | and batch_size == dn_metas["dn_anchor"].shape[0]): 104 | num_dn_group, num_dn = dn_metas["dn_anchor"].shape[1:3] 105 | dn_anchor = self.anchor_handler.anchor_projection( 106 | dn_metas["dn_anchor"].flatten(1, 2), [T_temp2cur], time_intervals=[-time_interval] 107 | )[0] 108 | dn_metas["dn_anchor"] = dn_anchor.reshape(batch_size, num_dn_group, num_dn, -1) 109 | 110 | time_interval = torch.where( 111 | torch.logical_and(time_interval != 0, self.mask), 112 | time_interval, 113 | time_interval.new_tensor(self.default_time_interval), 114 | ) 115 | else: 116 | self.reset() 117 | time_interval = instance_feature.new_tensor([self.default_time_interval] * batch_size) 118 | 119 | return instance_feature, anchor, self.cached_feature, self.cached_anchor, time_interval 120 | 121 | def update(self, instance_feature, anchor, confidence): 122 | if self.cached_feature is None: 123 | return instance_feature, anchor 124 | 125 | num_dn = 0 126 | if instance_feature.shape[1] > self.num_anchor: 127 | num_dn = instance_feature.shape[1] - self.num_anchor 128 | dn_instance_feature = instance_feature[:, -num_dn:] 129 | dn_anchor = anchor[:, -num_dn:] 130 | instance_feature = instance_feature[:, : self.num_anchor] 131 | anchor = anchor[:, : self.num_anchor] 132 | confidence = confidence[:, : self.num_anchor] 133 | 134 | N = self.num_anchor - self.num_temp_instances 135 | confidence = confidence.max(dim=-1).values 136 | _, (selected_feature, selected_anchor) = topk(confidence, N, instance_feature, anchor) 137 | 138 | selected_feature = torch.cat([self.cached_feature, selected_feature], dim=1) 139 | selected_anchor = torch.cat([self.cached_anchor, selected_anchor], dim=1) 140 | 141 | instance_feature = torch.where(self.mask[:, None, None], selected_feature, instance_feature) 142 | anchor = torch.where(self.mask[:, None, None], selected_anchor, anchor) 143 | if self.instance_id is not None: 144 | self.instance_id = torch.where(self.mask[:, None], self.instance_id, self.instance_id.new_tensor(-1)) 145 | 146 | if num_dn > 0: 147 | instance_feature = torch.cat([instance_feature, dn_instance_feature], dim=1) 148 | anchor = torch.cat([anchor, dn_anchor], dim=1) 149 | 150 | return instance_feature, anchor 151 | 152 | def cache(self, instance_feature, anchor, confidence, metas=None, feature_maps=None): 153 | if self.num_temp_instances <= 0: 154 | return 155 | instance_feature = instance_feature.detach() 156 | anchor = anchor.detach() 157 | confidence = confidence.detach() 158 | 159 | self.metas = metas 160 | confidence = confidence.max(dim=-1).values.sigmoid() 161 | if self.confidence is not None: 162 | confidence[:, : self.num_temp_instances] = torch.maximum( 163 | self.confidence * self.confidence_decay, confidence[:, : self.num_temp_instances]) 164 | self.temp_confidence = confidence 165 | 166 | self.confidence, (self.cached_feature, self.cached_anchor) = topk( 167 | confidence, self.num_temp_instances, instance_feature, anchor) 168 | 169 | def get_instance_id(self, confidence, anchor=None, threshold=None): 170 | confidence = confidence.max(dim=-1).values.sigmoid() 171 | instance_id = confidence.new_full(confidence.shape, -1).long() 172 | 173 | if self.instance_id is not None and self.instance_id.shape[0] == instance_id.shape[0]: 174 | instance_id[:, : self.instance_id.shape[1]] = self.instance_id 175 | 176 | mask = instance_id < 0 177 | if threshold is not None: 178 | mask = mask & (confidence >= threshold) 179 | num_new_instance = mask.sum() 180 | new_ids = torch.arange(num_new_instance).to(instance_id) + self.prev_id 181 | instance_id[torch.where(mask)] = new_ids 182 | self.prev_id += num_new_instance 183 | self.update_instance_id(instance_id, confidence) 184 | return instance_id 185 | 186 | def update_instance_id(self, instance_id=None, confidence=None): 187 | if self.temp_confidence is None: 188 | if confidence.dim() == 3: # bs, num_anchor, num_cls 189 | temp_conf = confidence.max(dim=-1).values 190 | else: # bs, num_anchor 191 | temp_conf = confidence 192 | else: 193 | temp_conf = self.temp_confidence 194 | instance_id = topk(temp_conf, self.num_temp_instances, instance_id)[1][0] 195 | instance_id = instance_id.squeeze(dim=-1) 196 | self.instance_id = F.pad(instance_id, (0, self.num_anchor - self.num_temp_instances), value=-1) 197 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection3d/decoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from mmdet.core.bbox.builder import BBOX_CODERS 6 | from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy 7 | from projects.mmdet3d_plugin.core.box3d import * 8 | 9 | 10 | @BBOX_CODERS.register_module() 11 | class SparseBox3DDecoder(object): 12 | def __init__( 13 | self, 14 | num_output: int = 300, 15 | score_threshold: Optional[float] = None, 16 | sorted: bool = True, 17 | ): 18 | super(SparseBox3DDecoder, self).__init__() 19 | self.num_output = num_output 20 | self.score_threshold = score_threshold 21 | self.sorted = sorted 22 | 23 | def decode_box(self, box): 24 | yaw = torch.atan2(box[:, SIN_YAW], box[:, COS_YAW]) 25 | box = torch.cat( 26 | [ 27 | box[:, [X, Y, Z]], 28 | box[:, [W, L, H]].exp(), 29 | yaw[:, None], 30 | box[:, VX:], 31 | ], 32 | dim=-1, 33 | ) 34 | return box 35 | 36 | def decode_box2d(self, box, aug_config): 37 | crop = aug_config['crop'] 38 | scale_factor = aug_config['resize'] 39 | 40 | crop_img_size = (crop[2] - crop[0], crop[3] - crop[1]) 41 | 42 | box = bbox_cxcywh_to_xyxy(box) 43 | 44 | box[..., 0::2] = box[..., 0::2] * crop_img_size[0] 45 | box[..., 1::2] = box[..., 1::2] * crop_img_size[1] 46 | box[..., 0::2].clamp_(min=0, max=crop_img_size[0]) 47 | box[..., 1::2].clamp_(min=0, max=crop_img_size[1]) 48 | 49 | box[..., 1::2] += crop[1] 50 | box /= box.new_tensor(scale_factor) 51 | return box 52 | 53 | def decode( 54 | self, 55 | cls_scores, 56 | box_preds, 57 | instance_id=None, 58 | qulity=None, 59 | output_idx=-1, 60 | ): 61 | squeeze_cls = instance_id is not None 62 | 63 | cls_scores = cls_scores[output_idx].sigmoid() 64 | 65 | if squeeze_cls: 66 | cls_scores, cls_ids = cls_scores.max(dim=-1) 67 | cls_scores = cls_scores.unsqueeze(dim=-1) 68 | 69 | box_preds = box_preds[output_idx] 70 | bs, num_pred, num_cls = cls_scores.shape 71 | cls_scores, indices = cls_scores.flatten(start_dim=1).topk( 72 | self.num_output, dim=1, sorted=self.sorted 73 | ) 74 | if not squeeze_cls: 75 | cls_ids = indices % num_cls 76 | if self.score_threshold is not None: 77 | mask = cls_scores >= self.score_threshold 78 | 79 | if qulity is not None: 80 | centerness = qulity[output_idx][..., CNS] 81 | centerness = torch.gather(centerness, 1, indices // num_cls) 82 | cls_scores_origin = cls_scores.clone() 83 | cls_scores *= centerness.sigmoid() 84 | cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True) 85 | if not squeeze_cls: 86 | cls_ids = torch.gather(cls_ids, 1, idx) 87 | if self.score_threshold is not None: 88 | mask = torch.gather(mask, 1, idx) 89 | indices = torch.gather(indices, 1, idx) 90 | 91 | output = [] 92 | for i in range(bs): 93 | category_ids = cls_ids[i] 94 | if squeeze_cls: 95 | category_ids = category_ids[indices[i]] 96 | scores = cls_scores[i] 97 | box = box_preds[i, indices[i] // num_cls] 98 | if self.score_threshold is not None: 99 | category_ids = category_ids[mask[i]] 100 | scores = scores[mask[i]] 101 | box = box[mask[i]] 102 | if qulity is not None: 103 | scores_origin = cls_scores_origin[i] 104 | if self.score_threshold is not None: 105 | scores_origin = scores_origin[mask[i]] 106 | 107 | box = self.decode_box(box) 108 | output.append( 109 | { 110 | "boxes_3d": box.cpu(), 111 | "scores_3d": scores.cpu(), 112 | "labels_3d": category_ids.cpu(), 113 | } 114 | ) 115 | if qulity is not None: 116 | output[-1]["cls_scores"] = scores_origin.cpu() 117 | if instance_id is not None: 118 | ids = instance_id[i, indices[i]] 119 | if self.score_threshold is not None: 120 | ids = ids[mask[i]] 121 | output[-1]["instance_ids"] = ids 122 | return output 123 | 124 | def decode_with2d( 125 | self, 126 | cls_scores, 127 | box_preds, 128 | instance_id=None, 129 | qulity=None, 130 | output_idx=-1, 131 | cls_scores2d=None, 132 | box_preds2d=None, 133 | trans_matrix=None, 134 | query_groups=None, 135 | output_idx2d=-1, 136 | aug_configs=None, 137 | with_association=False 138 | ): 139 | squeeze_cls = instance_id is not None 140 | 141 | cls_scores = cls_scores[output_idx].sigmoid() 142 | 143 | if squeeze_cls: 144 | cls_scores, cls_ids = cls_scores.max(dim=-1) 145 | cls_scores = cls_scores.unsqueeze(dim=-1) 146 | 147 | box_preds = box_preds[output_idx] 148 | bs, num_pred, num_cls = cls_scores.shape 149 | cls_scores, indices = cls_scores.flatten(start_dim=1).topk( 150 | self.num_output, dim=1, sorted=self.sorted 151 | ) 152 | if not squeeze_cls: 153 | cls_ids = indices % num_cls 154 | if self.score_threshold is not None: 155 | mask = cls_scores >= self.score_threshold 156 | 157 | if qulity is not None: 158 | centerness = qulity[output_idx][..., CNS] 159 | centerness = torch.gather(centerness, 1, indices // num_cls) 160 | cls_scores_origin = cls_scores.clone() 161 | cls_scores *= centerness.sigmoid() 162 | cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True) 163 | if not squeeze_cls: 164 | cls_ids = torch.gather(cls_ids, 1, idx) 165 | if self.score_threshold is not None: 166 | mask = torch.gather(mask, 1, idx) 167 | indices = torch.gather(indices, 1, idx) 168 | 169 | cls_scores2d = cls_scores2d[output_idx2d] 170 | box_preds2d = box_preds2d[output_idx2d] 171 | trans_matrix_t = trans_matrix[output_idx2d].permute(0, 2, 1) 172 | query_groups = query_groups[output_idx2d] 173 | aug_config = aug_configs[0] 174 | 175 | output = [] 176 | for i in range(bs): 177 | category_ids = cls_ids[i] 178 | if squeeze_cls: 179 | category_ids = category_ids[indices[i]] 180 | scores = cls_scores[i] 181 | box = box_preds[i, indices[i] // num_cls] 182 | 183 | # 2d 184 | assert num_cls == 1 185 | 186 | # get indices2d and trans_t 187 | if with_association: 188 | trans_t = trans_matrix_t[i, indices[i]] 189 | indices2d = torch.where(trans_t.any(0))[0] 190 | trans_t = torch.index_select(trans_t, 1, indices2d).cpu() 191 | else: 192 | indices2d = torch.arange(len(box_preds2d[i])) 193 | trans_t = None 194 | 195 | # get new query_groups 196 | camidx_2d = [] 197 | query_groups_new = [] 198 | for cam_idx, qg in enumerate(query_groups): 199 | parts_index = torch.where(torch.logical_and(qg[0] <= indices2d, indices2d < qg[1]))[0] 200 | 201 | if len(parts_index) > 0: 202 | qg_new = (parts_index[0].cpu().item(), parts_index[-1].cpu().item() + 1) 203 | elif len(query_groups_new) > 0: 204 | qg_new = (query_groups_new[-1][-1], query_groups_new[-1][-1]) 205 | else: 206 | qg_new = (0, 0) 207 | 208 | camidx_2d.append(torch.ones((len(parts_index))) * cam_idx) 209 | query_groups_new.append(qg_new) 210 | 211 | camidx_2d = torch.cat(camidx_2d, dim=0) 212 | query_groups = query_groups_new 213 | 214 | # resize box to original img 215 | scores2d, category_ids2d = cls_scores2d[i, indices2d].sigmoid().max(dim=-1) 216 | box2d = box_preds2d[i, indices2d] 217 | box2d = self.decode_box2d(box2d, aug_config) 218 | 219 | if self.score_threshold is not None: 220 | category_ids = category_ids[mask[i]] 221 | scores = scores[mask[i]] 222 | box = box[mask[i]] 223 | if qulity is not None: 224 | scores_origin = cls_scores_origin[i] 225 | if self.score_threshold is not None: 226 | scores_origin = scores_origin[mask[i]] 227 | 228 | box = self.decode_box(box) 229 | 230 | output.append( 231 | { 232 | # 3d 233 | "boxes_3d": box.cpu(), 234 | "scores_3d": scores.cpu(), 235 | "labels_3d": category_ids.cpu(), 236 | # 2d 237 | "boxes_2d": box2d.cpu(), 238 | "scores_2d": scores2d.cpu(), 239 | "labels_2d": category_ids2d.cpu(), 240 | "camidx_2d": camidx_2d, 241 | "trans_matrix": trans_t, 242 | "query_groups": query_groups 243 | } 244 | ) 245 | if qulity is not None: 246 | output[-1]["cls_scores"] = scores_origin.cpu() 247 | if instance_id is not None: 248 | ids = instance_id[i, indices[i]] 249 | if self.score_threshold is not None: 250 | ids = ids[mask[i]] 251 | output[-1]["instance_ids"] = ids 252 | return output 253 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection2d/coster.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from scipy.optimize import linear_sum_assignment 4 | 5 | from mmdet.core.bbox.builder import BBOX_SAMPLERS 6 | from mmdet.core.bbox.builder import BBOX_ASSIGNERS 7 | from mmdet.core.bbox.assigners import HungarianAssigner 8 | from mmdet.core.bbox.assigners.assign_result import AssignResult 9 | from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh 10 | from mmdet.core.bbox.match_costs import build_match_cost 11 | 12 | __all__ = ["SparseBox2DCoster"] 13 | 14 | 15 | @BBOX_SAMPLERS.register_module() 16 | class SparseBox2DCoster(object): 17 | def __init__(self, 18 | eps=1e-12, 19 | cls_cost=None, 20 | reg_cost=None, 21 | iou_cost=None): 22 | super(SparseBox2DCoster, self).__init__() 23 | self.eps = eps 24 | self.cls_cost = build_match_cost(cls_cost) 25 | self.reg_cost = build_match_cost(reg_cost) 26 | self.iou_cost = build_match_cost(iou_cost) 27 | 28 | 29 | def cost(self, cls_pred, box_pred, cls_target, box_target, 30 | data, trans_shape=None, query_groups=None): 31 | bs, num_pred, num_cls = cls_pred.shape 32 | img_w, img_h = map(int, data['image_wh'][0, 0].tolist()) 33 | factor = box_pred[0].new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0) 34 | 35 | all_reg_weights = [] 36 | for i in range(bs): 37 | reg_weights = [] 38 | for j in range(len(query_groups)): 39 | weights = torch.logical_not(box_target[i][j].isnan()).to(dtype=box_target[i][j].dtype) 40 | reg_weights.append(weights) 41 | all_reg_weights.append(reg_weights) 42 | 43 | cls_cost = self._cls_cost(cls_pred, cls_target, query_groups) 44 | reg_cost = self._reg_cost(box_pred, box_target, query_groups, factor) 45 | iou_cost = self._iou_cost(box_pred, box_target, query_groups, factor) 46 | 47 | all_costs = [] 48 | for i in range(bs): 49 | costs = [] 50 | for j in range(len(query_groups)): 51 | if cls_cost[i][j] is not None and reg_cost[i][j] is not None and iou_cost[i][j] is not None: 52 | cost = (cls_cost[i][j] + reg_cost[i][j] + iou_cost[i][j]).detach().cpu().numpy() 53 | if cost.size > 0 and trans_shape is not None: 54 | cost[trans_shape[i][j]:, :] = cost.max() 55 | cost = np.where(np.isneginf(cost) | np.isnan(cost), 1e8, cost) 56 | costs.append(cost) 57 | else: 58 | costs.append(None) 59 | all_costs.append(costs) 60 | 61 | return all_costs 62 | 63 | def trans_cost(self, all_costs, pred_cls_2d, gt_cls_2d, gt_cls_3d, 64 | ref_trans_matrix, query_groups, gt_2d_3d_map): 65 | 66 | bs, num_query2d, num_cls2d = pred_cls_2d.shape 67 | num_query3d = ref_trans_matrix.shape[-1] 68 | 69 | cost2d_map3d_list = [] 70 | for i in range(bs): 71 | num_target2d = sum([len(x) for x in gt_cls_2d[i]]) 72 | num_target3d = len(gt_cls_3d[i]) 73 | 74 | if num_target2d >0 and num_target3d>0: 75 | # 1. extend cost2d to large map 76 | cost2d_extend = np.ones((num_query2d, num_target2d), dtype=np.float32) * (-1 / self.eps) 77 | 78 | gt_qg_shape = np.cumsum([0] + [len(x) for x in gt_cls_2d[i]]) 79 | gt_qg_groups = [(qg[0], qg[1]) for qg in zip(gt_qg_shape[:-1], gt_qg_shape[1:])] 80 | 81 | for j, (qg, gp) in enumerate(zip(query_groups, gt_qg_groups)): 82 | if all_costs[i][j] is not None: 83 | cost2d_extend[qg[0]:qg[1], gp[0]:gp[1]] = all_costs[i][j] 84 | 85 | if cost2d_extend.max() == (-1 / self.eps): 86 | cost2d_extend = 0 87 | 88 | cost2d_extend[cost2d_extend == (-1 / self.eps)] = cost2d_extend.max() 89 | 90 | # 2. trans cost2d_extend to cost2d_map3d 91 | map_trans_matrix = torch.zeros((num_target2d, num_target3d)) 92 | map_trans_matrix[torch.arange(num_target2d), torch.cat(gt_2d_3d_map[i]).cpu()] = 1 93 | ref_trans_matrix_t = ref_trans_matrix[i].cpu().T 94 | 95 | cost2d_extend = torch.from_numpy(cost2d_extend) 96 | cost2d_map3d = (cost2d_extend @ map_trans_matrix) / torch.clamp(map_trans_matrix.sum(0), 1e-5).unsqueeze(0) 97 | cost2d_map3d = (ref_trans_matrix_t @ cost2d_map3d) / torch.clamp(ref_trans_matrix_t.sum(-1), 1e-5).unsqueeze(-1) 98 | 99 | map_mask = torch.logical_or((cost2d_map3d.sum(0) == 0)[None, :], 100 | (cost2d_map3d.sum(1) == 0)[:, None]) 101 | 102 | cost2d_map3d[map_mask] = cost2d_map3d.max() 103 | cost2d_map3d_list.append(cost2d_map3d.numpy()) 104 | 105 | else: 106 | cost2d_map3d = np.zeros((num_query3d, num_target3d), dtype=np.float32) 107 | cost2d_map3d_list.append(cost2d_map3d) 108 | 109 | return cost2d_map3d_list 110 | 111 | 112 | def sample(self, cls_pred, box_pred, depth_pred, cls_target, box_target, depth_target, 113 | data, trans_matrix = None, query_groups = None, cost_list=None, 114 | alpha=None, alpha_targets=None): 115 | 116 | bs, num_pred, num_cls = cls_pred.shape 117 | 118 | all_reg_weights = [] 119 | for i in range(bs): 120 | reg_weights = [] 121 | for j in range(len(query_groups)): 122 | weights = torch.logical_not(box_target[i][j].isnan()).to(dtype=box_target[i][j].dtype) 123 | reg_weights.append(weights) 124 | all_reg_weights.append(reg_weights) 125 | 126 | 127 | all_indices = [] 128 | for i in range(bs): 129 | indices = [] 130 | for j in range(len(query_groups)): 131 | if cost_list[i][j] is not None and cost_list[i][j].size>0: 132 | cost = cost_list[i][j] 133 | cost = np.where(np.isneginf(cost) | np.isnan(cost), 1e8, cost) 134 | indices.append( 135 | [ 136 | cls_pred.new_tensor(x, dtype=torch.int64) 137 | for x in linear_sum_assignment(cost) 138 | ] 139 | ) 140 | else: 141 | indices.append(None) 142 | all_indices.append(indices) 143 | 144 | output_cls_target = cls_target[0][0].new_ones([bs, num_pred], dtype=torch.long) * -1 145 | output_box_target = box_pred.new_zeros(box_pred.shape) 146 | output_reg_weights = box_pred.new_zeros(box_pred.shape) 147 | 148 | output_box_depth = None 149 | if depth_pred is not None: 150 | output_box_depth = box_pred.new_zeros(box_pred[..., 1].shape) 151 | 152 | output_alpha_target = None 153 | if alpha_targets is not None and alpha is not None: 154 | output_alpha_target = alpha.new_zeros(alpha.shape) 155 | 156 | for i in range(bs): 157 | for j, qg in enumerate(query_groups): 158 | if len(cls_target[i][j]) > 0 and all_indices[i][j] is not None: 159 | pred_idx, target_idx = all_indices[i][j] 160 | 161 | output_cls_target[i, pred_idx + qg[0]] = cls_target[i][j][target_idx] 162 | output_box_target[i, pred_idx + qg[0]] = box_target[i][j][target_idx] 163 | output_reg_weights[i, pred_idx + qg[0]] = all_reg_weights[i][j][target_idx] 164 | 165 | if alpha_targets is not None: 166 | assigned_alpha = alpha_targets[i][j][target_idx] 167 | if len(assigned_alpha.shape) == 1: # for scalar alpha, not multibin 168 | assigned_alpha_target = torch.stack((torch.sin(assigned_alpha), torch.cos(assigned_alpha))).transpose(1,0)#view(-1, 2) 169 | 170 | if output_alpha_target is not None: 171 | output_alpha_target[i, pred_idx + qg[0]] = assigned_alpha_target 172 | 173 | if depth_pred is not None: 174 | output_box_depth[i, pred_idx + qg[0]] = depth_target[i][j][target_idx] 175 | 176 | return output_cls_target, output_box_target, output_alpha_target, output_box_depth, output_reg_weights 177 | 178 | 179 | def _cls_cost(self, cls_pred, cls_target, query_groups): 180 | bs = cls_pred.shape[0] 181 | all_cost = [] 182 | for i in range(bs): 183 | cost = [] 184 | for j, qg in enumerate(query_groups): 185 | if len(cls_target[i][j]) > 0: 186 | cls_cost = self.cls_cost(cls_pred[i][qg[0]:qg[1]], 187 | cls_target[i][j]) 188 | cost.append(cls_cost) 189 | else: 190 | cost.append(None) 191 | all_cost.append(cost) 192 | return all_cost 193 | 194 | 195 | def _reg_cost(self, box_pred, box_target, query_groups, factor): 196 | bs = box_pred.shape[0] 197 | all_cost = [] 198 | for i in range(bs): 199 | cost = [] 200 | for j, qg in enumerate(query_groups): 201 | if len(box_target[i][j]) > 0: 202 | # normalized resolution 203 | reg_cost = self.reg_cost(box_pred[i][qg[0]:qg[1]], 204 | box_target[i][j][:, 0:4] / factor) 205 | cost.append(reg_cost) 206 | else: 207 | cost.append(None) 208 | all_cost.append(cost) 209 | return all_cost 210 | 211 | 212 | def _iou_cost(self, box_pred, box_target, query_groups, factor): 213 | bs = box_pred.shape[0] 214 | all_cost = [] 215 | for i in range(bs): 216 | cost = [] 217 | for j, qg in enumerate(query_groups): 218 | if len(box_target[i][j])>0: 219 | # original resolution 220 | iou_cost = self.iou_cost(bbox_cxcywh_to_xyxy(box_pred[i][qg[0]:qg[1]]) * factor, 221 | box_target[i][j][:, 0:4]) 222 | cost.append(iou_cost) 223 | else: 224 | cost.append(None) 225 | all_cost.append(cost) 226 | return all_cost 227 | 228 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/datasets/pipelines/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mmcv 3 | import torch 4 | from mmcv.parallel import DataContainer as DC 5 | from mmdet.datasets.builder import PIPELINES 6 | from mmdet.datasets.pipelines import to_tensor 7 | 8 | def filter_info2d(input_dict, mask): 9 | if isinstance(mask, torch.Tensor): 10 | mask = mask.numpy() 11 | 12 | maskes_2d = [np.ones(len(x), dtype=np.bool) for x in input_dict['gt_bboxes_2d']] 13 | for cam_idx, map_2d_3d in enumerate(input_dict['gt_2d_3d_map']): 14 | for index_2d, index_3d in enumerate(map_2d_3d): 15 | if index_3d in np.where(~mask)[0]: 16 | maskes_2d[cam_idx][index_2d] = False 17 | 18 | trans_index = np.ones((len(mask)), dtype=np.long) * -1 19 | trans_index[mask] = 1 20 | trans_index[mask] = np.cumsum(trans_index[mask]) - 1 21 | trans_index = np.concatenate([trans_index, -np.ones(1, dtype=np.long)]) 22 | 23 | for cam_idx, mask_2d in enumerate(maskes_2d): 24 | input_dict['gt_bboxes_2d'][cam_idx] = input_dict['gt_bboxes_2d'][cam_idx][mask_2d] 25 | input_dict['gt_labels_2d'][cam_idx] = input_dict['gt_labels_2d'][cam_idx][mask_2d] 26 | input_dict['gt_centers_2d'][cam_idx] = input_dict['gt_centers_2d'][cam_idx][mask_2d] 27 | input_dict['gt_depths_2d'][cam_idx] = input_dict['gt_depths_2d'][cam_idx][mask_2d] 28 | input_dict['gt_alphas_2d'][cam_idx] = input_dict['gt_alphas_2d'][cam_idx][mask_2d] 29 | input_dict['gt_2d_3d_map'][cam_idx] = trans_index[input_dict['gt_2d_3d_map'][cam_idx][mask_2d]] 30 | 31 | return input_dict 32 | 33 | 34 | @PIPELINES.register_module() 35 | class MultiScaleDepthMapGenerator(object): 36 | def __init__(self, downsample=1, max_depth=60): 37 | if not isinstance(downsample, (list, tuple)): 38 | downsample = [downsample] 39 | self.downsample = downsample 40 | self.max_depth = max_depth 41 | 42 | def __call__(self, input_dict): 43 | points = input_dict["points"][..., :3, None] 44 | gt_depth = [] 45 | for i, lidar2img in enumerate(input_dict["lidar2img"]): 46 | H, W = input_dict["img_shape"][i][:2] 47 | 48 | pts_2d = ( 49 | np.squeeze(lidar2img[:3, :3] @ points, axis=-1) 50 | + lidar2img[:3, 3] 51 | ) 52 | pts_2d[:, :2] /= pts_2d[:, 2:3] 53 | U = np.round(pts_2d[:, 0]).astype(np.int32) 54 | V = np.round(pts_2d[:, 1]).astype(np.int32) 55 | depths = pts_2d[:, 2] 56 | mask = np.logical_and.reduce( 57 | [ 58 | V >= 0, 59 | V < H, 60 | U >= 0, 61 | U < W, 62 | depths >= 0.1, 63 | # depths <= self.max_depth, 64 | ] 65 | ) 66 | V, U, depths = V[mask], U[mask], depths[mask] 67 | sort_idx = np.argsort(depths)[::-1] 68 | V, U, depths = V[sort_idx], U[sort_idx], depths[sort_idx] 69 | depths = np.clip(depths, 0.1, self.max_depth) 70 | for j, downsample in enumerate(self.downsample): 71 | if len(gt_depth) < j + 1: 72 | gt_depth.append([]) 73 | h, w = (int(H / downsample), int(W / downsample)) 74 | u = np.floor(U / downsample).astype(np.int32) 75 | v = np.floor(V / downsample).astype(np.int32) 76 | depth_map = np.ones([h, w], dtype=np.float32) * -1 77 | depth_map[v, u] = depths 78 | gt_depth[j].append(depth_map) 79 | 80 | input_dict["gt_depth"] = [np.stack(x) for x in gt_depth] 81 | return input_dict 82 | 83 | 84 | @PIPELINES.register_module() 85 | class NuScenesSparse4DAdaptor(object): 86 | def __init(self): 87 | pass 88 | 89 | def __call__(self, input_dict): 90 | input_dict["projection_mat"] = np.float32(np.stack(input_dict["lidar2img"])) 91 | input_dict["image_wh"] = np.ascontiguousarray( 92 | np.array(input_dict["img_shape"], dtype=np.float32)[:, :2][:, ::-1]) 93 | input_dict["T_global_inv"] = np.linalg.inv(input_dict["lidar2global"]) 94 | input_dict["T_global"] = input_dict["lidar2global"] 95 | 96 | if "cam_intrinsic" in input_dict: 97 | input_dict["cam_intrinsic"] = np.float32( 98 | np.stack(input_dict["cam_intrinsic"]) 99 | ) 100 | input_dict["focal"] = input_dict["cam_intrinsic"][..., 0, 0] 101 | # input_dict["focal"] = np.sqrt( 102 | # np.abs(np.linalg.det(input_dict["cam_intrinsic"][:, :2, :2])) 103 | # ) 104 | if "instance_inds" in input_dict: 105 | input_dict["instance_id"] = input_dict["instance_inds"] 106 | 107 | if "gt_bboxes_3d" in input_dict: 108 | input_dict["gt_bboxes_3d"][:, 6] = self.limit_period( 109 | input_dict["gt_bboxes_3d"][:, 6], offset=0.5, period=2 * np.pi 110 | ) 111 | input_dict["gt_bboxes_3d"] = DC( 112 | to_tensor(input_dict["gt_bboxes_3d"]).float() 113 | ) 114 | if "gt_labels_3d" in input_dict: 115 | input_dict["gt_labels_3d"] = DC( 116 | to_tensor(input_dict["gt_labels_3d"]).long() 117 | ) 118 | 119 | for key in ['gt_bboxes_2d', 'gt_labels_2d', 'gt_centers_2d', 120 | 'gt_depths_2d', 'gt_2d_3d_map', 'gt_alphas_2d']: 121 | if key not in input_dict: 122 | continue 123 | if isinstance(input_dict[key], list): 124 | input_dict[key] = DC([to_tensor(res) for res in input_dict[key]]) 125 | else: 126 | input_dict[key] = DC(to_tensor(input_dict[key])) 127 | 128 | imgs = [img.transpose(2, 0, 1) for img in input_dict["img"]] 129 | imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) 130 | input_dict["img"] = DC(to_tensor(imgs), stack=True) 131 | 132 | input_dict["intrinsics"] = np.float32(np.stack(input_dict["intrinsics"])) 133 | input_dict["extrinsics"] = np.float32(np.stack(input_dict["extrinsics"])) 134 | 135 | return input_dict 136 | 137 | def limit_period( 138 | self, val: np.ndarray, offset: float = 0.5, period: float = np.pi 139 | ) -> np.ndarray: 140 | limited_val = val - np.floor(val / period + offset) * period 141 | return limited_val 142 | 143 | 144 | @PIPELINES.register_module() 145 | class InstanceNameFilter(object): 146 | """Filter GT objects by their names. 147 | 148 | Args: 149 | classes (list[str]): List of class names to be kept for training. 150 | """ 151 | 152 | def __init__(self, classes): 153 | self.classes = classes 154 | self.labels = list(range(len(self.classes))) 155 | 156 | def __call__(self, input_dict): 157 | """Call function to filter objects by their names. 158 | 159 | Args: 160 | input_dict (dict): Result dict from loading pipeline. 161 | 162 | Returns: 163 | dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \ 164 | keys are updated in the result dict. 165 | """ 166 | gt_labels_3d = input_dict["gt_labels_3d"] 167 | gt_bboxes_mask = np.array([n in self.labels for n in gt_labels_3d], dtype=np.bool_) 168 | 169 | input_dict["gt_bboxes_3d"] = input_dict["gt_bboxes_3d"][gt_bboxes_mask] 170 | input_dict["gt_labels_3d"] = input_dict["gt_labels_3d"][gt_bboxes_mask] 171 | if "instance_inds" in input_dict: 172 | input_dict["instance_inds"] = input_dict["instance_inds"][gt_bboxes_mask] 173 | 174 | # update 2d labels 175 | if 'gt_bboxes_2d' in input_dict: 176 | input_dict = filter_info2d(input_dict, gt_bboxes_mask) 177 | 178 | return input_dict 179 | 180 | def __repr__(self): 181 | """str: Return a string that describes the module.""" 182 | repr_str = self.__class__.__name__ 183 | repr_str += f"(classes={self.classes})" 184 | return repr_str 185 | 186 | 187 | @PIPELINES.register_module() 188 | class CircleObjectRangeFilter(object): 189 | def __init__( 190 | self, class_dist_thred=[52.5] * 5 + [31.5] + [42] * 3 + [31.5] 191 | ): 192 | self.class_dist_thred = class_dist_thred 193 | 194 | def __call__(self, input_dict): 195 | gt_bboxes_3d = input_dict["gt_bboxes_3d"] 196 | gt_labels_3d = input_dict["gt_labels_3d"] 197 | dist = np.sqrt( 198 | np.sum(gt_bboxes_3d[:, :2] ** 2, axis=-1) 199 | ) 200 | mask = np.array([False] * len(dist)) 201 | for label_idx, dist_thred in enumerate(self.class_dist_thred): 202 | mask = np.logical_or( 203 | mask, 204 | np.logical_and(gt_labels_3d == label_idx, dist <= dist_thred), 205 | ) 206 | 207 | gt_bboxes_3d = gt_bboxes_3d[mask] 208 | gt_labels_3d = gt_labels_3d[mask] 209 | 210 | # update 2d labels 211 | if 'gt_bboxes_2d' in input_dict: 212 | input_dict = filter_info2d(input_dict, mask) 213 | 214 | input_dict["gt_bboxes_3d"] = gt_bboxes_3d 215 | input_dict["gt_labels_3d"] = gt_labels_3d 216 | 217 | if "instance_inds" in input_dict: 218 | input_dict["instance_inds"] = input_dict["instance_inds"][mask] 219 | 220 | return input_dict 221 | 222 | def __repr__(self): 223 | """str: Return a string that describes the module.""" 224 | repr_str = self.__class__.__name__ 225 | repr_str += f"(class_dist_thred={self.class_dist_thred})" 226 | return repr_str 227 | 228 | 229 | @PIPELINES.register_module() 230 | class NormalizeMultiviewImage(object): 231 | """Normalize the image. 232 | Added key is "img_norm_cfg". 233 | Args: 234 | mean (sequence): Mean values of 3 channels. 235 | std (sequence): Std values of 3 channels. 236 | to_rgb (bool): Whether to convert the image from BGR to RGB, 237 | default is true. 238 | """ 239 | 240 | def __init__(self, mean, std, to_rgb=True): 241 | self.mean = np.array(mean, dtype=np.float32) 242 | self.std = np.array(std, dtype=np.float32) 243 | self.to_rgb = to_rgb 244 | 245 | def __call__(self, results): 246 | """Call function to normalize images. 247 | Args: 248 | results (dict): Result dict from loading pipeline. 249 | Returns: 250 | dict: Normalized results, 'img_norm_cfg' key is added into 251 | result dict. 252 | """ 253 | results["img"] = [ 254 | mmcv.imnormalize(img, self.mean, self.std, self.to_rgb) 255 | for img in results["img"] 256 | ] 257 | results["img_norm_cfg"] = dict( 258 | mean=self.mean, std=self.std, to_rgb=self.to_rgb 259 | ) 260 | return results 261 | 262 | def __repr__(self): 263 | repr_str = self.__class__.__name__ 264 | repr_str += f"(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})" 265 | return repr_str 266 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/models/detection3d/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from mmcv.cnn import Linear, Scale, bias_init_with_prob 6 | from mmcv.runner.base_module import Sequential, BaseModule 7 | from mmcv.cnn import xavier_init 8 | from mmcv.cnn.bricks.registry import ( 9 | PLUGIN_LAYERS, 10 | POSITIONAL_ENCODING, 11 | ) 12 | 13 | from projects.mmdet3d_plugin.core.box3d import * 14 | from ..blocks import linear_relu_ln 15 | 16 | __all__ = [ 17 | "SparseBox3DRefinementModule", 18 | "SparseBox3DKeyPointsGenerator", 19 | "SparseBox3DEncoder", 20 | ] 21 | 22 | 23 | @POSITIONAL_ENCODING.register_module() 24 | class SparseBox3DEncoder(BaseModule): 25 | def __init__( 26 | self, 27 | embed_dims, 28 | vel_dims=3, 29 | mode="add", 30 | output_fc=True, 31 | in_loops=1, 32 | out_loops=2, 33 | ): 34 | super().__init__() 35 | assert mode in ["add", "cat"] 36 | self.embed_dims = embed_dims 37 | self.vel_dims = vel_dims 38 | self.mode = mode 39 | 40 | def embedding_layer(input_dims, output_dims): 41 | return nn.Sequential( 42 | *linear_relu_ln(output_dims, in_loops, out_loops, input_dims) 43 | ) 44 | 45 | if not isinstance(embed_dims, (list, tuple)): 46 | embed_dims = [embed_dims] * 5 47 | self.pos_fc = embedding_layer(3, embed_dims[0]) 48 | self.size_fc = embedding_layer(3, embed_dims[1]) 49 | self.yaw_fc = embedding_layer(2, embed_dims[2]) 50 | if vel_dims > 0: 51 | self.vel_fc = embedding_layer(self.vel_dims, embed_dims[3]) 52 | if output_fc: 53 | self.output_fc = embedding_layer(embed_dims[-1], embed_dims[-1]) 54 | else: 55 | self.output_fc = None 56 | 57 | def forward(self, box_3d: torch.Tensor): 58 | pos_feat = self.pos_fc(box_3d[..., [X, Y, Z]]) 59 | size_feat = self.size_fc(box_3d[..., [W, L, H]]) 60 | yaw_feat = self.yaw_fc(box_3d[..., [SIN_YAW, COS_YAW]]) 61 | if self.mode == "add": 62 | output = pos_feat + size_feat + yaw_feat 63 | elif self.mode == "cat": 64 | output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1) 65 | 66 | if self.vel_dims > 0: 67 | vel_feat = self.vel_fc(box_3d[..., VX : VX + self.vel_dims]) 68 | if self.mode == "add": 69 | output = output + vel_feat 70 | elif self.mode == "cat": 71 | output = torch.cat([output, vel_feat], dim=-1) 72 | if self.output_fc is not None: 73 | output = self.output_fc(output) 74 | return output 75 | 76 | 77 | @PLUGIN_LAYERS.register_module() 78 | class SparseBox3DRefinementModule(BaseModule): 79 | def __init__( 80 | self, 81 | embed_dims=256, 82 | output_dim=11, 83 | num_cls=10, 84 | normalize_yaw=False, 85 | refine_yaw=False, 86 | with_cls_branch=True, 87 | with_quality_estimation=False, 88 | ): 89 | super(SparseBox3DRefinementModule, self).__init__() 90 | self.embed_dims = embed_dims 91 | self.output_dim = output_dim 92 | self.num_cls = num_cls 93 | self.normalize_yaw = normalize_yaw 94 | self.refine_yaw = refine_yaw 95 | 96 | self.refine_state = [X, Y, Z, W, L, H] 97 | if self.refine_yaw: 98 | self.refine_state += [SIN_YAW, COS_YAW] 99 | 100 | self.layers = nn.Sequential( 101 | *linear_relu_ln(embed_dims, 2, 2), 102 | Linear(self.embed_dims, self.output_dim), 103 | Scale([1.0] * self.output_dim), 104 | ) 105 | self.with_cls_branch = with_cls_branch 106 | if with_cls_branch: 107 | self.cls_layers = nn.Sequential( 108 | *linear_relu_ln(embed_dims, 1, 2), 109 | Linear(self.embed_dims, self.num_cls), 110 | ) 111 | self.with_quality_estimation = with_quality_estimation 112 | if with_quality_estimation: 113 | self.quality_layers = nn.Sequential( 114 | *linear_relu_ln(embed_dims, 1, 2), 115 | Linear(self.embed_dims, 2), 116 | ) 117 | 118 | def init_weight(self): 119 | if self.with_cls_branch: 120 | bias_init = bias_init_with_prob(0.01) 121 | nn.init.constant_(self.cls_layers[-1].bias, bias_init) 122 | 123 | def forward( 124 | self, 125 | instance_feature: torch.Tensor, 126 | anchor: torch.Tensor, 127 | anchor_embed: torch.Tensor, 128 | time_interval: torch.Tensor = 1.0, 129 | return_cls=True, 130 | ): 131 | feature = instance_feature + anchor_embed 132 | output = self.layers(feature) 133 | output[..., self.refine_state] = output[..., self.refine_state] + anchor[..., self.refine_state] 134 | 135 | if self.normalize_yaw: 136 | output[..., [SIN_YAW, COS_YAW]] = torch.nn.functional.normalize(output[..., [SIN_YAW, COS_YAW]], dim=-1) 137 | 138 | if self.output_dim > 8: 139 | if not isinstance(time_interval, torch.Tensor): 140 | time_interval = instance_feature.new_tensor(time_interval) 141 | translation = torch.transpose(output[..., VX:], 0, -1) 142 | velocity = torch.transpose(translation / time_interval, 0, -1) 143 | output[..., VX:] = velocity + anchor[..., VX:] 144 | 145 | if return_cls: 146 | assert self.with_cls_branch, "Without classification layers !!!" 147 | cls = self.cls_layers(instance_feature) 148 | else: 149 | cls = None 150 | if return_cls and self.with_quality_estimation: 151 | quality = self.quality_layers(feature) 152 | else: 153 | quality = None 154 | return output, cls, quality 155 | 156 | 157 | @PLUGIN_LAYERS.register_module() 158 | class SparseBox3DKeyPointsGenerator(BaseModule): 159 | def __init__( 160 | self, 161 | embed_dims=256, 162 | num_learnable_pts=0, 163 | fix_scale=None, 164 | ): 165 | super(SparseBox3DKeyPointsGenerator, self).__init__() 166 | self.embed_dims = embed_dims 167 | self.num_learnable_pts = num_learnable_pts 168 | if fix_scale is None: 169 | fix_scale = ((0.0, 0.0, 0.0),) 170 | self.fix_scale = nn.Parameter( 171 | torch.tensor(fix_scale), requires_grad=False 172 | ) 173 | self.num_pts = len(self.fix_scale) + num_learnable_pts 174 | if num_learnable_pts > 0: 175 | self.learnable_fc = Linear(self.embed_dims, num_learnable_pts * 3) 176 | 177 | def init_weight(self): 178 | if self.num_learnable_pts > 0: 179 | xavier_init(self.learnable_fc, distribution="uniform", bias=0.0) 180 | 181 | def forward( 182 | self, 183 | anchor, 184 | instance_feature=None, 185 | T_cur2temp_list=None, 186 | cur_timestamp=None, 187 | temp_timestamps=None, 188 | ): 189 | bs, num_anchor = anchor.shape[:2] 190 | size = anchor[..., None, [W, L, H]].exp() 191 | key_points = self.fix_scale * size 192 | if self.num_learnable_pts > 0 and instance_feature is not None: 193 | learnable_scale = ( 194 | self.learnable_fc(instance_feature) 195 | .reshape(bs, num_anchor, self.num_learnable_pts, 3) 196 | .sigmoid() 197 | - 0.5 198 | ) 199 | key_points = torch.cat( 200 | [key_points, learnable_scale * size], dim=-2 201 | ) 202 | 203 | rotation_mat = anchor.new_zeros([bs, num_anchor, 3, 3]) 204 | 205 | rotation_mat[:, :, 0, 0] = anchor[:, :, COS_YAW] 206 | rotation_mat[:, :, 0, 1] = -anchor[:, :, SIN_YAW] 207 | rotation_mat[:, :, 1, 0] = anchor[:, :, SIN_YAW] 208 | rotation_mat[:, :, 1, 1] = anchor[:, :, COS_YAW] 209 | rotation_mat[:, :, 2, 2] = 1 210 | 211 | key_points = torch.matmul( 212 | rotation_mat[:, :, None], key_points[..., None] 213 | ).squeeze(-1) 214 | key_points = key_points + anchor[..., None, [X, Y, Z]] 215 | 216 | if ( 217 | cur_timestamp is None 218 | or temp_timestamps is None 219 | or T_cur2temp_list is None 220 | or len(temp_timestamps) == 0 221 | ): 222 | return key_points 223 | 224 | temp_key_points_list = [] 225 | velocity = anchor[..., VX:] 226 | for i, t_time in enumerate(temp_timestamps): 227 | time_interval = cur_timestamp - t_time 228 | translation = ( 229 | velocity 230 | * time_interval.to(dtype=velocity.dtype)[:, None, None] 231 | ) 232 | temp_key_points = key_points - translation[:, :, None] 233 | T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype) 234 | temp_key_points = ( 235 | T_cur2temp[:, None, None, :3] 236 | @ torch.cat( 237 | [ 238 | temp_key_points, 239 | torch.ones_like(temp_key_points[..., :1]), 240 | ], 241 | dim=-1, 242 | ).unsqueeze(-1) 243 | ) 244 | temp_key_points = temp_key_points.squeeze(-1) 245 | temp_key_points_list.append(temp_key_points) 246 | return key_points, temp_key_points_list 247 | 248 | @staticmethod 249 | def anchor_projection(anchor, T_src2dst_list, src_timestamp=None, dst_timestamps=None, time_intervals=None): 250 | dst_anchors = [] 251 | for i in range(len(T_src2dst_list)): 252 | vel = anchor[..., VX:] 253 | vel_dim = vel.shape[-1] 254 | T_src2dst = torch.unsqueeze(T_src2dst_list[i].to(dtype=anchor.dtype), dim=1) 255 | 256 | center = anchor[..., [X, Y, Z]] 257 | if time_intervals is not None: 258 | time_interval = time_intervals[i] 259 | elif src_timestamp is not None and dst_timestamps is not None: 260 | time_interval = (src_timestamp - dst_timestamps[i]).to(dtype=vel.dtype) 261 | else: 262 | time_interval = None 263 | 264 | if time_interval is not None: 265 | translation = vel.transpose(0, -1) * time_interval 266 | translation = translation.transpose(0, -1) 267 | center = center - translation 268 | 269 | center = torch.matmul(T_src2dst[..., :3, :3], center[..., None]).squeeze(dim=-1) + T_src2dst[..., :3, 3] 270 | size = anchor[..., [W, L, H]] 271 | yaw = torch.matmul(T_src2dst[..., :2, :2], anchor[..., [COS_YAW, SIN_YAW], None],).squeeze(-1) 272 | vel = torch.matmul(T_src2dst[..., :vel_dim, :vel_dim], vel[..., None]).squeeze(-1) 273 | dst_anchor = torch.cat([center, size, yaw, vel], dim=-1) 274 | # TODO: Fix bug 275 | # index = [X, Y, Z, W, L, H, COS_YAW, SIN_YAW] + [VX, VY, VZ][:vel_dim] 276 | # index = torch.tensor(index, device=dst_anchor.device) 277 | # index = torch.argsort(index) 278 | # dst_anchor = dst_anchor.index_select(dim=-1, index=index) 279 | dst_anchors.append(dst_anchor) 280 | return dst_anchors 281 | 282 | @staticmethod 283 | def distance(anchor): 284 | return torch.norm(anchor[..., :2], p=2, dim=-1) 285 | -------------------------------------------------------------------------------- /projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | #include 5 | #include 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | 12 | 13 | __device__ float bilinear_sampling( 14 | const float *&bottom_data, const int &height, const int &width, 15 | const int &num_embeds, const float &h_im, const float &w_im, 16 | const int &base_ptr 17 | ) { 18 | const int h_low = floorf(h_im); 19 | const int w_low = floorf(w_im); 20 | const int h_high = h_low + 1; 21 | const int w_high = w_low + 1; 22 | 23 | const float lh = h_im - h_low; 24 | const float lw = w_im - w_low; 25 | const float hh = 1 - lh, hw = 1 - lw; 26 | 27 | const int w_stride = num_embeds; 28 | const int h_stride = width * w_stride; 29 | const int h_low_ptr_offset = h_low * h_stride; 30 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride; 31 | const int w_low_ptr_offset = w_low * w_stride; 32 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride; 33 | 34 | float v1 = 0; 35 | if (h_low >= 0 && w_low >= 0) { 36 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; 37 | v1 = bottom_data[ptr1]; 38 | } 39 | float v2 = 0; 40 | if (h_low >= 0 && w_high <= width - 1) { 41 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; 42 | v2 = bottom_data[ptr2]; 43 | } 44 | float v3 = 0; 45 | if (h_high <= height - 1 && w_low >= 0) { 46 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; 47 | v3 = bottom_data[ptr3]; 48 | } 49 | float v4 = 0; 50 | if (h_high <= height - 1 && w_high <= width - 1) { 51 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; 52 | v4 = bottom_data[ptr4]; 53 | } 54 | 55 | const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 56 | 57 | const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 58 | return val; 59 | } 60 | 61 | 62 | __device__ void bilinear_sampling_grad( 63 | const float *&bottom_data, const float &weight, 64 | const int &height, const int &width, 65 | const int &num_embeds, const float &h_im, const float &w_im, 66 | const int &base_ptr, 67 | const float &grad_output, 68 | float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) { 69 | const int h_low = floorf(h_im); 70 | const int w_low = floorf(w_im); 71 | const int h_high = h_low + 1; 72 | const int w_high = w_low + 1; 73 | 74 | const float lh = h_im - h_low; 75 | const float lw = w_im - w_low; 76 | const float hh = 1 - lh, hw = 1 - lw; 77 | 78 | const int w_stride = num_embeds; 79 | const int h_stride = width * w_stride; 80 | const int h_low_ptr_offset = h_low * h_stride; 81 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride; 82 | const int w_low_ptr_offset = w_low * w_stride; 83 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride; 84 | 85 | const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 86 | const float top_grad_mc_ms_feat = grad_output * weight; 87 | float grad_h_weight = 0, grad_w_weight = 0; 88 | 89 | float v1 = 0; 90 | if (h_low >= 0 && w_low >= 0) { 91 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; 92 | v1 = bottom_data[ptr1]; 93 | grad_h_weight -= hw * v1; 94 | grad_w_weight -= hh * v1; 95 | atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat); 96 | } 97 | float v2 = 0; 98 | if (h_low >= 0 && w_high <= width - 1) { 99 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; 100 | v2 = bottom_data[ptr2]; 101 | grad_h_weight -= lw * v2; 102 | grad_w_weight += hh * v2; 103 | atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat); 104 | } 105 | float v3 = 0; 106 | if (h_high <= height - 1 && w_low >= 0) { 107 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; 108 | v3 = bottom_data[ptr3]; 109 | grad_h_weight += hw * v3; 110 | grad_w_weight -= lh * v3; 111 | atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat); 112 | } 113 | float v4 = 0; 114 | if (h_high <= height - 1 && w_high <= width - 1) { 115 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; 116 | v4 = bottom_data[ptr4]; 117 | grad_h_weight += lw * v4; 118 | grad_w_weight += lh * v4; 119 | atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat); 120 | } 121 | 122 | const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 123 | atomicAdd(grad_weights, grad_output * val); 124 | atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat); 125 | atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat); 126 | } 127 | 128 | 129 | __global__ void deformable_aggregation_kernel( 130 | const int num_kernels, 131 | float* output, 132 | const float* mc_ms_feat, 133 | const int* spatial_shape, 134 | const int* scale_start_index, 135 | const float* sample_location, 136 | const float* weights, 137 | int batch_size, 138 | int num_cams, 139 | int num_feat, 140 | int num_embeds, 141 | int num_scale, 142 | int num_anchors, 143 | int num_pts, 144 | int num_groups 145 | ) { 146 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 147 | if (idx >= num_kernels) return; 148 | 149 | const float weight = *(weights + idx / (num_embeds / num_groups)); 150 | const int channel_index = idx % num_embeds; 151 | idx /= num_embeds; 152 | const int scale_index = idx % num_scale; 153 | idx /= num_scale; 154 | 155 | const int cam_index = idx % num_cams; 156 | idx /= num_cams; 157 | const int pts_index = idx % num_pts; 158 | idx /= num_pts; 159 | 160 | int anchor_index = idx % num_anchors; 161 | idx /= num_anchors; 162 | const int batch_index = idx % batch_size; 163 | idx /= batch_size; 164 | 165 | anchor_index = batch_index * num_anchors + anchor_index; 166 | const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1; 167 | 168 | const float loc_w = sample_location[loc_offset]; 169 | if (loc_w <= 0 || loc_w >= 1) return; 170 | const float loc_h = sample_location[loc_offset + 1]; 171 | if (loc_h <= 0 || loc_h >= 1) return; 172 | 173 | int cam_scale_index = cam_index * num_scale + scale_index; 174 | const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index; 175 | 176 | cam_scale_index = cam_scale_index << 1; 177 | const int h = spatial_shape[cam_scale_index]; 178 | const int w = spatial_shape[cam_scale_index + 1]; 179 | 180 | const float h_im = loc_h * h - 0.5; 181 | const float w_im = loc_w * w - 0.5; 182 | 183 | atomicAdd( 184 | output + anchor_index * num_embeds + channel_index, 185 | bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight 186 | ); 187 | } 188 | 189 | 190 | __global__ void deformable_aggregation_grad_kernel( 191 | const int num_kernels, 192 | const float* mc_ms_feat, 193 | const int* spatial_shape, 194 | const int* scale_start_index, 195 | const float* sample_location, 196 | const float* weights, 197 | const float* grad_output, 198 | float* grad_mc_ms_feat, 199 | float* grad_sampling_location, 200 | float* grad_weights, 201 | int batch_size, 202 | int num_cams, 203 | int num_feat, 204 | int num_embeds, 205 | int num_scale, 206 | int num_anchors, 207 | int num_pts, 208 | int num_groups 209 | ) { 210 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 211 | if (idx >= num_kernels) return; 212 | 213 | const int weights_ptr = idx / (num_embeds / num_groups); 214 | const int channel_index = idx % num_embeds; 215 | idx /= num_embeds; 216 | const int scale_index = idx % num_scale; 217 | idx /= num_scale; 218 | 219 | const int cam_index = idx % num_cams; 220 | idx /= num_cams; 221 | const int pts_index = idx % num_pts; 222 | idx /= num_pts; 223 | 224 | int anchor_index = idx % num_anchors; 225 | idx /= num_anchors; 226 | const int batch_index = idx % batch_size; 227 | idx /= batch_size; 228 | 229 | anchor_index = batch_index * num_anchors + anchor_index; 230 | const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1; 231 | 232 | const float loc_w = sample_location[loc_offset]; 233 | if (loc_w <= 0 || loc_w >= 1) return; 234 | const float loc_h = sample_location[loc_offset + 1]; 235 | if (loc_h <= 0 || loc_h >= 1) return; 236 | 237 | const float grad = grad_output[anchor_index*num_embeds + channel_index]; 238 | 239 | int cam_scale_index = cam_index * num_scale + scale_index; 240 | const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index; 241 | 242 | cam_scale_index = cam_scale_index << 1; 243 | const int h = spatial_shape[cam_scale_index]; 244 | const int w = spatial_shape[cam_scale_index + 1]; 245 | 246 | const float h_im = loc_h * h - 0.5; 247 | const float w_im = loc_w * w - 0.5; 248 | 249 | /* atomicAdd( */ 250 | /* output + anchor_index * num_embeds + channel_index, */ 251 | /* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */ 252 | /* ); */ 253 | const float weight = weights[weights_ptr]; 254 | float *grad_weights_ptr = grad_weights + weights_ptr; 255 | float *grad_location_ptr = grad_sampling_location + loc_offset; 256 | bilinear_sampling_grad( 257 | mc_ms_feat, weight, h, w, num_embeds, h_im, w_im, 258 | value_offset, 259 | grad, 260 | grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr 261 | ); 262 | } 263 | 264 | 265 | void deformable_aggregation( 266 | float* output, 267 | const float* mc_ms_feat, 268 | const int* spatial_shape, 269 | const int* scale_start_index, 270 | const float* sample_location, 271 | const float* weights, 272 | int batch_size, 273 | int num_cams, 274 | int num_feat, 275 | int num_embeds, 276 | int num_scale, 277 | int num_anchors, 278 | int num_pts, 279 | int num_groups 280 | ) { 281 | const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale; 282 | deformable_aggregation_kernel 283 | <<<(int)ceil(((double)num_kernels/128)), 128>>>( 284 | num_kernels, output, 285 | mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights, 286 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups 287 | ); 288 | } 289 | 290 | 291 | void deformable_aggregation_grad( 292 | const float* mc_ms_feat, 293 | const int* spatial_shape, 294 | const int* scale_start_index, 295 | const float* sample_location, 296 | const float* weights, 297 | const float* grad_output, 298 | float* grad_mc_ms_feat, 299 | float* grad_sampling_location, 300 | float* grad_weights, 301 | int batch_size, 302 | int num_cams, 303 | int num_feat, 304 | int num_embeds, 305 | int num_scale, 306 | int num_anchors, 307 | int num_pts, 308 | int num_groups 309 | ) { 310 | const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale; 311 | deformable_aggregation_grad_kernel 312 | <<<(int)ceil(((double)num_kernels/128)), 128>>>( 313 | num_kernels, 314 | mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights, 315 | grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights, 316 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups 317 | ); 318 | } 319 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from __future__ import division 3 | import sys 4 | import os 5 | 6 | print(sys.executable, os.path.abspath(__file__)) 7 | # import init_paths # for conda pkgs submitting method 8 | import argparse 9 | import copy 10 | import mmcv 11 | import time 12 | import torch 13 | import warnings 14 | from mmcv import Config, DictAction 15 | from mmcv.runner import get_dist_info, init_dist 16 | from os import path as osp 17 | 18 | from mmdet import __version__ as mmdet_version 19 | from mmdet.apis import train_detector 20 | from mmdet.datasets import build_dataset 21 | from mmdet.models import build_detector 22 | from mmdet.utils import collect_env, get_root_logger 23 | from mmdet.apis import set_random_seed 24 | from torch import distributed as dist 25 | from datetime import timedelta 26 | 27 | import cv2 28 | 29 | cv2.setNumThreads(8) 30 | 31 | 32 | def parse_args(): 33 | parser = argparse.ArgumentParser(description="Train a detector") 34 | parser.add_argument("config", help="train config file path") 35 | parser.add_argument("--work-dir", help="the dir to save logs and models") 36 | parser.add_argument( 37 | "--resume-from", help="the checkpoint file to resume from" 38 | ) 39 | parser.add_argument( 40 | "--no-validate", 41 | action="store_true", 42 | help="whether not to evaluate the checkpoint during training", 43 | ) 44 | group_gpus = parser.add_mutually_exclusive_group() 45 | group_gpus.add_argument( 46 | "--gpus", 47 | type=int, 48 | help="number of gpus to use " 49 | "(only applicable to non-distributed training)", 50 | ) 51 | group_gpus.add_argument( 52 | "--gpu-ids", 53 | type=int, 54 | nargs="+", 55 | help="ids of gpus to use " 56 | "(only applicable to non-distributed training)", 57 | ) 58 | parser.add_argument("--seed", type=int, default=0, help="random seed") 59 | parser.add_argument( 60 | "--deterministic", 61 | action="store_true", 62 | help="whether to set deterministic options for CUDNN backend.", 63 | ) 64 | parser.add_argument( 65 | "--options", 66 | nargs="+", 67 | action=DictAction, 68 | help="override some settings in the used config, the key-value pair " 69 | "in xxx=yyy format will be merged into config file (deprecate), " 70 | "change to --cfg-options instead.", 71 | ) 72 | parser.add_argument( 73 | "--cfg-options", 74 | nargs="+", 75 | action=DictAction, 76 | help="override some settings in the used config, the key-value pair " 77 | "in xxx=yyy format will be merged into config file. If the value to " 78 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 79 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 80 | "Note that the quotation marks are necessary and that no white space " 81 | "is allowed.", 82 | ) 83 | parser.add_argument( 84 | "--dist-url", 85 | type=str, 86 | default="auto", 87 | help="dist url for init process, such as tcp://localhost:8000", 88 | ) 89 | parser.add_argument("--gpus-per-machine", type=int, default=8) 90 | parser.add_argument( 91 | "--launcher", 92 | choices=["none", "pytorch", "slurm", "mpi", "mpi_nccl"], 93 | default="none", 94 | help="job launcher", 95 | ) 96 | parser.add_argument("--local_rank", type=int, default=0) 97 | parser.add_argument( 98 | "--autoscale-lr", 99 | action="store_true", 100 | help="automatically scale lr with the number of gpus", 101 | ) 102 | args = parser.parse_args() 103 | if "LOCAL_RANK" not in os.environ: 104 | os.environ["LOCAL_RANK"] = str(args.local_rank) 105 | 106 | if args.options and args.cfg_options: 107 | raise ValueError( 108 | "--options and --cfg-options cannot be both specified, " 109 | "--options is deprecated in favor of --cfg-options" 110 | ) 111 | if args.options: 112 | warnings.warn("--options is deprecated in favor of --cfg-options") 113 | args.cfg_options = args.options 114 | 115 | return args 116 | 117 | 118 | def main(): 119 | args = parse_args() 120 | 121 | cfg = Config.fromfile(args.config) 122 | if args.cfg_options is not None: 123 | cfg.merge_from_dict(args.cfg_options) 124 | # import modules from string list. 125 | if cfg.get("custom_imports", None): 126 | from mmcv.utils import import_modules_from_strings 127 | 128 | import_modules_from_strings(**cfg["custom_imports"]) 129 | 130 | # import modules from plguin/xx, registry will be updated 131 | if hasattr(cfg, "plugin"): 132 | if cfg.plugin: 133 | import importlib 134 | 135 | if hasattr(cfg, "plugin_dir"): 136 | plugin_dir = cfg.plugin_dir 137 | _module_dir = os.path.dirname(plugin_dir) 138 | _module_dir = _module_dir.split("/") 139 | _module_path = _module_dir[0] 140 | 141 | for m in _module_dir[1:]: 142 | _module_path = _module_path + "." + m 143 | print(_module_path) 144 | plg_lib = importlib.import_module(_module_path) 145 | else: 146 | # import dir is the dirpath for the config file 147 | _module_dir = os.path.dirname(args.config) 148 | _module_dir = _module_dir.split("/") 149 | _module_path = _module_dir[0] 150 | for m in _module_dir[1:]: 151 | _module_path = _module_path + "." + m 152 | print(_module_path) 153 | plg_lib = importlib.import_module(_module_path) 154 | from projects.mmdet3d_plugin.apis.train import custom_train_model 155 | 156 | # set cudnn_benchmark 157 | if cfg.get("cudnn_benchmark", False): 158 | torch.backends.cudnn.benchmark = True 159 | 160 | # work_dir is determined in this priority: CLI > segment in file > filename 161 | if args.work_dir is not None: 162 | # update configs according to CLI args if args.work_dir is not None 163 | cfg.work_dir = args.work_dir 164 | elif cfg.get("work_dir", None) is None: 165 | # use config filename as default work_dir if cfg.work_dir is None 166 | cfg.work_dir = osp.join( 167 | "./work_dirs", osp.splitext(osp.basename(args.config))[0] 168 | ) 169 | if args.resume_from is not None: 170 | cfg.resume_from = args.resume_from 171 | if args.gpu_ids is not None: 172 | cfg.gpu_ids = args.gpu_ids 173 | else: 174 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 175 | 176 | if args.autoscale_lr: 177 | # apply the linear scaling rule (https://arxiv.org/abs/1706.02677) 178 | cfg.optimizer["lr"] = cfg.optimizer["lr"] * len(cfg.gpu_ids) / 8 179 | 180 | # init distributed env first, since logger depends on the dist info. 181 | if args.launcher == "none": 182 | distributed = False 183 | elif args.launcher == "mpi_nccl": 184 | distributed = True 185 | 186 | import mpi4py.MPI as MPI 187 | 188 | comm = MPI.COMM_WORLD 189 | mpi_local_rank = comm.Get_rank() 190 | mpi_world_size = comm.Get_size() 191 | print( 192 | "MPI local_rank=%d, world_size=%d" 193 | % (mpi_local_rank, mpi_world_size) 194 | ) 195 | 196 | # num_gpus = torch.cuda.device_count() 197 | device_ids_on_machines = list(range(args.gpus_per_machine)) 198 | str_ids = list(map(str, device_ids_on_machines)) 199 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str_ids) 200 | torch.cuda.set_device(mpi_local_rank % args.gpus_per_machine) 201 | 202 | dist.init_process_group( 203 | backend="nccl", 204 | init_method=args.dist_url, 205 | world_size=mpi_world_size, 206 | rank=mpi_local_rank, 207 | timeout=timedelta(seconds=3600), 208 | ) 209 | 210 | cfg.gpu_ids = range(mpi_world_size) 211 | print("cfg.gpu_ids:", cfg.gpu_ids) 212 | else: 213 | distributed = True 214 | init_dist( 215 | args.launcher, timeout=timedelta(seconds=3600), **cfg.dist_params 216 | ) 217 | # re-set gpu_ids with distributed training mode 218 | _, world_size = get_dist_info() 219 | cfg.gpu_ids = range(world_size) 220 | 221 | # create work_dir 222 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 223 | # dump config 224 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 225 | # init the logger before other steps 226 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime()) 227 | log_file = osp.join(cfg.work_dir, f"{timestamp}.log") 228 | # specify logger name, if we still use 'mmdet', the output info will be 229 | # filtered and won't be saved in the log_file 230 | # TODO: ugly workaround to judge whether we are training det or seg model 231 | logger = get_root_logger( 232 | log_file=log_file, log_level=cfg.log_level 233 | ) 234 | 235 | # init the meta dict to record some important information such as 236 | # environment info and seed, which will be logged 237 | meta = dict() 238 | # log env info 239 | env_info_dict = collect_env() 240 | env_info = "\n".join([(f"{k}: {v}") for k, v in env_info_dict.items()]) 241 | dash_line = "-" * 60 + "\n" 242 | logger.info( 243 | "Environment info:\n" + dash_line + env_info + "\n" + dash_line 244 | ) 245 | meta["env_info"] = env_info 246 | meta["config"] = cfg.pretty_text 247 | 248 | # log some basic info 249 | logger.info(f"Distributed training: {distributed}") 250 | logger.info(f"Config:\n{cfg.pretty_text}") 251 | 252 | # set random seeds 253 | if args.seed is not None: 254 | logger.info( 255 | f"Set random seed to {args.seed}, " 256 | f"deterministic: {args.deterministic}" 257 | ) 258 | set_random_seed(args.seed, deterministic=args.deterministic) 259 | cfg.seed = args.seed 260 | meta["seed"] = args.seed 261 | meta["exp_name"] = osp.basename(args.config) 262 | 263 | model = build_detector( 264 | cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg") 265 | ) 266 | model.init_weights() 267 | 268 | logger.info(f"Model:\n{model}") 269 | datasets = [build_dataset(cfg.data.train)] 270 | if len(cfg.workflow) == 2: 271 | val_dataset = copy.deepcopy(cfg.data.val) 272 | # in case we use a dataset wrapper 273 | if "dataset" in cfg.data.train: 274 | val_dataset.pipeline = cfg.data.train.dataset.pipeline 275 | else: 276 | val_dataset.pipeline = cfg.data.train.pipeline 277 | # set test_mode=False here in deep copied config 278 | # which do not affect AP/AR calculation later 279 | # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa 280 | val_dataset.test_mode = False 281 | datasets.append(build_dataset(val_dataset)) 282 | if cfg.checkpoint_config is not None: 283 | # save mmdet version, config file content and class names in 284 | # checkpoints as meta data 285 | cfg.checkpoint_config.meta = dict( 286 | mmdet_version=mmdet_version, 287 | config=cfg.pretty_text, 288 | CLASSES=datasets[0].CLASSES, 289 | ) 290 | # add an attribute for visualization convenience 291 | model.CLASSES = datasets[0].CLASSES 292 | if hasattr(cfg, "plugin"): 293 | custom_train_model( 294 | model, 295 | datasets, 296 | cfg, 297 | distributed=distributed, 298 | validate=(not args.no_validate), 299 | timestamp=timestamp, 300 | meta=meta, 301 | ) 302 | else: 303 | train_detector( 304 | model, 305 | datasets, 306 | cfg, 307 | distributed=distributed, 308 | validate=(not args.no_validate), 309 | timestamp=timestamp, 310 | meta=meta, 311 | ) 312 | 313 | 314 | if __name__ == "__main__": 315 | torch.multiprocessing.set_start_method( 316 | "fork" 317 | ) # use fork workers_per_gpu can be > 1 318 | main() 319 | --------------------------------------------------------------------------------