├── projects
├── __init__.py
└── mmdet3d_plugin
│ ├── core
│ ├── evaluation
│ │ ├── __init__.py
│ │ └── eval_hooks.py
│ └── box3d.py
│ ├── __init__.py
│ ├── apis
│ ├── __init__.py
│ ├── train.py
│ ├── test.py
│ └── mmdet_train.py
│ ├── models
│ ├── detection2d
│ │ ├── __init__.py
│ │ ├── blocks.py
│ │ └── coster.py
│ ├── detection3d
│ │ ├── __init__.py
│ │ ├── losses.py
│ │ ├── decoder.py
│ │ └── blocks.py
│ ├── __init__.py
│ ├── utils.py
│ ├── grid_mask.py
│ ├── simpb.py
│ ├── aggregation.py
│ ├── allocation.py
│ └── instance_bank.py
│ ├── datasets
│ ├── samplers
│ │ ├── sampler.py
│ │ ├── __init__.py
│ │ ├── distributed_sampler.py
│ │ ├── group_sampler.py
│ │ ├── infinite_group_each_sample_in_batch_sampler.py
│ │ └── group_in_batch_sampler.py
│ ├── __init__.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ ├── loading.py
│ │ └── transform.py
│ ├── builder.py
│ └── utils.py
│ └── ops
│ ├── setup.py
│ ├── deformable_aggregation.py
│ ├── __init__.py
│ └── src
│ ├── deformable_aggregation.cpp
│ └── deformable_aggregation_cuda.cu
├── docs
├── figs
│ └── arch.png
├── prepare_environment.md
├── prepare_dataset.md
└── training_evaluation.md
├── requirement.txt
├── tools
├── dist_train.sh
├── dist_test.sh
├── anchor_generator.py
├── fuse_conv_bn.py
├── benchmark.py
└── train.py
├── README.md
└── LICENSE
/projects/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/docs/figs/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nullmax-vision/SimPB/HEAD/docs/figs/arch.png
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/core/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | from .eval_hooks import CustomDistEvalHook
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .models import *
3 | from .apis import *
4 | from .core.evaluation import *
5 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/apis/__init__.py:
--------------------------------------------------------------------------------
1 | from .train import custom_train_model
2 | from .mmdet_train import custom_train_detector
3 |
4 | # from .test import custom_multi_gpu_test
5 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/core/box3d.py:
--------------------------------------------------------------------------------
1 | X, Y, Z, W, L, H, SIN_YAW, COS_YAW, VX, VY, VZ = list(range(11)) # undecoded
2 | CNS, YNS = 0, 1 # centerness and yawness indices in qulity
3 | YAW = 6 # decoded
4 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | numpy==1.23.5
2 | mmcv_full==1.7.1
3 | mmdet==2.28.2
4 | urllib3==1.26.16
5 | pyquaternion==0.9.9
6 | nuscenes-devkit==1.1.10
7 | yapf==0.33.0
8 | tensorboard==2.14.0
9 | motmetrics==1.1.3
10 | pandas==1.1.5
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/detection2d/__init__.py:
--------------------------------------------------------------------------------
1 | from .coster import SparseBox2DCoster
2 | from .target import SparseBox2DTarget
3 | from .denoise import Denoise2D
4 | from .blocks import SparseBox2DEncoder, SparseBox2DRefinementModule
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/samplers/sampler.py:
--------------------------------------------------------------------------------
1 | from mmcv.utils.registry import Registry, build_from_cfg
2 |
3 | SAMPLER = Registry("sampler")
4 |
5 |
6 | def build_sampler(cfg, default_args):
7 | return build_from_cfg(cfg, SAMPLER, default_args)
8 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | from .group_sampler import DistributedGroupSampler
2 | from .distributed_sampler import DistributedSampler
3 | from .sampler import SAMPLER, build_sampler
4 | from .group_in_batch_sampler import (
5 | GroupInBatchSampler,
6 | )
7 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .nuscenes_dataset import NuScenes3DDetTrackDataset
2 | from .builder import *
3 | from .pipelines import *
4 | from .samplers import *
5 |
6 | __all__ = [
7 | 'NuScenes3DDetTrackDataset',
8 | "custom_build_dataset",
9 | ]
10 |
--------------------------------------------------------------------------------
/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PORT=${PORT:-28650}
6 |
7 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
8 | python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
10 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/detection3d/__init__.py:
--------------------------------------------------------------------------------
1 | from .decoder import SparseBox3DDecoder
2 | from .target import SparseBox3DTarget, SparseBox3DTargetWith2D
3 | from .blocks import (
4 | SparseBox3DRefinementModule,
5 | SparseBox3DKeyPointsGenerator,
6 | SparseBox3DEncoder,
7 | )
8 | from .losses import SparseBox3DLoss
9 |
--------------------------------------------------------------------------------
/tools/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | PORT=${PORT:-29610}
7 |
8 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
9 | python3 -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
11 |
--------------------------------------------------------------------------------
/docs/prepare_environment.md:
--------------------------------------------------------------------------------
1 | ## Prepare Environment
2 | * Linux
3 | * python 3.8
4 | * Pytorch 1.10.0+
5 | * CUDA 11.1+
6 |
7 | **1. Create a conda virtual environment**
8 | ```bash
9 | conda create -n simpb python=3.8 -y
10 | conda activate simpb
11 | ```
12 |
13 | **2. Install Pytorch**
14 | ```bash
15 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1+cu117 -f https://download.pytorch.org/whl/torch_stable.html
16 | ```
17 |
18 | **3. Install other packages**
19 | ```bash
20 | pip install --upgrade pip
21 | pip install -r requirement.txt
22 | ```
23 |
24 | **4. Compile the deformable_aggregation CUDA op**
25 | ```bash
26 | cd projects/mmdet3d_plugin/ops
27 | python setup.py develop
28 | cd ../../../
29 | ```
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | from .transform import (
2 | InstanceNameFilter,
3 | CircleObjectRangeFilter,
4 | NormalizeMultiviewImage,
5 | NuScenesSparse4DAdaptor,
6 | MultiScaleDepthMapGenerator,
7 | )
8 | from .augment import (
9 | ResizeCropFlipImage,
10 | BBoxRotation,
11 | BBoxScale,
12 | PhotoMetricDistortionMultiViewImage,
13 | )
14 | from .loading import LoadMultiViewImageFromFiles, LoadPointsFromFile
15 |
16 | __all__ = [
17 | "InstanceNameFilter",
18 | "ResizeCropFlipImage",
19 | "BBoxRotation",
20 | "BBoxScale",
21 | "CircleObjectRangeFilter",
22 | "MultiScaleDepthMapGenerator",
23 | "NormalizeMultiviewImage",
24 | "PhotoMetricDistortionMultiViewImage",
25 | "NuScenesSparse4DAdaptor",
26 | "LoadMultiViewImageFromFiles",
27 | "LoadPointsFromFile",
28 | ]
29 |
--------------------------------------------------------------------------------
/docs/prepare_dataset.md:
--------------------------------------------------------------------------------
1 | ## Prepare Dataset
2 | ### NuScenes
3 | **1. Download [nuScenes](https://www.nuscenes.org/download) V1.0 dataset**
4 |
5 | **2. Link dataset to project**
6 | ```bash
7 | ln -s path/to/nuscenes ./data/nuscenes
8 | ```
9 | **3. Convert nuscenes dataset**
10 | ```bash
11 | python tools/data_converter/nuscenes_converter.py --info_prefix ./data/nuscenes/simpb_nuscenes
12 | ```
13 | ### Kmean Anchors
14 | ```bash
15 | python tools/anchor_generator.py --ann_file ./data/nuscenes/simpb_nuscenes_infos_train.pkl
16 | ```
17 |
18 | **Folder structure**
19 | ```
20 | SimPB
21 | ├── projects/
22 | ├── tools/
23 | ├── ckpts/
24 | ├── data/
25 | │ ├── nuscenes/
26 | │ │ ├── maps/
27 | │ │ ├── samples/
28 | │ │ ├── sweeps/
29 | │ │ ├── v1.0-test/
30 | | | ├── v1.0-trainval/
31 | | | ├── nuscenes_kmeans900.npy
32 | | | ├── simpb_nuscenes_infos_test.pkl
33 | | | ├── simpb_nuscenes_infos_train.pkl
34 | | | ├── simpb_nuscenes_infos_val.pkl
35 | ```
--------------------------------------------------------------------------------
/docs/training_evaluation.md:
--------------------------------------------------------------------------------
1 | ## Getting Started
2 |
3 | ## Train
4 |
5 | **1. Download pretrained backbone**
6 | ```bash
7 | cd ckpts
8 | wget https://download.pytorch.org/models/resnet50-19c8e357.pth
9 | ```
10 |
11 | **2. Train simpb with multiple GPUs**
12 | ```bash
13 | bash ./tools/dist_train.sh ./projects/configs/simpb_nus_r50_img_704x256.py 8 --no-validate
14 | ```
15 |
16 | ## Test
17 | **1. Download pretrained model**
18 |
19 | download pretrained model [here](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_img.pth), or use your own training weight
20 |
21 | **2. Evaluate the pretrained model**
22 | ```bash
23 | bash ./tools/dist_test.sh ./projects/configs/simpb_nus_r50_img_704x256.py path/to/model.pth 8 --eval bbox
24 | ```
25 |
26 | ## Visualize
27 | **1. Get results file**
28 | ```bash
29 | python ./tools/test.py ./projects/configs/simpb_nus_r50_img_704x256.py path/to/model.pth --out path/to/model.pkl
30 | ```
31 |
32 | **2. Load and show results**
33 | ```bash
34 | python ./tools/test.py ./projects/configs/simpb_nus_r50_img_704x256.py path/to/model.pth --result_file path/to/model.pkl --show_only --show-dir ./
35 | ```
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .simpb import SimPB
2 | from .simpb_head import SimPBHead
3 | from .blocks import (
4 | DeformableFeatureAggregation,
5 | DenseDepthNet,
6 | AsymmetricFFN,
7 | )
8 | from .instance_bank import InstanceBank
9 |
10 | from .detection2d import (
11 | SparseBox2DEncoder,
12 | SparseBox2DRefinementModule,
13 | )
14 |
15 | from .detection3d import (
16 | SparseBox3DDecoder,
17 | SparseBox3DTarget,
18 | SparseBox3DRefinementModule,
19 | SparseBox3DKeyPointsGenerator,
20 | SparseBox3DEncoder,
21 | )
22 |
23 | from .allocation import DynamicQueryAllocation
24 | from .aggregation import AdaptiveQueryAggregation
25 | from .group_attn import (QueryGroupMultiheadAttention,
26 | QueryGroupMultiScaleDeformableAttention)
27 |
28 | __all__ = [
29 | "SimPB",
30 | "SimPBHead",
31 | "DeformableFeatureAggregation",
32 | "DenseDepthNet",
33 | "AsymmetricFFN",
34 | "InstanceBank",
35 | "SparseBox3DDecoder",
36 | "SparseBox3DTarget",
37 | "SparseBox3DRefinementModule",
38 | "SparseBox3DKeyPointsGenerator",
39 | "SparseBox3DEncoder",
40 | "DynamicQueryAllocation",
41 | "AdaptiveQueryAggregation",
42 | "QueryGroupMultiheadAttention",
43 | "QueryGroupMultiScaleDeformableAttention",
44 | ]
45 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/apis/train.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Zhiqi Li
5 | # ---------------------------------------------
6 |
7 | from .mmdet_train import custom_train_detector
8 | # from mmseg.apis import train_segmentor
9 | from mmdet.apis import train_detector
10 |
11 |
12 | def custom_train_model(
13 | model,
14 | dataset,
15 | cfg,
16 | distributed=False,
17 | validate=False,
18 | timestamp=None,
19 | meta=None,
20 | ):
21 | """A function wrapper for launching model training according to cfg.
22 |
23 | Because we need different eval_hook in runner. Should be deprecated in the
24 | future.
25 | """
26 | if cfg.model.type in ["EncoderDecoder3D"]:
27 | assert False
28 | else:
29 | custom_train_detector(
30 | model,
31 | dataset,
32 | cfg,
33 | distributed=distributed,
34 | validate=validate,
35 | timestamp=timestamp,
36 | meta=meta,
37 | )
38 |
39 |
40 | def train_model(
41 | model,
42 | dataset,
43 | cfg,
44 | distributed=False,
45 | validate=False,
46 | timestamp=None,
47 | meta=None,
48 | ):
49 | """A function wrapper for launching model training according to cfg.
50 |
51 | Because we need different eval_hook in runner. Should be deprecated in the
52 | future.
53 | """
54 | train_detector(
55 | model,
56 | dataset,
57 | cfg,
58 | distributed=distributed,
59 | validate=validate,
60 | timestamp=timestamp,
61 | meta=meta,
62 | )
63 |
--------------------------------------------------------------------------------
/tools/anchor_generator.py:
--------------------------------------------------------------------------------
1 | import mmcv
2 | import argparse
3 | import numpy as np
4 | from sklearn.cluster import KMeans
5 | from projects.mmdet3d_plugin.core.box3d import *
6 |
7 |
8 | def get_kmeans_anchor(
9 | ann_file,
10 | num_anchor=900,
11 | detection_range=55,
12 | output_file_name="nuscenes_kmeans900.npy",
13 | verbose=False,
14 | ):
15 | data = mmcv.load(ann_file, file_format="pkl")
16 | gt_boxes = np.concatenate([x["gt_boxes"] for x in data["infos"]], axis=0)
17 | distance = np.linalg.norm(gt_boxes[:, :3], axis=-1, ord=2)
18 | mask = distance <= detection_range
19 | gt_boxes = gt_boxes[mask]
20 | clf = KMeans(n_clusters=num_anchor, verbose=verbose)
21 | print("===========Starting kmeans, please wait.===========")
22 | clf.fit(gt_boxes[:, [X, Y, Z]])
23 | anchor = np.zeros((num_anchor, 11))
24 | anchor[:, [X, Y, Z]] = clf.cluster_centers_
25 | anchor[:, [W, L, H]] = np.log(gt_boxes[:, [W, L, H]].mean(axis=0))
26 | anchor[:, COS_YAW] = 1
27 | np.save(output_file_name, anchor)
28 | print(f"===========Done! Save results to {output_file_name}.===========")
29 |
30 |
31 | if __name__ == "__main__":
32 | parser = argparse.ArgumentParser(description="anchor kmeans")
33 | parser.add_argument("--ann_file", type=str, required=True)
34 | parser.add_argument("--num_anchor", type=int, default=900)
35 | parser.add_argument("--detection_range", type=float, default=55)
36 | parser.add_argument("--output_file_name", type=str, default="kmeans900.npy")
37 | parser.add_argument("--verbose", action="store_true")
38 | args = parser.parse_args()
39 |
40 | get_kmeans_anchor(
41 | args.ann_file,
42 | args.num_anchor,
43 | args.detection_range,
44 | args.output_file_name,
45 | args.verbose,
46 | )
47 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/ops/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from setuptools import setup
5 | from torch.utils.cpp_extension import (
6 | BuildExtension,
7 | CppExtension,
8 | CUDAExtension,
9 | )
10 |
11 |
12 | def make_cuda_ext(
13 | name,
14 | module,
15 | sources,
16 | sources_cuda=[],
17 | extra_args=[],
18 | extra_include_path=[],
19 | ):
20 |
21 | define_macros = []
22 | extra_compile_args = {"cxx": [] + extra_args}
23 |
24 | if torch.cuda.is_available() or os.getenv("FORCE_CUDA", "0") == "1":
25 | define_macros += [("WITH_CUDA", None)]
26 | extension = CUDAExtension
27 | extra_compile_args["nvcc"] = extra_args + [
28 | "-D__CUDA_NO_HALF_OPERATORS__",
29 | "-D__CUDA_NO_HALF_CONVERSIONS__",
30 | "-D__CUDA_NO_HALF2_OPERATORS__",
31 | ]
32 | sources += sources_cuda
33 | else:
34 | print("Compiling {} without CUDA".format(name))
35 | extension = CppExtension
36 |
37 | return extension(
38 | name="{}.{}".format(module, name),
39 | sources=[os.path.join(*module.split("."), p) for p in sources],
40 | include_dirs=extra_include_path,
41 | define_macros=define_macros,
42 | extra_compile_args=extra_compile_args,
43 | )
44 |
45 |
46 | if __name__ == "__main__":
47 | setup(
48 | name="deformable_aggregation_ext",
49 | ext_modules=[
50 | make_cuda_ext(
51 | "deformable_aggregation_ext",
52 | module=".",
53 | sources=[
54 | f"src/deformable_aggregation.cpp",
55 | f"src/deformable_aggregation_cuda.cu",
56 | ],
57 | ),
58 | ],
59 | cmdclass={"build_ext": BuildExtension},
60 | )
61 |
--------------------------------------------------------------------------------
/tools/fuse_conv_bn.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 |
4 | import torch
5 | from torch import nn as nn
6 | from mmcv.runner import save_checkpoint
7 | from mmdet.apis import init_detector
8 |
9 |
10 | def fuse_conv_bn(conv, bn):
11 | """During inference, the functionary of batch norm layers is turned off but
12 | only the mean and var alone channels are used, which exposes the chance to
13 | fuse it with the preceding conv layers to save computations and simplify
14 | network structures."""
15 | conv_w = conv.weight
16 | conv_b = (
17 | conv.bias
18 | if conv.bias is not None
19 | else torch.zeros_like(bn.running_mean)
20 | )
21 |
22 | factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
23 | conv.weight = nn.Parameter(
24 | conv_w * factor.reshape([conv.out_channels, 1, 1, 1])
25 | )
26 | conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
27 | return conv
28 |
29 |
30 | def fuse_module(m):
31 | last_conv = None
32 | last_conv_name = None
33 |
34 | for name, child in m.named_children():
35 | if isinstance(child, (nn.BatchNorm2d, nn.SyncBatchNorm)):
36 | if last_conv is None: # only fuse BN that is after Conv
37 | continue
38 | fused_conv = fuse_conv_bn(last_conv, child)
39 | m._modules[last_conv_name] = fused_conv
40 | # To reduce changes, set BN as Identity instead of deleting it.
41 | m._modules[name] = nn.Identity()
42 | last_conv = None
43 | elif isinstance(child, nn.Conv2d):
44 | last_conv = child
45 | last_conv_name = name
46 | else:
47 | fuse_module(child)
48 | return m
49 |
50 |
51 | def parse_args():
52 | parser = argparse.ArgumentParser(
53 | description="fuse Conv and BN layers in a model"
54 | )
55 | parser.add_argument("config", help="config file path")
56 | parser.add_argument("checkpoint", help="checkpoint file path")
57 | parser.add_argument("out", help="output path of the converted model")
58 | args = parser.parse_args()
59 | return args
60 |
61 |
62 | def main():
63 | args = parse_args()
64 | # build the model from a config file and a checkpoint file
65 | model = init_detector(args.config, args.checkpoint)
66 | # fuse conv and bn layers of the model
67 | fused_model = fuse_module(model)
68 | save_checkpoint(fused_model, args.out)
69 |
70 |
71 | if __name__ == "__main__":
72 | main()
73 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/ops/deformable_aggregation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd.function import Function, once_differentiable
3 |
4 | from . import deformable_aggregation_ext
5 |
6 |
7 | class DeformableAggregationFunction(Function):
8 | @staticmethod
9 | def forward(
10 | ctx,
11 | mc_ms_feat,
12 | spatial_shape,
13 | scale_start_index,
14 | sampling_location,
15 | weights,
16 | ):
17 | # output: [bs, num_pts, num_embeds]
18 | mc_ms_feat = mc_ms_feat.contiguous().float()
19 | spatial_shape = spatial_shape.contiguous().int()
20 | scale_start_index = scale_start_index.contiguous().int()
21 | sampling_location = sampling_location.contiguous().float()
22 | weights = weights.contiguous().float()
23 | output = deformable_aggregation_ext.deformable_aggregation_forward(
24 | mc_ms_feat,
25 | spatial_shape,
26 | scale_start_index,
27 | sampling_location,
28 | weights,
29 | )
30 | ctx.save_for_backward(
31 | mc_ms_feat,
32 | spatial_shape,
33 | scale_start_index,
34 | sampling_location,
35 | weights,
36 | )
37 | return output
38 |
39 | @staticmethod
40 | @once_differentiable
41 | def backward(ctx, grad_output):
42 | (
43 | mc_ms_feat,
44 | spatial_shape,
45 | scale_start_index,
46 | sampling_location,
47 | weights,
48 | ) = ctx.saved_tensors
49 | mc_ms_feat = mc_ms_feat.contiguous().float()
50 | spatial_shape = spatial_shape.contiguous().int()
51 | scale_start_index = scale_start_index.contiguous().int()
52 | sampling_location = sampling_location.contiguous().float()
53 | weights = weights.contiguous().float()
54 |
55 | grad_mc_ms_feat = torch.zeros_like(mc_ms_feat)
56 | grad_sampling_location = torch.zeros_like(sampling_location)
57 | grad_weights = torch.zeros_like(weights)
58 | deformable_aggregation_ext.deformable_aggregation_backward(
59 | mc_ms_feat,
60 | spatial_shape,
61 | scale_start_index,
62 | sampling_location,
63 | weights,
64 | grad_output.contiguous(),
65 | grad_mc_ms_feat,
66 | grad_sampling_location,
67 | grad_weights,
68 | )
69 | return (
70 | grad_mc_ms_feat,
71 | None,
72 | None,
73 | grad_sampling_location,
74 | grad_weights,
75 | )
76 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/detection3d/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from mmcv.utils import build_from_cfg
5 | from mmdet.models.builder import LOSSES
6 |
7 | from projects.mmdet3d_plugin.core.box3d import *
8 | from torch.nn.functional import cosine_similarity
9 |
10 |
11 | @LOSSES.register_module()
12 | class SparseBox3DLoss(nn.Module):
13 | def __init__(
14 | self,
15 | loss_box,
16 | loss_centerness=None,
17 | loss_yawness=None,
18 | cls_allow_reverse=None,
19 | ):
20 | super().__init__()
21 |
22 | def build(cfg, registry):
23 | if cfg is None:
24 | return None
25 | return build_from_cfg(cfg, registry)
26 |
27 | self.loss_box = build(loss_box, LOSSES)
28 | self.loss_cns = build(loss_centerness, LOSSES)
29 | self.loss_yns = build(loss_yawness, LOSSES)
30 | self.cls_allow_reverse = cls_allow_reverse
31 |
32 | def forward(
33 | self,
34 | box,
35 | box_target,
36 | weight=None,
37 | avg_factor=None,
38 | suffix="",
39 | quality=None,
40 | cls_target=None,
41 | **kwargs,
42 | ):
43 | # Some categories do not distinguish between positive and negative
44 | # directions. For example, barrier in nuScenes dataset.
45 | if self.cls_allow_reverse is not None and cls_target is not None:
46 | if_reverse = (cosine_similarity(
47 | box_target[..., [SIN_YAW, COS_YAW]], box[..., [SIN_YAW, COS_YAW]], dim=-1) < 0)
48 | if_reverse = (torch.isin(cls_target, cls_target.new_tensor(self.cls_allow_reverse)) & if_reverse)
49 |
50 | box_target[..., [SIN_YAW, COS_YAW]] = torch.where(
51 | if_reverse[..., None], -box_target[..., [SIN_YAW, COS_YAW]], box_target[..., [SIN_YAW, COS_YAW]])
52 |
53 | output = {}
54 | box_loss = self.loss_box(box, box_target, weight=weight, avg_factor=avg_factor)
55 | output[f"loss_box{suffix}"] = box_loss
56 |
57 | if quality is not None:
58 | cns = quality[..., CNS]
59 | yns = quality[..., YNS].sigmoid()
60 | cns_target = torch.norm(box_target[..., [X, Y, Z]] - box[..., [X, Y, Z]], p=2, dim=-1)
61 | cns_target = torch.exp(-cns_target)
62 | cns_loss = self.loss_cns(cns, cns_target, avg_factor=avg_factor)
63 | output[f"loss_cns{suffix}"] = cns_loss
64 |
65 | yns_target = (cosine_similarity(box_target[..., [SIN_YAW, COS_YAW]], box[..., [SIN_YAW, COS_YAW]], dim=-1) > 0)
66 | yns_target = yns_target.float()
67 | yns_loss = self.loss_yns(yns, yns_target, avg_factor=avg_factor)
68 | output[f"loss_yns{suffix}"] = yns_loss
69 | return output
70 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | def inverse_sigmoid(x, eps=1e-5):
5 | x = x.clamp(min=0, max=1)
6 | x1 = x.clamp(min=eps)
7 | x2 = (1 - x).clamp(min=eps)
8 | return torch.log(x1 / x2)
9 |
10 | def get_valid_ratio(mask):
11 | """Get the valid radios of feature maps of all level."""
12 | _, H, W = mask.shape
13 | valid_H = torch.sum(~mask[:, :, 0], 1)
14 | valid_W = torch.sum(~mask[:, 0, :], 1)
15 | valid_ratio_h = valid_H.float() / H
16 | valid_ratio_w = valid_W.float() / W
17 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
18 | return valid_ratio
19 |
20 | def get_reference_points(spatial_shapes, valid_ratios, device):
21 | reference_points_list = []
22 | for lvl, (H, W) in enumerate(spatial_shapes):
23 | # TODO check this 0.5
24 | ref_y, ref_x = torch.meshgrid(
25 | torch.linspace(
26 | 0.5, H - 0.5, H, dtype=torch.float32, device=device),
27 | torch.linspace(
28 | 0.5, W - 0.5, W, dtype=torch.float32, device=device))
29 | ref_y = ref_y.reshape(-1)[None] / (
30 | valid_ratios[:, None, lvl, 1] * H)
31 | ref_x = ref_x.reshape(-1)[None] / (
32 | valid_ratios[:, None, lvl, 0] * W)
33 | ref = torch.stack((ref_x, ref_y), -1)
34 | reference_points_list.append(ref)
35 | reference_points = torch.cat(reference_points_list, 1)
36 | reference_points = reference_points[:, :, None] * valid_ratios[:, None]
37 | return reference_points
38 |
39 |
40 | def pos2posemb2d(pos, num_pos_feats=128, temperature=10000):
41 | scale = 2 * math.pi
42 | pos = pos * scale
43 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
44 | dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
45 | pos_x = pos[..., 0, None] / dim_t
46 | pos_y = pos[..., 1, None] / dim_t
47 | pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
48 | pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
49 | if pos.size(-1) == 2:
50 | posemb = torch.cat((pos_y, pos_x), dim=-1)
51 | elif pos.size(-1) == 4:
52 | w_embed = pos[:, :, 2] * scale
53 | pos_w = w_embed[:, :, None] / dim_t
54 | pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
55 |
56 | h_embed = pos[:, :, 3] * scale
57 | pos_h = h_embed[:, :, None] / dim_t
58 | pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
59 |
60 | posemb = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
61 | else:
62 | raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos.size(-1)))
63 | return posemb
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/samplers/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | from torch.utils.data import DistributedSampler as _DistributedSampler
5 | from .sampler import SAMPLER
6 |
7 | import pdb
8 | import sys
9 |
10 |
11 | class ForkedPdb(pdb.Pdb):
12 | def interaction(self, *args, **kwargs):
13 | _stdin = sys.stdin
14 | try:
15 | sys.stdin = open("/dev/stdin")
16 | pdb.Pdb.interaction(self, *args, **kwargs)
17 | finally:
18 | sys.stdin = _stdin
19 |
20 |
21 | def set_trace():
22 | ForkedPdb().set_trace(sys._getframe().f_back)
23 |
24 |
25 | @SAMPLER.register_module()
26 | class DistributedSampler(_DistributedSampler):
27 | def __init__(
28 | self, dataset=None, num_replicas=None, rank=None, shuffle=True, seed=0
29 | ):
30 | super().__init__(
31 | dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle
32 | )
33 | # for the compatibility from PyTorch 1.3+
34 | self.seed = seed if seed is not None else 0
35 |
36 | def __iter__(self):
37 | # deterministically shuffle based on epoch
38 | assert not self.shuffle
39 | if "data_infos" in dir(self.dataset):
40 | timestamps = [
41 | x["timestamp"] / 1e6 for x in self.dataset.data_infos
42 | ]
43 | vehicle_idx = [
44 | x["lidar_path"].split("/")[-1][:4]
45 | if "lidar_path" in x
46 | else None
47 | for x in self.dataset.data_infos
48 | ]
49 | else:
50 | timestamps = [
51 | x["timestamp"] / 1e6
52 | for x in self.dataset.datasets[0].data_infos
53 | ] * len(self.dataset.datasets)
54 | vehicle_idx = [
55 | x["lidar_path"].split("/")[-1][:4]
56 | if "lidar_path" in x
57 | else None
58 | for x in self.dataset.datasets[0].data_infos
59 | ] * len(self.dataset.datasets)
60 |
61 | sequence_splits = []
62 | for i in range(len(timestamps)):
63 | if i == 0 or (
64 | abs(timestamps[i] - timestamps[i - 1]) > 4
65 | or vehicle_idx[i] != vehicle_idx[i - 1]
66 | ):
67 | sequence_splits.append([i])
68 | else:
69 | sequence_splits[-1].append(i)
70 |
71 | indices = []
72 | perfix_sum = 0
73 | split_length = len(self.dataset) // self.num_replicas
74 | for i in range(len(sequence_splits)):
75 | if perfix_sum >= (self.rank + 1) * split_length:
76 | break
77 | elif perfix_sum >= self.rank * split_length:
78 | indices.extend(sequence_splits[i])
79 | perfix_sum += len(sequence_splits[i])
80 |
81 | self.num_samples = len(indices)
82 | return iter(indices)
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
SimPB
3 | [ECCV 2024] SimPB: A Single Model for 2D and 3D Object Detection from Multiple Cameras
4 |
5 | [](https://arxiv.org/abs/2403.10353)
6 |
7 |
8 |
9 | 
10 |
11 | ## Introduction
12 | This repository is an official implementation of *SimPB*, which **Sim**ultaneously detects 2D objects in the
13 | **P**erspective view and 3D objects in the **B**EV space from multiple cameras.
14 |
15 |
16 | ## Getting started
17 | - [Prepare Environment](docs/prepare_environment.md)
18 | - [Prepare Dataset](docs/prepare_dataset.md)
19 | - [Training and Evaluation](docs/training_evaluation.md)
20 |
21 | ## Model Zoo
22 |
23 | **Results on NuScenes validation**
24 |
25 | | method | backbone | pretrain | img size | mAP | NDS | config | ckpt | log |
26 | |:-------:|:---------:|:---------------------------------------------------------------------:|:--------:|:-----:|:-----:|:--------------------------------------------------------------:|:----:|:---:|
27 | | SimPB+ | ResNet50 | [ImageNet](https://download.pytorch.org/models/resnet50-19c8e357.pth) | 704x256 | 0.479 | 0.586 | [config](projects/configs/simpb_nus_r50_img_704x256.py) | [ckpt](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_img.pth) | [log](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_img.log) |
28 | | SimPB+ | ResNet50 | [nuImg](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/nuimages_semseg/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim/cascade_mask_rcnn_r50_fpn_coco-20e_20e_nuim_20201009_124951-40963960.pth) | 704x256 | 0.489 | 0.591 | [config](projects/configs/simpb_nus_r50_uimg_704x256.py) | [ckpt](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_uimg.pth) | [log](https://github.com/nullmax-vision/SimPB/releases/download/weights/simpb_r50_uimg.log) |
29 | | SimPB | ResNet101 | nuImg | 1408x512 | 0.539 | 0.629 | | | |
30 |
31 | Note: SimPB+ is a modified architecture that introduces 2d denoise and removes the encoder. This slightly reduces the runtime while maintaining comparable performance compared to the released script.
32 |
33 | ## Acknowledgement
34 | Thanks to these excellent open-source works:
35 |
36 | [Sparse4Dv3](https://github.com/HorizonRobotics/Sparse4D),
37 | [StreamPETR](https://github.com/exiawsh/StreamPETR),
38 | [SparseBEV](https://github.com/MCG-NJU/SparseBEV),
39 | [Far3D](https://github.com/megvii-research/Far3D),
40 | [MMDetection3D](https://github.com/open-mmlab/mmdetection3d)
41 |
42 | ## Citation
43 | ```bibtex
44 | @article{simpb,
45 | title={SimPB: A Single Model for 2D and 3D Object Detection from Multiple Cameras},
46 | author={Yingqi Tang and Zhaotie Meng and Guoliang Chen and Erkang Cheng},
47 | journal={ECCV},
48 | year={2024}
49 | }
50 | ```
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/ops/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from .deformable_aggregation import DeformableAggregationFunction
4 |
5 |
6 | def deformable_aggregation_function(
7 | feature_maps,
8 | spatial_shape,
9 | scale_start_index,
10 | sampling_location,
11 | weights,
12 | ):
13 | return DeformableAggregationFunction.apply(
14 | feature_maps,
15 | spatial_shape,
16 | scale_start_index,
17 | sampling_location,
18 | weights,
19 | )
20 |
21 |
22 | def feature_maps_format(feature_maps, inverse=False):
23 | if inverse:
24 | col_feats, spatial_shape, scale_start_index = feature_maps
25 | num_cams, num_levels = spatial_shape.shape[:2]
26 |
27 | split_size = spatial_shape[..., 0] * spatial_shape[..., 1]
28 | split_size = split_size.cpu().numpy().tolist()
29 |
30 | idx = 0
31 | cam_split = [1]
32 | cam_split_size = [sum(split_size[0])]
33 | for i in range(num_cams - 1):
34 | if not torch.all(spatial_shape[i] == spatial_shape[i + 1]):
35 | cam_split.append(0)
36 | cam_split_size.append(0)
37 | cam_split[-1] += 1
38 | cam_split_size[-1] += sum(split_size[i + 1])
39 | mc_feat = [
40 | x.unflatten(1, (cam_split[i], -1))
41 | for i, x in enumerate(col_feats.split(cam_split_size, dim=1))
42 | ]
43 |
44 | spatial_shape = spatial_shape.cpu().numpy().tolist()
45 | mc_ms_feat = []
46 | shape_index = 0
47 | for i, feat in enumerate(mc_feat):
48 | feat = list(feat.split(split_size[shape_index], dim=2))
49 | for j, f in enumerate(feat):
50 | feat[j] = f.unflatten(2, spatial_shape[shape_index][j])
51 | feat[j] = feat[j].permute(0, 1, 4, 2, 3)
52 | mc_ms_feat.append(feat)
53 | shape_index += cam_split[i]
54 | return mc_ms_feat
55 |
56 | if isinstance(feature_maps[0], (list, tuple)):
57 | formated = [feature_maps_format(x) for x in feature_maps]
58 | col_feats = torch.cat([x[0] for x in formated], dim=1)
59 | spatial_shape = torch.cat([x[1] for x in formated], dim=0)
60 | scale_start_index = torch.cat([x[2] for x in formated], dim=0)
61 | return [col_feats, spatial_shape, scale_start_index]
62 |
63 | bs, num_cams = feature_maps[0].shape[:2]
64 | spatial_shape = []
65 |
66 | col_feats = []
67 | for i, feat in enumerate(feature_maps):
68 | spatial_shape.append(feat.shape[-2:])
69 | col_feats.append(
70 | torch.reshape(feat, (bs, num_cams, feat.shape[2], -1))
71 | )
72 |
73 | col_feats = torch.cat(col_feats, dim=-1).permute(0, 1, 3, 2).flatten(1, 2)
74 | spatial_shape = [spatial_shape] * num_cams
75 | spatial_shape = torch.tensor(
76 | spatial_shape,
77 | dtype=torch.int64,
78 | device=col_feats.device,
79 | )
80 | scale_start_index = spatial_shape[..., 0] * spatial_shape[..., 1]
81 | scale_start_index = scale_start_index.flatten().cumsum(dim=0)
82 | scale_start_index = torch.cat(
83 | [torch.tensor([0]).to(scale_start_index), scale_start_index[:-1]]
84 | )
85 | scale_start_index = scale_start_index.reshape(num_cams, -1)
86 |
87 | feature_maps = [
88 | col_feats,
89 | spatial_shape,
90 | scale_start_index,
91 | ]
92 | return feature_maps
93 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/core/evaluation/eval_hooks.py:
--------------------------------------------------------------------------------
1 | # Note: Considering that MMCV's EvalHook updated its interface in V1.3.16,
2 | # in order to avoid strong version dependency, we did not directly
3 | # inherit EvalHook but BaseDistEvalHook.
4 |
5 | import bisect
6 | import os.path as osp
7 |
8 | import mmcv
9 | import torch.distributed as dist
10 | from mmcv.runner import DistEvalHook as BaseDistEvalHook
11 | from mmcv.runner import EvalHook as BaseEvalHook
12 | from torch.nn.modules.batchnorm import _BatchNorm
13 | from mmdet.core.evaluation.eval_hooks import DistEvalHook
14 |
15 |
16 | def _calc_dynamic_intervals(start_interval, dynamic_interval_list):
17 | assert mmcv.is_list_of(dynamic_interval_list, tuple)
18 |
19 | dynamic_milestones = [0]
20 | dynamic_milestones.extend(
21 | [dynamic_interval[0] for dynamic_interval in dynamic_interval_list]
22 | )
23 | dynamic_intervals = [start_interval]
24 | dynamic_intervals.extend(
25 | [dynamic_interval[1] for dynamic_interval in dynamic_interval_list]
26 | )
27 | return dynamic_milestones, dynamic_intervals
28 |
29 |
30 | class CustomDistEvalHook(BaseDistEvalHook):
31 | def __init__(self, *args, dynamic_intervals=None, **kwargs):
32 | super(CustomDistEvalHook, self).__init__(*args, **kwargs)
33 | self.use_dynamic_intervals = dynamic_intervals is not None
34 | if self.use_dynamic_intervals:
35 | (
36 | self.dynamic_milestones,
37 | self.dynamic_intervals,
38 | ) = _calc_dynamic_intervals(self.interval, dynamic_intervals)
39 |
40 | def _decide_interval(self, runner):
41 | if self.use_dynamic_intervals:
42 | progress = runner.epoch if self.by_epoch else runner.iter
43 | step = bisect.bisect(self.dynamic_milestones, (progress + 1))
44 | # Dynamically modify the evaluation interval
45 | self.interval = self.dynamic_intervals[step - 1]
46 |
47 | def before_train_epoch(self, runner):
48 | """Evaluate the model only at the start of training by epoch."""
49 | self._decide_interval(runner)
50 | super().before_train_epoch(runner)
51 |
52 | def before_train_iter(self, runner):
53 | self._decide_interval(runner)
54 | super().before_train_iter(runner)
55 |
56 | def _do_evaluate(self, runner):
57 | """perform evaluation and save ckpt."""
58 | # Synchronization of BatchNorm's buffer (running_mean
59 | # and running_var) is not supported in the DDP of pytorch,
60 | # which may cause the inconsistent performance of models in
61 | # different ranks, so we broadcast BatchNorm's buffers
62 | # of rank 0 to other ranks to avoid this.
63 | if self.broadcast_bn_buffer:
64 | model = runner.model
65 | for name, module in model.named_modules():
66 | if (
67 | isinstance(module, _BatchNorm)
68 | and module.track_running_stats
69 | ):
70 | dist.broadcast(module.running_var, 0)
71 | dist.broadcast(module.running_mean, 0)
72 |
73 | if not self._should_evaluate(runner):
74 | return
75 |
76 | tmpdir = self.tmpdir
77 | if tmpdir is None:
78 | tmpdir = osp.join(runner.work_dir, ".eval_hook")
79 |
80 | from projects.mmdet3d_plugin.apis.test import (
81 | custom_multi_gpu_test,
82 | ) # to solve circlur import
83 |
84 | results = custom_multi_gpu_test(
85 | runner.model,
86 | self.dataloader,
87 | tmpdir=tmpdir,
88 | gpu_collect=self.gpu_collect,
89 | )
90 | if runner.rank == 0:
91 | print("\n")
92 | runner.log_buffer.output["eval_iter_num"] = len(self.dataloader)
93 |
94 | key_score = self.evaluate(runner, results)
95 |
96 | if self.save_best:
97 | self._save_ckpt(runner, key_score)
98 |
--------------------------------------------------------------------------------
/tools/benchmark.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import time
4 | import torch
5 | from mmcv import Config
6 | from mmcv.parallel import MMDataParallel
7 | from mmcv.runner import load_checkpoint, wrap_fp16_model
8 | import sys
9 |
10 | sys.path.append(".")
11 | from projects.mmdet3d_plugin.datasets.builder import build_dataloader
12 | from projects.mmdet3d_plugin.datasets import custom_build_dataset
13 |
14 | # from mmdet3d.datasets import build_dataloader, build_dataset
15 | from mmdet.models import build_detector
16 |
17 | from tools.fuse_conv_bn import fuse_module
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser(description="MMDet benchmark a model")
22 | parser.add_argument("config", help="test config file path")
23 | parser.add_argument("--checkpoint", default=None, help="checkpoint file")
24 | parser.add_argument("--samples", default=2000, help="samples to benchmark")
25 | parser.add_argument(
26 | "--log-interval", default=50, help="interval of logging"
27 | )
28 | parser.add_argument(
29 | "--fuse-conv-bn",
30 | action="store_true",
31 | help="Whether to fuse conv and bn, this will slightly increase"
32 | "the inference speed",
33 | )
34 | args = parser.parse_args()
35 | return args
36 |
37 |
38 | def get_max_memory(model):
39 | device = getattr(model, "output_device", None)
40 | mem = torch.cuda.max_memory_allocated(device=device)
41 | mem_mb = torch.tensor(
42 | [mem / (1024 * 1024)], dtype=torch.int, device=device
43 | )
44 | return mem_mb.item()
45 |
46 |
47 | def main():
48 | args = parse_args()
49 |
50 | cfg = Config.fromfile(args.config)
51 | # set cudnn_benchmark
52 | if cfg.get("cudnn_benchmark", False):
53 | torch.backends.cudnn.benchmark = True
54 | cfg.model.pretrained = None
55 | cfg.data.test.test_mode = True
56 |
57 | # build the dataloader
58 | # TODO: support multiple images per gpu (only minor changes are needed)
59 | print(cfg.data.test)
60 | dataset = custom_build_dataset(cfg.data.test)
61 | data_loader = build_dataloader(
62 | dataset,
63 | samples_per_gpu=1,
64 | workers_per_gpu=cfg.data.workers_per_gpu,
65 | dist=False,
66 | shuffle=False,
67 | )
68 |
69 | # build the model and load checkpoint
70 | cfg.model.train_cfg = None
71 | model = build_detector(cfg.model, test_cfg=cfg.get("test_cfg"))
72 | fp16_cfg = cfg.get("fp16", None)
73 | if fp16_cfg is not None:
74 | wrap_fp16_model(model)
75 | if args.checkpoint is not None:
76 | load_checkpoint(model, args.checkpoint, map_location="cpu")
77 | if args.fuse_conv_bn:
78 | model = fuse_module(model)
79 |
80 | model = MMDataParallel(model, device_ids=[0])
81 |
82 | model.eval()
83 |
84 | # the first several iterations may be very slow so skip them
85 | num_warmup = 5
86 | pure_inf_time = 0
87 |
88 | # benchmark with several samples and take the average
89 | max_memory = 0
90 | for i, data in enumerate(data_loader):
91 | # torch.cuda.synchronize()
92 | with torch.no_grad():
93 | start_time = time.perf_counter()
94 | model(return_loss=False, rescale=True, **data)
95 |
96 | torch.cuda.synchronize()
97 | elapsed = time.perf_counter() - start_time
98 | max_memory = max(max_memory, get_max_memory(model))
99 |
100 | if i >= num_warmup:
101 | pure_inf_time += elapsed
102 | if (i + 1) % args.log_interval == 0:
103 | fps = (i + 1 - num_warmup) / pure_inf_time
104 | print(
105 | f"Done image [{i + 1:<3}/ {args.samples}], "
106 | f"fps: {fps:.1f} img / s, "
107 | f"gpu mem: {max_memory} M"
108 | )
109 |
110 | if (i + 1) == args.samples:
111 | pure_inf_time += elapsed
112 | fps = (i + 1 - num_warmup) / pure_inf_time
113 | print(f"Overall fps: {fps:.1f} img / s")
114 | break
115 |
116 |
117 | if __name__ == "__main__":
118 | main()
119 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/samplers/group_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import math
3 |
4 | import numpy as np
5 | import torch
6 | from mmcv.runner import get_dist_info
7 | from torch.utils.data import Sampler
8 | from .sampler import SAMPLER
9 | import random
10 | from IPython import embed
11 |
12 |
13 | @SAMPLER.register_module()
14 | class DistributedGroupSampler(Sampler):
15 | """Sampler that restricts data loading to a subset of the dataset.
16 | It is especially useful in conjunction with
17 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
18 | process can pass a DistributedSampler instance as a DataLoader sampler,
19 | and load a subset of the original dataset that is exclusive to it.
20 | .. note::
21 | Dataset is assumed to be of constant size.
22 | Arguments:
23 | dataset: Dataset used for sampling.
24 | num_replicas (optional): Number of processes participating in
25 | distributed training.
26 | rank (optional): Rank of the current process within num_replicas.
27 | seed (int, optional): random seed used to shuffle the sampler if
28 | ``shuffle=True``. This number should be identical across all
29 | processes in the distributed group. Default: 0.
30 | """
31 |
32 | def __init__(
33 | self, dataset, samples_per_gpu=1, num_replicas=None, rank=None, seed=0
34 | ):
35 | _rank, _num_replicas = get_dist_info()
36 | if num_replicas is None:
37 | num_replicas = _num_replicas
38 | if rank is None:
39 | rank = _rank
40 | self.dataset = dataset
41 | self.samples_per_gpu = samples_per_gpu
42 | self.num_replicas = num_replicas
43 | self.rank = rank
44 | self.epoch = 0
45 | self.seed = seed if seed is not None else 0
46 |
47 | assert hasattr(self.dataset, "flag")
48 | self.flag = self.dataset.flag
49 | self.group_sizes = np.bincount(self.flag)
50 |
51 | self.num_samples = 0
52 | for i, j in enumerate(self.group_sizes):
53 | self.num_samples += (
54 | int(
55 | math.ceil(
56 | self.group_sizes[i]
57 | * 1.0
58 | / self.samples_per_gpu
59 | / self.num_replicas
60 | )
61 | )
62 | * self.samples_per_gpu
63 | )
64 | self.total_size = self.num_samples * self.num_replicas
65 |
66 | def __iter__(self):
67 | # deterministically shuffle based on epoch
68 | g = torch.Generator()
69 | g.manual_seed(self.epoch + self.seed)
70 |
71 | indices = []
72 | for i, size in enumerate(self.group_sizes):
73 | if size > 0:
74 | indice = np.where(self.flag == i)[0]
75 | assert len(indice) == size
76 | # add .numpy() to avoid bug when selecting indice in parrots.
77 | # TODO: check whether torch.randperm() can be replaced by
78 | # numpy.random.permutation().
79 | indice = indice[
80 | list(torch.randperm(int(size), generator=g).numpy())
81 | ].tolist()
82 | extra = int(
83 | math.ceil(
84 | size * 1.0 / self.samples_per_gpu / self.num_replicas
85 | )
86 | ) * self.samples_per_gpu * self.num_replicas - len(indice)
87 | # pad indice
88 | tmp = indice.copy()
89 | for _ in range(extra // size):
90 | indice.extend(tmp)
91 | indice.extend(tmp[: extra % size])
92 | indices.extend(indice)
93 |
94 | assert len(indices) == self.total_size
95 |
96 | indices = [
97 | indices[j]
98 | for i in list(
99 | torch.randperm(
100 | len(indices) // self.samples_per_gpu, generator=g
101 | )
102 | )
103 | for j in range(
104 | i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu
105 | )
106 | ]
107 |
108 | # subsample
109 | offset = self.num_samples * self.rank
110 | indices = indices[offset : offset + self.num_samples]
111 | assert len(indices) == self.num_samples
112 |
113 | return iter(indices)
114 |
115 | def __len__(self):
116 | return self.num_samples
117 |
118 | def set_epoch(self, epoch):
119 | self.epoch = epoch
120 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/grid_mask.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | class Grid(object):
8 | def __init__(
9 | self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0
10 | ):
11 | self.use_h = use_h
12 | self.use_w = use_w
13 | self.rotate = rotate
14 | self.offset = offset
15 | self.ratio = ratio
16 | self.mode = mode
17 | self.st_prob = prob
18 | self.prob = prob
19 |
20 | def set_prob(self, epoch, max_epoch):
21 | self.prob = self.st_prob * epoch / max_epoch
22 |
23 | def __call__(self, img, label):
24 | if np.random.rand() > self.prob:
25 | return img, label
26 | h = img.size(1)
27 | w = img.size(2)
28 | self.d1 = 2
29 | self.d2 = min(h, w)
30 | hh = int(1.5 * h)
31 | ww = int(1.5 * w)
32 | d = np.random.randint(self.d1, self.d2)
33 | if self.ratio == 1:
34 | self.l = np.random.randint(1, d)
35 | else:
36 | self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
37 | mask = np.ones((hh, ww), np.float32)
38 | st_h = np.random.randint(d)
39 | st_w = np.random.randint(d)
40 | if self.use_h:
41 | for i in range(hh // d):
42 | s = d * i + st_h
43 | t = min(s + self.l, hh)
44 | mask[s:t, :] *= 0
45 | if self.use_w:
46 | for i in range(ww // d):
47 | s = d * i + st_w
48 | t = min(s + self.l, ww)
49 | mask[:, s:t] *= 0
50 |
51 | r = np.random.randint(self.rotate)
52 | mask = Image.fromarray(np.uint8(mask))
53 | mask = mask.rotate(r)
54 | mask = np.asarray(mask)
55 | mask = mask[
56 | (hh - h) // 2 : (hh - h) // 2 + h,
57 | (ww - w) // 2 : (ww - w) // 2 + w,
58 | ]
59 |
60 | mask = torch.from_numpy(mask).float()
61 | if self.mode == 1:
62 | mask = 1 - mask
63 |
64 | mask = mask.expand_as(img)
65 | if self.offset:
66 | offset = torch.from_numpy(2 * (np.random.rand(h, w) - 0.5)).float()
67 | offset = (1 - mask) * offset
68 | img = img * mask + offset
69 | else:
70 | img = img * mask
71 |
72 | return img, label
73 |
74 |
75 | class GridMask(nn.Module):
76 | def __init__(
77 | self, use_h, use_w, rotate=1, offset=False, ratio=0.5, mode=0, prob=1.0
78 | ):
79 | super(GridMask, self).__init__()
80 | self.use_h = use_h
81 | self.use_w = use_w
82 | self.rotate = rotate
83 | self.offset = offset
84 | self.ratio = ratio
85 | self.mode = mode
86 | self.st_prob = prob
87 | self.prob = prob
88 |
89 | def set_prob(self, epoch, max_epoch):
90 | self.prob = self.st_prob * epoch / max_epoch # + 1.#0.5
91 |
92 | def forward(self, x):
93 | if np.random.rand() > self.prob or not self.training:
94 | return x
95 | n, c, h, w = x.size()
96 | x = x.view(-1, h, w)
97 | hh = int(1.5 * h)
98 | ww = int(1.5 * w)
99 | d = np.random.randint(2, h)
100 | self.l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
101 | mask = np.ones((hh, ww), np.float32)
102 | st_h = np.random.randint(d)
103 | st_w = np.random.randint(d)
104 | if self.use_h:
105 | for i in range(hh // d):
106 | s = d * i + st_h
107 | t = min(s + self.l, hh)
108 | mask[s:t, :] *= 0
109 | if self.use_w:
110 | for i in range(ww // d):
111 | s = d * i + st_w
112 | t = min(s + self.l, ww)
113 | mask[:, s:t] *= 0
114 |
115 | r = np.random.randint(self.rotate)
116 | mask = Image.fromarray(np.uint8(mask))
117 | mask = mask.rotate(r)
118 | mask = np.asarray(mask)
119 | mask = mask[
120 | (hh - h) // 2 : (hh - h) // 2 + h,
121 | (ww - w) // 2 : (ww - w) // 2 + w,
122 | ]
123 |
124 | mask = torch.from_numpy(mask.copy()).float().cuda()
125 | if self.mode == 1:
126 | mask = 1 - mask
127 | mask = mask.expand_as(x)
128 | if self.offset:
129 | offset = (
130 | torch.from_numpy(2 * (np.random.rand(h, w) - 0.5))
131 | .float()
132 | .cuda()
133 | )
134 | x = x * mask + offset * (1 - mask)
135 | else:
136 | x = x * mask
137 |
138 | return x.view(n, c, h, w)
139 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/simpb.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from inspect import signature
4 | from mmcv.runner import force_fp32, auto_fp16
5 | from mmcv.utils import build_from_cfg
6 | from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
7 | from mmdet.models import (
8 | DETECTORS,
9 | BaseDetector,
10 | build_backbone,
11 | build_head,
12 | build_neck,
13 | )
14 | from .grid_mask import GridMask
15 |
16 | try:
17 | from ..ops import feature_maps_format
18 | DAF_VALID = True
19 | except:
20 | DAF_VALID = False
21 |
22 | __all__ = ["SimPB"]
23 |
24 |
25 | @DETECTORS.register_module()
26 | class SimPB(BaseDetector):
27 | def __init__(
28 | self,
29 | img_backbone,
30 | head,
31 | img_neck=None,
32 | init_cfg=None,
33 | train_cfg=None,
34 | test_cfg=None,
35 | pretrained=None,
36 | use_grid_mask=True,
37 | use_deformable_func=False,
38 | depth_branch=None,
39 | ):
40 | super(SimPB, self).__init__(init_cfg=init_cfg)
41 | if pretrained is not None:
42 | backbone.pretrained = pretrained
43 | self.img_backbone = build_backbone(img_backbone)
44 | if img_neck is not None:
45 | self.img_neck = build_neck(img_neck)
46 | if test_cfg is not None:
47 | head['test_cfg'] = test_cfg
48 | self.head = build_head(head)
49 | self.use_grid_mask = use_grid_mask
50 | if use_deformable_func:
51 | assert DAF_VALID, "deformable_aggregation needs to be set up."
52 | self.use_deformable_func = use_deformable_func
53 | self.head.use_deformable_func = self.use_deformable_func
54 | if depth_branch is not None:
55 | self.depth_branch = build_from_cfg(depth_branch, PLUGIN_LAYERS)
56 | else:
57 | self.depth_branch = None
58 | if use_grid_mask:
59 | self.grid_mask = GridMask(
60 | True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
61 | )
62 |
63 | @auto_fp16(apply_to=("img",), out_fp32=True)
64 | def extract_feat(self, img, return_depth=False, metas=None):
65 | bs = img.shape[0]
66 | if img.dim() == 5: # multi-view
67 | num_cams = img.shape[1]
68 | img = img.flatten(end_dim=1)
69 | else:
70 | num_cams = 1
71 | if self.use_grid_mask:
72 | img = self.grid_mask(img)
73 | if "metas" in signature(self.img_backbone.forward).parameters:
74 | feature_maps = self.img_backbone(img, num_cams, metas=metas)
75 | else:
76 | feature_maps = self.img_backbone(img)
77 | if self.img_neck is not None:
78 | feature_maps = list(self.img_neck(feature_maps))
79 | for i, feat in enumerate(feature_maps):
80 | feature_maps[i] = torch.reshape(
81 | feat, (bs, num_cams) + feat.shape[1:]
82 | )
83 | if return_depth and self.depth_branch is not None:
84 | depths = self.depth_branch(feature_maps, metas.get("focal"))
85 | else:
86 | depths = None
87 | if self.use_deformable_func:
88 | feature_maps = feature_maps_format(feature_maps)
89 | if return_depth:
90 | return feature_maps, depths
91 | return feature_maps
92 |
93 | @force_fp32(apply_to=("img",))
94 | def forward(self, img, **data):
95 | if self.training:
96 | return self.forward_train(img, **data)
97 | else:
98 | return self.forward_test(img, **data)
99 |
100 | def forward_train(self, img, **data):
101 | feature_maps, depths = self.extract_feat(img, True, data)
102 | model_outs = self.head(feature_maps, data)
103 | output = self.head.loss(model_outs, data)
104 | if depths is not None and "gt_depth" in data:
105 | output["loss_dense_depth"] = self.depth_branch.loss(
106 | depths, data["gt_depth"]
107 | )
108 | return output
109 |
110 | def forward_test(self, img, **data):
111 | if isinstance(img, list):
112 | return self.aug_test(img, **data)
113 | else:
114 | return self.simple_test(img, **data)
115 |
116 | def simple_test(self, img, **data):
117 | feature_maps = self.extract_feat(img)
118 |
119 | model_outs = self.head(feature_maps, data)
120 | results = self.head.post_process(model_outs, data)
121 |
122 | return results
123 |
124 | def aug_test(self, img, **data):
125 | # fake test time augmentation
126 | for key in data.keys():
127 | if isinstance(data[key], list):
128 | data[key] = data[key][0]
129 | return self.simple_test(img[0], **data)
130 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/ops/src/deformable_aggregation.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | void deformable_aggregation(
5 | float* output,
6 | const float* mc_ms_feat,
7 | const int* spatial_shape,
8 | const int* scale_start_index,
9 | const float* sample_location,
10 | const float* weights,
11 | int batch_size,
12 | int num_cams,
13 | int num_feat,
14 | int num_embeds,
15 | int num_scale,
16 | int num_anchors,
17 | int num_pts,
18 | int num_groups
19 | );
20 |
21 |
22 | /* feat: bs, num_feat, c */
23 | /* _spatial_shape: cam, scale, 2 */
24 | /* _scale_start_index: cam, scale */
25 | /* _sampling_location: bs, anchor, pts, cam, 2 */
26 | /* _weights: bs, anchor, pts, cam, scale, group */
27 | /* output: bs, anchor, c */
28 | /* kernel: bs, anchor, pts, c */
29 |
30 |
31 | at::Tensor deformable_aggregation_forward(
32 | const at::Tensor &_mc_ms_feat,
33 | const at::Tensor &_spatial_shape,
34 | const at::Tensor &_scale_start_index,
35 | const at::Tensor &_sampling_location,
36 | const at::Tensor &_weights
37 | ) {
38 | at::DeviceGuard guard(_mc_ms_feat.device());
39 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
40 | int batch_size = _mc_ms_feat.size(0);
41 | int num_feat = _mc_ms_feat.size(1);
42 | int num_embeds = _mc_ms_feat.size(2);
43 | int num_cams = _spatial_shape.size(0);
44 | int num_scale = _spatial_shape.size(1);
45 | int num_anchors = _sampling_location.size(1);
46 | int num_pts = _sampling_location.size(2);
47 | int num_groups = _weights.size(5);
48 |
49 | const float* mc_ms_feat = _mc_ms_feat.data_ptr();
50 | const int* spatial_shape = _spatial_shape.data_ptr();
51 | const int* scale_start_index = _scale_start_index.data_ptr();
52 | const float* sampling_location = _sampling_location.data_ptr();
53 | const float* weights = _weights.data_ptr();
54 |
55 | auto output = at::zeros({batch_size, num_anchors, num_embeds}, _mc_ms_feat.options());
56 | deformable_aggregation(
57 | output.data_ptr(),
58 | mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
59 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
60 | );
61 | return output;
62 | }
63 |
64 |
65 | void deformable_aggregation_grad(
66 | const float* mc_ms_feat,
67 | const int* spatial_shape,
68 | const int* scale_start_index,
69 | const float* sample_location,
70 | const float* weights,
71 | const float* grad_output,
72 | float* grad_mc_ms_feat,
73 | float* grad_sampling_location,
74 | float* grad_weights,
75 | int batch_size,
76 | int num_cams,
77 | int num_feat,
78 | int num_embeds,
79 | int num_scale,
80 | int num_anchors,
81 | int num_pts,
82 | int num_groups
83 | );
84 |
85 |
86 | void deformable_aggregation_backward(
87 | const at::Tensor &_mc_ms_feat,
88 | const at::Tensor &_spatial_shape,
89 | const at::Tensor &_scale_start_index,
90 | const at::Tensor &_sampling_location,
91 | const at::Tensor &_weights,
92 | const at::Tensor &_grad_output,
93 | at::Tensor &_grad_mc_ms_feat,
94 | at::Tensor &_grad_sampling_location,
95 | at::Tensor &_grad_weights
96 | ) {
97 | at::DeviceGuard guard(_mc_ms_feat.device());
98 | const at::cuda::OptionalCUDAGuard device_guard(device_of(_mc_ms_feat));
99 | int batch_size = _mc_ms_feat.size(0);
100 | int num_feat = _mc_ms_feat.size(1);
101 | int num_embeds = _mc_ms_feat.size(2);
102 | int num_cams = _spatial_shape.size(0);
103 | int num_scale = _spatial_shape.size(1);
104 | int num_anchors = _sampling_location.size(1);
105 | int num_pts = _sampling_location.size(2);
106 | int num_groups = _weights.size(5);
107 |
108 | const float* mc_ms_feat = _mc_ms_feat.data_ptr();
109 | const int* spatial_shape = _spatial_shape.data_ptr();
110 | const int* scale_start_index = _scale_start_index.data_ptr();
111 | const float* sampling_location = _sampling_location.data_ptr();
112 | const float* weights = _weights.data_ptr();
113 | const float* grad_output = _grad_output.data_ptr();
114 |
115 | float* grad_mc_ms_feat = _grad_mc_ms_feat.data_ptr();
116 | float* grad_sampling_location = _grad_sampling_location.data_ptr();
117 | float* grad_weights = _grad_weights.data_ptr();
118 |
119 | deformable_aggregation_grad(
120 | mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights,
121 | grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
122 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
123 | );
124 | }
125 |
126 |
127 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
128 | m.def(
129 | "deformable_aggregation_forward",
130 | &deformable_aggregation_forward,
131 | "deformable_aggregation_forward"
132 | );
133 | m.def(
134 | "deformable_aggregation_backward",
135 | &deformable_aggregation_backward,
136 | "deformable_aggregation_backward"
137 | );
138 | }
139 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/aggregation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from mmcv.cnn import Linear
6 | from mmcv.utils import build_from_cfg
7 | from mmcv.cnn.bricks.registry import ATTENTION, PLUGIN_LAYERS
8 |
9 |
10 | class ReWeight(nn.Module):
11 | def __init__(self, c_dim, f_dim=256, trans=True, with_pos=False):
12 | super().__init__()
13 | self.c_dim = c_dim
14 | self.f_dim = f_dim
15 | self.trans = trans
16 | self.with_pos = with_pos
17 |
18 | self.reduce = nn.Sequential(
19 | nn.Linear(c_dim, f_dim),
20 | nn.ReLU(),
21 | )
22 | self.alpha = nn.Sequential(
23 | nn.Linear(f_dim, 1),
24 | nn.Sigmoid()
25 | )
26 |
27 | def forward(self, query, query_pos, parameter, trans_matrix=None):
28 |
29 | alpha = self.alpha(self.reduce(parameter))
30 |
31 | if self.trans:
32 | reweight_matrix = (trans_matrix * alpha).permute(0, 2, 1)
33 | reweight_divisor = torch.clamp(reweight_matrix.sum(-1).unsqueeze(-1), 1e-5)
34 | query = torch.div(torch.matmul(reweight_matrix, query), reweight_divisor)
35 | query_pos = torch.div(torch.matmul(reweight_matrix, query_pos), reweight_divisor) if self.with_pos else None
36 | else:
37 | query = alpha * query
38 | query_pos = alpha * query_pos if self.with_pos else None
39 |
40 | return query, query_pos
41 |
42 |
43 | @PLUGIN_LAYERS.register_module()
44 | class AdaptiveQueryAggregation(nn.Module):
45 | def __init__(self, self_attn=None, reweight=None, decouple_attn=False, with_pos=False):
46 | super().__init__()
47 | self.with_pos = with_pos
48 | self.decouple_attn = decouple_attn
49 | trans = True if self_attn is not None else False
50 | self.reweight = ReWeight(c_dim=257, trans=trans, with_pos=with_pos) if reweight is not None else None
51 | self.self_attn = build_from_cfg(self_attn, ATTENTION) if self_attn is not None else None
52 |
53 |
54 | def forward(self,
55 | query2d, query_pos2d, anchor2d,
56 | query3d, query_pos3d, anchor3d,
57 | dn_query2d=None, dn_query_pos2d=None, dn_anchor2d=None,
58 | dn_query3d=None, dn_query_pos3d=None, dn_anchor3d=None,
59 | trans_matrix=None, center_matrix=None, dn_trans_matrix=None, dn_center_matrix=None,
60 | attn_mask=None, graph_model=None, **kwargs):
61 |
62 | if self.reweight is not None:
63 | center_param = torch.cat([query2d, center_matrix.sum(-1).unsqueeze(-1)], dim=-1)
64 | query3d_from2d, query_pos3d_from2d = self.reweight(query2d, query_pos2d, center_param, trans_matrix)
65 |
66 | if dn_query2d is not None:
67 | dn_center_param = torch.cat([dn_query2d, dn_center_matrix.sum(-1).unsqueeze(-1)], dim=-1)
68 | dn_query3d_from2d, dn_query_pos3d_from2d = self.reweight(dn_query2d, dn_query_pos2d, dn_center_param, dn_trans_matrix)
69 |
70 | else:
71 | trans_matrix_t = trans_matrix.permute(0, 2, 1)
72 | trans_divisor = torch.clamp(trans_matrix_t.sum(-1).unsqueeze(-1), 1e-5)
73 | query3d_from2d = torch.div(torch.matmul(trans_matrix_t, query2d), trans_divisor)
74 | query_pos3d_from2d = torch.div(torch.matmul(trans_matrix_t, query_pos2d), trans_divisor) if self.with_pos else None
75 |
76 | if dn_query2d is not None:
77 | query3d_from2d = torch.div(torch.matmul(trans_matrix_t, query2d), trans_divisor)
78 | query_pos3d_from2d = torch.div(torch.matmul(trans_matrix_t, query_pos2d), trans_divisor) if self.with_pos else None
79 |
80 | # merge with denoise
81 | if dn_query3d is not None:
82 | query3d = torch.cat([query3d, dn_query3d], dim=1)
83 | query_pos3d = torch.cat([query_pos3d, dn_query_pos3d], dim=1)
84 | anchor3d = torch.cat([anchor3d, dn_anchor3d], dim=1)
85 |
86 | if dn_query2d is not None:
87 | query3d_from2d = torch.cat([query3d_from2d, dn_query3d_from2d], dim=1)
88 | query_pos3d_from2d = torch.cat([query_pos3d_from2d, dn_query_pos3d_from2d], dim=1) if self.with_pos else None
89 | else:
90 | query3d_from2d = torch.cat([query3d_from2d, torch.zeros_like(dn_query3d)], dim=1)
91 | query_pos3d_from2d = torch.cat([query_pos3d_from2d, torch.zeros_like(dn_query3d)], dim=1) if self.with_pos else None
92 |
93 | query3d = query3d + query3d_from2d
94 | query_pos3d = query_pos3d + query_pos3d_from2d if self.with_pos else query_pos3d
95 |
96 | aggregated_query3d = graph_model(self.self_attn,
97 | query=query3d,
98 | query_pos=query_pos3d,
99 | attn_mask=attn_mask)
100 |
101 | return aggregated_query3d, query_pos3d, anchor3d
102 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/detection2d/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from mmcv.cnn import Linear, Scale, bias_init_with_prob
6 | from mmcv.runner.base_module import Sequential, BaseModule
7 | from mmcv.cnn.bricks.registry import (
8 | PLUGIN_LAYERS,
9 | POSITIONAL_ENCODING,
10 | )
11 |
12 | from ..blocks import linear_relu_ln
13 | from ..utils import inverse_sigmoid, pos2posemb2d
14 |
15 | __all__ = [
16 | "SparseBox2DRefinementModule",
17 | "SparseBox2DEncoder",
18 | ]
19 |
20 | @POSITIONAL_ENCODING.register_module()
21 | class SparseBox2DEncoder(BaseModule):
22 | def __init__(
23 | self,
24 | embed_dims=256,
25 | with_size=False,
26 | with_sin_embed=False,
27 | mode="add",
28 | in_loops=1,
29 | out_loops=2,
30 | ):
31 | super(SparseBox2DEncoder, self).__init__()
32 | self.embed_dims = embed_dims
33 | self.mode = mode
34 | self.with_size = with_size
35 | self.with_sin_embed = with_sin_embed
36 |
37 | def embedding_layer(input_dims):
38 | return nn.Sequential(*linear_relu_ln(embed_dims, in_loops, out_loops, input_dims))
39 |
40 | if self.with_sin_embed:
41 | self.query_embeddings2d = embedding_layer(256)
42 | else:
43 | self.pos_fc = embedding_layer(2)
44 | if self.with_size:
45 | self.size_fc = embedding_layer(2)
46 | self.output_fc = embedding_layer(self.embed_dims)
47 |
48 | def forward(self, box_2d):
49 | if self.with_sin_embed:
50 | output = self.query_embeddings2d(pos2posemb2d(box_2d))
51 | else:
52 | pos_feat = self.pos_fc(box_2d[..., :2])
53 | if self.with_size:
54 | size_feat = self.size_fc(box_2d[..., 2:4])
55 | if self.mode == "add":
56 | output = pos_feat + size_feat
57 | elif self.mode == "cat":
58 | output = torch.cat([pos_feat, size_feat], dim=-1)
59 | output = self.output_fc(output)
60 | else:
61 | output = pos_feat
62 |
63 | return output
64 |
65 | @PLUGIN_LAYERS.register_module()
66 | class SparseBox2DRefinementModule(BaseModule):
67 | def __init__(self, embed_dims=256, output_dim=4, num_cls=10, alpha_dim=2,
68 | with_cls_branch=True, with_alpha_branch=False, with_depth_branch=False,
69 | with_multibin_depth=False, depth_bin_num=64):
70 | super(SparseBox2DRefinementModule, self).__init__()
71 | self.embed_dims = embed_dims
72 | self.output_dim = output_dim
73 | self.num_cls = num_cls
74 |
75 | self.layers = nn.Sequential(
76 | *linear_relu_ln(embed_dims, 2, 2),
77 | Linear(self.embed_dims, self.output_dim),
78 | Scale([1.0] * self.output_dim)
79 | )
80 |
81 | self.with_cls_branch = with_cls_branch
82 | if with_cls_branch:
83 | self.cls_layers = nn.Sequential(
84 | *linear_relu_ln(embed_dims, 1, 2),
85 | Linear(self.embed_dims, self.num_cls),
86 | )
87 | self.with_alpha_branch = with_alpha_branch
88 | if with_alpha_branch:
89 | self.alpha_layers = nn.Sequential(
90 | *linear_relu_ln(embed_dims, 1, 2),
91 | Linear(self.embed_dims, alpha_dim),
92 | Scale([1.0] * 2)
93 | )
94 | self.with_depth_branch = with_depth_branch
95 | self.with_multibin_depth = with_multibin_depth
96 | if with_depth_branch:
97 | if with_multibin_depth:
98 | self.depth_layers = nn.Sequential(
99 | *linear_relu_ln(embed_dims, 2, 2),
100 | Linear(self.embed_dims, depth_bin_num),
101 | )
102 | else:
103 | self.depth_layers = nn.Sequential(
104 | *linear_relu_ln(embed_dims, 2, 2),
105 | Linear(self.embed_dims, 1),
106 | Scale([1.0] * 1)
107 | )
108 |
109 | def init_weight(self):
110 | if self.with_cls_branch:
111 | bias_init = bias_init_with_prob(0.01)
112 | nn.init.constant_(self.cls_layers[-1].bias, bias_init)
113 | if self.with_multibin_depth:
114 | bias_init = bias_init_with_prob(0.01)
115 | nn.init.constant_(self.depth_layers[-1].bias, bias_init)
116 |
117 | def forward(self, instance_feature, anchor2d, anchor2d_embed,
118 | metas=None, return_cls=True, query_groups=None):
119 | output = self.layers(instance_feature + anchor2d_embed)
120 |
121 | if anchor2d.shape[-1] == 2:
122 | output[..., :2] = output[..., :2] + inverse_sigmoid(anchor2d)
123 | elif anchor2d.shape[-1] == 4:
124 | output[..., :4] = output[..., :4] + inverse_sigmoid(anchor2d)
125 |
126 | cls = None
127 | if return_cls:
128 | cls = self.cls_layers(instance_feature)
129 |
130 | alpha = None
131 | if self.with_alpha_branch:
132 | alpha = self.alpha_layers(instance_feature)
133 |
134 | depth = None
135 | if self.with_depth_branch:
136 | if self.with_multibin_depth:
137 | depth = self.depth_layers(instance_feature + anchor2d_embed)
138 | else:
139 | focal = torch.cat([
140 | metas['focal'][:, i:i+1].repeat(1, qg[1]-qg[0]) for i, qg in enumerate(query_groups)
141 | ], dim=-1)
142 | depth = self.depth_layers(instance_feature).exp()
143 | depth = depth * focal.unsqueeze(-1) / 100
144 |
145 | return output.sigmoid(), cls, depth, alpha
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/apis/test.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Zhiqi Li
5 | # ---------------------------------------------
6 | import os.path as osp
7 | import pickle
8 | import shutil
9 | import tempfile
10 | import time
11 |
12 | import mmcv
13 | import torch
14 | import torch.distributed as dist
15 | from mmcv.image import tensor2imgs
16 | from mmcv.runner import get_dist_info
17 |
18 | from mmdet.core import encode_mask_results
19 |
20 |
21 | import mmcv
22 | import numpy as np
23 | import pycocotools.mask as mask_util
24 |
25 |
26 | def custom_encode_mask_results(mask_results):
27 | """Encode bitmap mask to RLE code. Semantic Masks only
28 | Args:
29 | mask_results (list | tuple[list]): bitmap mask results.
30 | In mask scoring rcnn, mask_results is a tuple of (segm_results,
31 | segm_cls_score).
32 | Returns:
33 | list | tuple: RLE encoded mask.
34 | """
35 | cls_segms = mask_results
36 | num_classes = len(cls_segms)
37 | encoded_mask_results = []
38 | for i in range(len(cls_segms)):
39 | encoded_mask_results.append(
40 | mask_util.encode(
41 | np.array(
42 | cls_segms[i][:, :, np.newaxis], order="F", dtype="uint8"
43 | )
44 | )[0]
45 | ) # encoded with RLE
46 | return [encoded_mask_results]
47 |
48 |
49 | def custom_multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
50 | """Test model with multiple gpus.
51 | This method tests model with multiple gpus and collects the results
52 | under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
53 | it encodes results to gpu tensors and use gpu communication for results
54 | collection. On cpu mode it saves the results on different gpus to 'tmpdir'
55 | and collects them by the rank 0 worker.
56 | Args:
57 | model (nn.Module): Model to be tested.
58 | data_loader (nn.Dataloader): Pytorch data loader.
59 | tmpdir (str): Path of directory to save the temporary results from
60 | different gpus under cpu mode.
61 | gpu_collect (bool): Option to use either gpu or cpu to collect results.
62 | Returns:
63 | list: The prediction results.
64 | """
65 | model.eval()
66 | bbox_results = []
67 | mask_results = []
68 | dataset = data_loader.dataset
69 | rank, world_size = get_dist_info()
70 | if rank == 0:
71 | prog_bar = mmcv.ProgressBar(len(dataset))
72 | time.sleep(2) # This line can prevent deadlock problem in some cases.
73 | have_mask = False
74 | for i, data in enumerate(data_loader):
75 | with torch.no_grad():
76 | result = model(return_loss=False, rescale=True, **data)
77 | # encode mask results
78 | if isinstance(result, dict):
79 | if "bbox_results" in result.keys():
80 | bbox_result = result["bbox_results"]
81 | batch_size = len(result["bbox_results"])
82 | bbox_results.extend(bbox_result)
83 | if (
84 | "mask_results" in result.keys()
85 | and result["mask_results"] is not None
86 | ):
87 | mask_result = custom_encode_mask_results(
88 | result["mask_results"]
89 | )
90 | mask_results.extend(mask_result)
91 | have_mask = True
92 | else:
93 | batch_size = len(result)
94 | bbox_results.extend(result)
95 |
96 | if rank == 0:
97 | for _ in range(batch_size * world_size):
98 | prog_bar.update()
99 |
100 | # collect results from all ranks
101 | if gpu_collect:
102 | bbox_results = collect_results_gpu(bbox_results, len(dataset))
103 | if have_mask:
104 | mask_results = collect_results_gpu(mask_results, len(dataset))
105 | else:
106 | mask_results = None
107 | else:
108 | bbox_results = collect_results_cpu(bbox_results, len(dataset), tmpdir)
109 | tmpdir = tmpdir + "_mask" if tmpdir is not None else None
110 | if have_mask:
111 | mask_results = collect_results_cpu(
112 | mask_results, len(dataset), tmpdir
113 | )
114 | else:
115 | mask_results = None
116 |
117 | if mask_results is None:
118 | return bbox_results
119 | return {"bbox_results": bbox_results, "mask_results": mask_results}
120 |
121 |
122 | def collect_results_cpu(result_part, size, tmpdir=None):
123 | rank, world_size = get_dist_info()
124 | # create a tmp dir if it is not specified
125 | if tmpdir is None:
126 | MAX_LEN = 512
127 | # 32 is whitespace
128 | dir_tensor = torch.full(
129 | (MAX_LEN,), 32, dtype=torch.uint8, device="cuda"
130 | )
131 | if rank == 0:
132 | mmcv.mkdir_or_exist(".dist_test")
133 | tmpdir = tempfile.mkdtemp(dir=".dist_test")
134 | tmpdir = torch.tensor(
135 | bytearray(tmpdir.encode()), dtype=torch.uint8, device="cuda"
136 | )
137 | dir_tensor[: len(tmpdir)] = tmpdir
138 | dist.broadcast(dir_tensor, 0)
139 | tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
140 | else:
141 | mmcv.mkdir_or_exist(tmpdir)
142 | # dump the part result to the dir
143 | mmcv.dump(result_part, osp.join(tmpdir, f"part_{rank}.pkl"))
144 | dist.barrier()
145 | # collect all parts
146 | if rank != 0:
147 | return None
148 | else:
149 | # load results of all parts from tmp dir
150 | part_list = []
151 | for i in range(world_size):
152 | part_file = osp.join(tmpdir, f"part_{i}.pkl")
153 | part_list.append(mmcv.load(part_file))
154 | # sort the results
155 | ordered_results = []
156 | """
157 | bacause we change the sample of the evaluation stage to make sure that
158 | each gpu will handle continuous sample,
159 | """
160 | # for res in zip(*part_list):
161 | for res in part_list:
162 | ordered_results.extend(list(res))
163 | # the dataloader may pad some samples
164 | ordered_results = ordered_results[:size]
165 | # remove tmp dir
166 | shutil.rmtree(tmpdir)
167 | return ordered_results
168 |
169 |
170 | def collect_results_gpu(result_part, size):
171 | collect_results_cpu(result_part, size)
172 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/builder.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import platform
3 | import random
4 | from functools import partial
5 |
6 | import numpy as np
7 | from mmcv.parallel import collate
8 | from mmcv.runner import get_dist_info
9 | from mmcv.utils import Registry, build_from_cfg
10 | from torch.utils.data import DataLoader
11 |
12 | from mmdet.datasets.samplers import GroupSampler
13 | from projects.mmdet3d_plugin.datasets.samplers import (
14 | GroupInBatchSampler,
15 | DistributedGroupSampler,
16 | DistributedSampler,
17 | build_sampler
18 | )
19 |
20 |
21 | def build_dataloader(
22 | dataset,
23 | samples_per_gpu,
24 | workers_per_gpu,
25 | num_gpus=1,
26 | dist=True,
27 | shuffle=True,
28 | seed=None,
29 | shuffler_sampler=None,
30 | nonshuffler_sampler=None,
31 | runner_type="EpochBasedRunner",
32 | **kwargs
33 | ):
34 | """Build PyTorch DataLoader.
35 | In distributed training, each GPU/process has a dataloader.
36 | In non-distributed training, there is only one dataloader for all GPUs.
37 | Args:
38 | dataset (Dataset): A PyTorch dataset.
39 | samples_per_gpu (int): Number of training samples on each GPU, i.e.,
40 | batch size of each GPU.
41 | workers_per_gpu (int): How many subprocesses to use for data loading
42 | for each GPU.
43 | num_gpus (int): Number of GPUs. Only used in non-distributed training.
44 | dist (bool): Distributed training/test or not. Default: True.
45 | shuffle (bool): Whether to shuffle the data at every epoch.
46 | Default: True.
47 | kwargs: any keyword argument to be used to initialize DataLoader
48 | Returns:
49 | DataLoader: A PyTorch dataloader.
50 | """
51 | rank, world_size = get_dist_info()
52 | batch_sampler = None
53 | if runner_type == 'IterBasedRunner':
54 | print("Use GroupInBatchSampler !!!")
55 | batch_sampler = GroupInBatchSampler(
56 | dataset,
57 | samples_per_gpu,
58 | world_size,
59 | rank,
60 | seed=seed,
61 | )
62 | batch_size = 1
63 | sampler = None
64 | num_workers = workers_per_gpu
65 | elif dist:
66 | # DistributedGroupSampler will definitely shuffle the data to satisfy
67 | # that images on each GPU are in the same group
68 | if shuffle:
69 | print("Use DistributedGroupSampler !!!")
70 | sampler = build_sampler(
71 | shuffler_sampler
72 | if shuffler_sampler is not None
73 | else dict(type="DistributedGroupSampler"),
74 | dict(
75 | dataset=dataset,
76 | samples_per_gpu=samples_per_gpu,
77 | num_replicas=world_size,
78 | rank=rank,
79 | seed=seed,
80 | ),
81 | )
82 | else:
83 | sampler = build_sampler(
84 | nonshuffler_sampler
85 | if nonshuffler_sampler is not None
86 | else dict(type="DistributedSampler"),
87 | dict(
88 | dataset=dataset,
89 | num_replicas=world_size,
90 | rank=rank,
91 | shuffle=shuffle,
92 | seed=seed,
93 | ),
94 | )
95 |
96 | batch_size = samples_per_gpu
97 | num_workers = workers_per_gpu
98 | else:
99 | # assert False, 'not support in bevformer'
100 | print("WARNING!!!!, Only can be used for obtain inference speed!!!!")
101 | sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
102 | batch_size = num_gpus * samples_per_gpu
103 | num_workers = num_gpus * workers_per_gpu
104 |
105 | init_fn = (
106 | partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed)
107 | if seed is not None
108 | else None
109 | )
110 |
111 | data_loader = DataLoader(
112 | dataset,
113 | batch_size=batch_size,
114 | sampler=sampler,
115 | batch_sampler=batch_sampler,
116 | num_workers=num_workers,
117 | collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
118 | pin_memory=False,
119 | worker_init_fn=init_fn,
120 | **kwargs
121 | )
122 |
123 | return data_loader
124 |
125 |
126 | def worker_init_fn(worker_id, num_workers, rank, seed):
127 | # The seed of each worker equals to
128 | # num_worker * rank + worker_id + user_seed
129 | worker_seed = num_workers * rank + worker_id + seed
130 | np.random.seed(worker_seed)
131 | random.seed(worker_seed)
132 |
133 |
134 | # Copyright (c) OpenMMLab. All rights reserved.
135 | import platform
136 | from mmcv.utils import Registry, build_from_cfg
137 |
138 | from mmdet.datasets import DATASETS
139 | from mmdet.datasets.builder import _concat_dataset
140 |
141 | if platform.system() != "Windows":
142 | # https://github.com/pytorch/pytorch/issues/973
143 | import resource
144 |
145 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
146 | base_soft_limit = rlimit[0]
147 | hard_limit = rlimit[1]
148 | soft_limit = min(max(4096, base_soft_limit), hard_limit)
149 | resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
150 |
151 | OBJECTSAMPLERS = Registry("Object sampler")
152 |
153 |
154 | def custom_build_dataset(cfg, default_args=None):
155 | try:
156 | from mmdet3d.datasets.dataset_wrappers import CBGSDataset
157 | except:
158 | CBGSDataset = None
159 | from mmdet.datasets.dataset_wrappers import (
160 | ClassBalancedDataset,
161 | ConcatDataset,
162 | RepeatDataset,
163 | )
164 |
165 | if isinstance(cfg, (list, tuple)):
166 | dataset = ConcatDataset(
167 | [custom_build_dataset(c, default_args) for c in cfg]
168 | )
169 | elif cfg["type"] == "ConcatDataset":
170 | dataset = ConcatDataset(
171 | [custom_build_dataset(c, default_args) for c in cfg["datasets"]],
172 | cfg.get("separate_eval", True),
173 | )
174 | elif cfg["type"] == "RepeatDataset":
175 | dataset = RepeatDataset(
176 | custom_build_dataset(cfg["dataset"], default_args), cfg["times"]
177 | )
178 | elif cfg["type"] == "ClassBalancedDataset":
179 | dataset = ClassBalancedDataset(
180 | custom_build_dataset(cfg["dataset"], default_args),
181 | cfg["oversample_thr"],
182 | )
183 | elif cfg["type"] == "CBGSDataset":
184 | dataset = CBGSDataset(
185 | custom_build_dataset(cfg["dataset"], default_args)
186 | )
187 | elif isinstance(cfg.get("ann_file"), (list, tuple)):
188 | dataset = _concat_dataset(cfg, default_args)
189 | else:
190 | dataset = build_from_cfg(cfg, DATASETS, default_args)
191 |
192 | return dataset
193 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import copy
3 |
4 | import numpy as np
5 | import torch
6 | import torch.distributed as dist
7 | from mmcv.runner import get_dist_info
8 | from torch.utils.data.sampler import Sampler
9 |
10 | # https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157
11 | def sync_random_seed(seed=None, device="cuda"):
12 | """Make sure different ranks share the same seed.
13 | All workers must call this function, otherwise it will deadlock.
14 | This method is generally used in `DistributedSampler`,
15 | because the seed should be identical across all processes
16 | in the distributed group.
17 | In distributed sampling, different ranks should sample non-overlapped
18 | data in the dataset. Therefore, this function is used to make sure that
19 | each rank shuffles the data indices in the same order based
20 | on the same seed. Then different ranks could use different indices
21 | to select non-overlapped data from the same data list.
22 | Args:
23 | seed (int, Optional): The seed. Default to None.
24 | device (str): The device where the seed will be put on.
25 | Default to 'cuda'.
26 | Returns:
27 | int: Seed to be used.
28 | """
29 | if seed is None:
30 | seed = np.random.randint(2 ** 31)
31 | assert isinstance(seed, int)
32 |
33 | rank, world_size = get_dist_info()
34 |
35 | if world_size == 1:
36 | return seed
37 |
38 | if rank == 0:
39 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
40 | else:
41 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
42 | dist.broadcast(random_num, src=0)
43 | return random_num.item()
44 |
45 |
46 | class InfiniteGroupEachSampleInBatchSampler(Sampler):
47 | """
48 | Pardon this horrendous name. Basically, we want every sample to be from its own group.
49 | If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on
50 | its own group.
51 |
52 | Shuffling is only done for group order, not done within groups.
53 | """
54 |
55 | def __init__(
56 | self,
57 | dataset,
58 | batch_size=1,
59 | world_size=None,
60 | rank=None,
61 | seed=0,
62 | skip_prob=0.5,
63 | sequence_flip_prob=0.1,
64 | ):
65 |
66 | _rank, _world_size = get_dist_info()
67 | if world_size is None:
68 | world_size = _world_size
69 | if rank is None:
70 | rank = _rank
71 |
72 | self.dataset = dataset
73 | self.batch_size = batch_size
74 | self.world_size = world_size
75 | self.rank = rank
76 | self.seed = sync_random_seed(seed)
77 |
78 | self.size = len(self.dataset)
79 |
80 | assert hasattr(self.dataset, "flag")
81 | self.flag = self.dataset.flag
82 | self.group_sizes = np.bincount(self.flag)
83 | self.groups_num = len(self.group_sizes)
84 | self.global_batch_size = batch_size * world_size
85 | assert self.groups_num >= self.global_batch_size
86 |
87 | # Now, for efficiency, make a dict group_idx: List[dataset sample_idxs]
88 | self.group_idx_to_sample_idxs = {
89 | group_idx: np.where(self.flag == group_idx)[0].tolist()
90 | for group_idx in range(self.groups_num)
91 | }
92 |
93 | # Get a generator per sample idx. Considering samples over all
94 | # GPUs, each sample position has its own generator
95 | self.group_indices_per_global_sample_idx = [
96 | self._group_indices_per_global_sample_idx(
97 | self.rank * self.batch_size + local_sample_idx
98 | )
99 | for local_sample_idx in range(self.batch_size)
100 | ]
101 |
102 | # Keep track of a buffer of dataset sample idxs for each local sample idx
103 | self.buffer_per_local_sample = [[] for _ in range(self.batch_size)]
104 | self.aug_per_local_sample = [None for _ in range(self.batch_size)]
105 | self.skip_prob = skip_prob
106 | self.sequence_flip_prob = sequence_flip_prob
107 |
108 | def _infinite_group_indices(self):
109 | g = torch.Generator()
110 | g.manual_seed(self.seed)
111 | while True:
112 | yield from torch.randperm(self.groups_num, generator=g).tolist()
113 |
114 | def _group_indices_per_global_sample_idx(self, global_sample_idx):
115 | yield from itertools.islice(
116 | self._infinite_group_indices(),
117 | global_sample_idx,
118 | None,
119 | self.global_batch_size,
120 | )
121 |
122 | def __iter__(self):
123 | while True:
124 | curr_batch = []
125 | for local_sample_idx in range(self.batch_size):
126 | skip = (
127 | np.random.uniform() < self.skip_prob
128 | and len(self.buffer_per_local_sample[local_sample_idx]) > 1
129 | )
130 | if len(self.buffer_per_local_sample[local_sample_idx]) == 0:
131 | # Finished current group, refill with next group
132 | # skip = False
133 | new_group_idx = next(
134 | self.group_indices_per_global_sample_idx[
135 | local_sample_idx
136 | ]
137 | )
138 | self.buffer_per_local_sample[
139 | local_sample_idx
140 | ] = copy.deepcopy(
141 | self.group_idx_to_sample_idxs[new_group_idx]
142 | )
143 | if np.random.uniform() < self.sequence_flip_prob:
144 | self.buffer_per_local_sample[local_sample_idx] = self.buffer_per_local_sample[local_sample_idx][::-1]
145 | if self.dataset.keep_consistent_seq_aug:
146 | self.aug_per_local_sample[local_sample_idx] = self.get_aug()
147 |
148 | if not self.dataset.keep_consistent_seq_aug:
149 | self.aug_per_local_sample[local_sample_idx] = self.get_aug()
150 |
151 | if skip:
152 | self.buffer_per_local_sample[local_sample_idx].pop(0)
153 | curr_batch.append(
154 | dict(
155 | idx=self.buffer_per_local_sample[local_sample_idx].pop(0),
156 | aug=self.aug_per_local_sample[local_sample_idx],
157 | )
158 | )
159 |
160 | yield curr_batch
161 |
162 | def __len__(self):
163 | """Length of base dataset."""
164 | return self.size
165 |
166 | def set_epoch(self, epoch):
167 | self.epoch = epoch
168 |
169 | def get_aug(self):
170 | rot_angle = np.random.uniform(*self.dataset.rot_range)
171 | scale_ratio = np.random.uniform(*self.dataset.scale_ratio_range)
172 | aug_configs = self.dataset._sample_augmentation()
173 | return rot_angle, scale_ratio, aug_configs
174 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/samplers/group_in_batch_sampler.py:
--------------------------------------------------------------------------------
1 | # https://github.com/Divadi/SOLOFusion/blob/main/mmdet3d/datasets/samplers/infinite_group_each_sample_in_batch_sampler.py
2 | import itertools
3 | import copy
4 |
5 | import numpy as np
6 | import torch
7 | import torch.distributed as dist
8 | from mmcv.runner import get_dist_info
9 | from torch.utils.data.sampler import Sampler
10 |
11 |
12 | # https://github.com/open-mmlab/mmdetection/blob/3b72b12fe9b14de906d1363982b9fba05e7d47c1/mmdet/core/utils/dist_utils.py#L157
13 | def sync_random_seed(seed=None, device="cuda"):
14 | """Make sure different ranks share the same seed.
15 | All workers must call this function, otherwise it will deadlock.
16 | This method is generally used in `DistributedSampler`,
17 | because the seed should be identical across all processes
18 | in the distributed group.
19 | In distributed sampling, different ranks should sample non-overlapped
20 | data in the dataset. Therefore, this function is used to make sure that
21 | each rank shuffles the data indices in the same order based
22 | on the same seed. Then different ranks could use different indices
23 | to select non-overlapped data from the same data list.
24 | Args:
25 | seed (int, Optional): The seed. Default to None.
26 | device (str): The device where the seed will be put on.
27 | Default to 'cuda'.
28 | Returns:
29 | int: Seed to be used.
30 | """
31 | if seed is None:
32 | seed = np.random.randint(2**31)
33 | assert isinstance(seed, int)
34 |
35 | rank, world_size = get_dist_info()
36 |
37 | if world_size == 1:
38 | return seed
39 |
40 | if rank == 0:
41 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
42 | else:
43 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
44 | dist.broadcast(random_num, src=0)
45 | return random_num.item()
46 |
47 |
48 | class GroupInBatchSampler(Sampler):
49 | """
50 | Pardon this horrendous name. Basically, we want every sample to be from its own group.
51 | If batch size is 4 and # of GPUs is 8, each sample of these 32 should be operating on
52 | its own group.
53 |
54 | Shuffling is only done for group order, not done within groups.
55 | """
56 |
57 | def __init__(
58 | self,
59 | dataset,
60 | batch_size=1,
61 | world_size=None,
62 | rank=None,
63 | seed=0,
64 | skip_prob=0.5,
65 | sequence_flip_prob=0.1,
66 | ):
67 | _rank, _world_size = get_dist_info()
68 | if world_size is None:
69 | world_size = _world_size
70 | if rank is None:
71 | rank = _rank
72 |
73 | self.dataset = dataset
74 | self.batch_size = batch_size
75 | self.world_size = world_size
76 | self.rank = rank
77 | self.seed = sync_random_seed(seed)
78 |
79 | self.size = len(self.dataset)
80 |
81 | assert hasattr(self.dataset, "flag")
82 | self.flag = self.dataset.flag
83 | self.group_sizes = np.bincount(self.flag)
84 | self.groups_num = len(self.group_sizes)
85 | self.global_batch_size = batch_size * world_size
86 | assert self.groups_num >= self.global_batch_size
87 |
88 | # Now, for efficiency, make a dict group_idx: List[dataset sample_idxs]
89 | self.group_idx_to_sample_idxs = {
90 | group_idx: np.where(self.flag == group_idx)[0].tolist()
91 | for group_idx in range(self.groups_num)
92 | }
93 |
94 | # Get a generator per sample idx. Considering samples over all
95 | # GPUs, each sample position has its own generator
96 | self.group_indices_per_global_sample_idx = [
97 | self._group_indices_per_global_sample_idx(
98 | self.rank * self.batch_size + local_sample_idx
99 | )
100 | for local_sample_idx in range(self.batch_size)
101 | ]
102 |
103 | # Keep track of a buffer of dataset sample idxs for each local sample idx
104 | self.buffer_per_local_sample = [[] for _ in range(self.batch_size)]
105 | self.aug_per_local_sample = [None for _ in range(self.batch_size)]
106 | self.skip_prob = skip_prob
107 | self.sequence_flip_prob = sequence_flip_prob
108 |
109 | def _infinite_group_indices(self):
110 | g = torch.Generator()
111 | g.manual_seed(self.seed)
112 | while True:
113 | yield from torch.randperm(self.groups_num, generator=g).tolist()
114 |
115 | def _group_indices_per_global_sample_idx(self, global_sample_idx):
116 | yield from itertools.islice(
117 | self._infinite_group_indices(),
118 | global_sample_idx,
119 | None,
120 | self.global_batch_size,
121 | )
122 |
123 | def __iter__(self):
124 | while True:
125 | curr_batch = []
126 | for local_sample_idx in range(self.batch_size):
127 | skip = (
128 | np.random.uniform() < self.skip_prob
129 | and len(self.buffer_per_local_sample[local_sample_idx]) > 1
130 | )
131 | if len(self.buffer_per_local_sample[local_sample_idx]) == 0:
132 | # Finished current group, refill with next group
133 | # skip = False
134 | new_group_idx = next(
135 | self.group_indices_per_global_sample_idx[
136 | local_sample_idx
137 | ]
138 | )
139 | self.buffer_per_local_sample[
140 | local_sample_idx
141 | ] = copy.deepcopy(
142 | self.group_idx_to_sample_idxs[new_group_idx]
143 | )
144 | if np.random.uniform() < self.sequence_flip_prob:
145 | self.buffer_per_local_sample[
146 | local_sample_idx
147 | ] = self.buffer_per_local_sample[local_sample_idx][
148 | ::-1
149 | ]
150 | if self.dataset.keep_consistent_seq_aug:
151 | self.aug_per_local_sample[
152 | local_sample_idx
153 | ] = self.dataset.get_augmentation()
154 |
155 | if not self.dataset.keep_consistent_seq_aug:
156 | self.aug_per_local_sample[
157 | local_sample_idx
158 | ] = self.dataset.get_augmentation()
159 |
160 | if skip:
161 | self.buffer_per_local_sample[local_sample_idx].pop(0)
162 | curr_batch.append(
163 | dict(
164 | idx=self.buffer_per_local_sample[local_sample_idx].pop(
165 | 0
166 | ),
167 | aug_config=self.aug_per_local_sample[local_sample_idx],
168 | )
169 | )
170 |
171 | yield curr_batch
172 |
173 | def __len__(self):
174 | """Length of base dataset."""
175 | return self.size
176 |
177 | def set_epoch(self, epoch):
178 | self.epoch = epoch
179 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/pipelines/loading.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mmcv
3 | from mmdet.datasets.builder import PIPELINES
4 |
5 |
6 | @PIPELINES.register_module()
7 | class LoadMultiViewImageFromFiles(object):
8 | """Load multi channel images from a list of separate channel files.
9 |
10 | Expects results['img_filename'] to be a list of filenames.
11 |
12 | Args:
13 | to_float32 (bool, optional): Whether to convert the img to float32.
14 | Defaults to False.
15 | color_type (str, optional): Color type of the file.
16 | Defaults to 'unchanged'.
17 | """
18 |
19 | def __init__(self, to_float32=False, color_type="unchanged"):
20 | self.to_float32 = to_float32
21 | self.color_type = color_type
22 |
23 | def __call__(self, results):
24 | """Call function to load multi-view image from files.
25 |
26 | Args:
27 | results (dict): Result dict containing multi-view image filenames.
28 |
29 | Returns:
30 | dict: The result dict containing the multi-view image data.
31 | Added keys and values are described below.
32 |
33 | - filename (str): Multi-view image filenames.
34 | - img (np.ndarray): Multi-view image arrays.
35 | - img_shape (tuple[int]): Shape of multi-view image arrays.
36 | - ori_shape (tuple[int]): Shape of original image arrays.
37 | - pad_shape (tuple[int]): Shape of padded image arrays.
38 | - scale_factor (float): Scale factor.
39 | - img_norm_cfg (dict): Normalization configuration of images.
40 | """
41 | filename = results["img_filename"]
42 | # img is of shape (h, w, c, num_views)
43 | img = np.stack(
44 | [mmcv.imread(name, self.color_type) for name in filename], axis=-1
45 | )
46 | if self.to_float32:
47 | img = img.astype(np.float32)
48 | results["filename"] = filename
49 | # unravel to list, see `DefaultFormatBundle` in formatting.py
50 | # which will transpose each image separately and then stack into array
51 | results["img"] = [img[..., i] for i in range(img.shape[-1])]
52 | results["img_shape"] = img.shape
53 | results["ori_shape"] = img.shape
54 | # Set initial values for default meta_keys
55 | results["pad_shape"] = img.shape
56 | results["scale_factor"] = 1.0
57 | num_channels = 1 if len(img.shape) < 3 else img.shape[2]
58 | results["img_norm_cfg"] = dict(
59 | mean=np.zeros(num_channels, dtype=np.float32),
60 | std=np.ones(num_channels, dtype=np.float32),
61 | to_rgb=False,
62 | )
63 | return results
64 |
65 | def __repr__(self):
66 | """str: Return a string that describes the module."""
67 | repr_str = self.__class__.__name__
68 | repr_str += f"(to_float32={self.to_float32}, "
69 | repr_str += f"color_type='{self.color_type}')"
70 | return repr_str
71 |
72 | @PIPELINES.register_module()
73 | class LoadPointsFromFile(object):
74 | """Load Points From File.
75 |
76 | Load points from file.
77 |
78 | Args:
79 | coord_type (str): The type of coordinates of points cloud.
80 | Available options includes:
81 | - 'LIDAR': Points in LiDAR coordinates.
82 | - 'DEPTH': Points in depth coordinates, usually for indoor dataset.
83 | - 'CAMERA': Points in camera coordinates.
84 | load_dim (int, optional): The dimension of the loaded points.
85 | Defaults to 6.
86 | use_dim (list[int], optional): Which dimensions of the points to use.
87 | Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4
88 | or use_dim=[0, 1, 2, 3] to use the intensity dimension.
89 | shift_height (bool, optional): Whether to use shifted height.
90 | Defaults to False.
91 | use_color (bool, optional): Whether to use color features.
92 | Defaults to False.
93 | file_client_args (dict, optional): Config dict of file clients,
94 | refer to
95 | https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
96 | for more details. Defaults to dict(backend='disk').
97 | """
98 |
99 | def __init__(
100 | self,
101 | coord_type,
102 | load_dim=6,
103 | use_dim=[0, 1, 2],
104 | shift_height=False,
105 | use_color=False,
106 | file_client_args=dict(backend="disk"),
107 | feather_points=False,
108 | ):
109 | self.shift_height = shift_height
110 | self.use_color = use_color
111 | if isinstance(use_dim, int):
112 | use_dim = list(range(use_dim))
113 | assert (
114 | max(use_dim) < load_dim
115 | ), f"Expect all used dimensions < {load_dim}, got {use_dim}"
116 | assert coord_type in ["CAMERA", "LIDAR", "DEPTH"]
117 |
118 | self.coord_type = coord_type
119 | self.load_dim = load_dim
120 | self.use_dim = use_dim
121 | self.file_client_args = file_client_args.copy()
122 | self.file_client = None
123 | self.feather_points = feather_points
124 |
125 | def _load_points(self, pts_filename):
126 | """Private function to load point clouds data.
127 |
128 | Args:
129 | pts_filename (str): Filename of point clouds data.
130 |
131 | Returns:
132 | np.ndarray: An array containing point clouds data.
133 | """
134 | if self.file_client is None:
135 | self.file_client = mmcv.FileClient(**self.file_client_args)
136 | try:
137 | pts_bytes = self.file_client.get(pts_filename)
138 | points = np.frombuffer(pts_bytes, dtype=np.float32)
139 |
140 | except ConnectionError:
141 | mmcv.check_file_exist(pts_filename)
142 | if pts_filename.endswith(".npy"):
143 | points = np.load(pts_filename)
144 | else:
145 | points = np.fromfile(pts_filename, dtype=np.float32)
146 |
147 | return points
148 |
149 | def __call__(self, results):
150 | """Call function to load points data from file.
151 |
152 | Args:
153 | results (dict): Result dict containing point clouds data.
154 |
155 | Returns:
156 | dict: The result dict containing the point clouds data.
157 | Added key and value are described below.
158 |
159 | - points (:obj:`BasePoints`): Point clouds data.
160 | """
161 | pts_filename = results["pts_filename"]
162 | points = self._load_points(pts_filename)
163 | points = points.reshape(-1, self.load_dim)
164 | points = points[:, self.use_dim]
165 | attribute_dims = None
166 |
167 | if self.shift_height:
168 | floor_height = np.percentile(points[:, 2], 0.99)
169 | height = points[:, 2] - floor_height
170 | points = np.concatenate(
171 | [points[:, :3], np.expand_dims(height, 1), points[:, 3:]], 1
172 | )
173 | attribute_dims = dict(height=3)
174 |
175 | if self.use_color:
176 | assert len(self.use_dim) >= 6
177 | if attribute_dims is None:
178 | attribute_dims = dict()
179 | attribute_dims.update(
180 | dict(
181 | color=[
182 | points.shape[1] - 3,
183 | points.shape[1] - 2,
184 | points.shape[1] - 1,
185 | ]
186 | )
187 | )
188 |
189 | results["points"] = points
190 | return results
191 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/apis/mmdet_train.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------
2 | # Copyright (c) OpenMMLab. All rights reserved.
3 | # ---------------------------------------------
4 | # Modified by Zhiqi Li
5 | # ---------------------------------------------
6 | import random
7 | import warnings
8 |
9 | import numpy as np
10 | import torch
11 | import torch.distributed as dist
12 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
13 | from mmcv.runner import (
14 | HOOKS,
15 | DistSamplerSeedHook,
16 | EpochBasedRunner,
17 | Fp16OptimizerHook,
18 | OptimizerHook,
19 | build_optimizer,
20 | build_runner,
21 | get_dist_info,
22 | )
23 | from mmcv.utils import build_from_cfg
24 |
25 | from mmdet.core import EvalHook
26 |
27 | from mmdet.datasets import build_dataset, replace_ImageToTensor
28 | from mmdet.utils import get_root_logger
29 | import time
30 | import os.path as osp
31 | from projects.mmdet3d_plugin.datasets.builder import build_dataloader
32 | from projects.mmdet3d_plugin.core.evaluation.eval_hooks import (
33 | CustomDistEvalHook,
34 | )
35 | from projects.mmdet3d_plugin.datasets import custom_build_dataset
36 |
37 |
38 | def custom_train_detector(
39 | model,
40 | dataset,
41 | cfg,
42 | distributed=False,
43 | validate=False,
44 | timestamp=None,
45 | meta=None,
46 | ):
47 | logger = get_root_logger(cfg.log_level)
48 |
49 | # prepare data loaders
50 |
51 | dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
52 | # assert len(dataset)==1s
53 | if "imgs_per_gpu" in cfg.data:
54 | logger.warning(
55 | '"imgs_per_gpu" is deprecated in MMDet V2.0. '
56 | 'Please use "samples_per_gpu" instead'
57 | )
58 | if "samples_per_gpu" in cfg.data:
59 | logger.warning(
60 | f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
61 | f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
62 | f"={cfg.data.imgs_per_gpu} is used in this experiments"
63 | )
64 | else:
65 | logger.warning(
66 | 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
67 | f"{cfg.data.imgs_per_gpu} in this experiments"
68 | )
69 | cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
70 |
71 | if "runner" in cfg:
72 | runner_type = cfg.runner["type"]
73 | else:
74 | runner_type = "EpochBasedRunner"
75 | data_loaders = [
76 | build_dataloader(
77 | ds,
78 | cfg.data.samples_per_gpu,
79 | cfg.data.workers_per_gpu,
80 | # cfg.gpus will be ignored if distributed
81 | len(cfg.gpu_ids),
82 | dist=distributed,
83 | seed=cfg.seed,
84 | nonshuffler_sampler=dict(
85 | type="DistributedSampler"
86 | ), # dict(type='DistributedSampler'),
87 | runner_type=runner_type,
88 | )
89 | for ds in dataset
90 | ]
91 |
92 | # put model on gpus
93 | if distributed:
94 | find_unused_parameters = cfg.get("find_unused_parameters", False)
95 | # Sets the `find_unused_parameters` parameter in
96 | # torch.nn.parallel.DistributedDataParallel
97 | model = MMDistributedDataParallel(
98 | model.cuda(),
99 | device_ids=[torch.cuda.current_device()],
100 | broadcast_buffers=False,
101 | find_unused_parameters=find_unused_parameters,
102 | )
103 |
104 | else:
105 | model = MMDataParallel(
106 | model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids
107 | )
108 |
109 | # build runner
110 | optimizer = build_optimizer(model, cfg.optimizer)
111 |
112 | if "runner" not in cfg:
113 | cfg.runner = {
114 | "type": "EpochBasedRunner",
115 | "max_epochs": cfg.total_epochs,
116 | }
117 | warnings.warn(
118 | "config is now expected to have a `runner` section, "
119 | "please set `runner` in your config.",
120 | UserWarning,
121 | )
122 | else:
123 | if "total_epochs" in cfg:
124 | assert cfg.total_epochs == cfg.runner.max_epochs
125 |
126 | runner = build_runner(
127 | cfg.runner,
128 | default_args=dict(
129 | model=model,
130 | optimizer=optimizer,
131 | work_dir=cfg.work_dir,
132 | logger=logger,
133 | meta=meta,
134 | ),
135 | )
136 |
137 | # an ugly workaround to make .log and .log.json filenames the same
138 | runner.timestamp = timestamp
139 |
140 | # fp16 setting
141 | fp16_cfg = cfg.get("fp16", None)
142 | if fp16_cfg is not None:
143 | optimizer_config = Fp16OptimizerHook(
144 | **cfg.optimizer_config, **fp16_cfg, distributed=distributed
145 | )
146 | elif distributed and "type" not in cfg.optimizer_config:
147 | optimizer_config = OptimizerHook(**cfg.optimizer_config)
148 | else:
149 | optimizer_config = cfg.optimizer_config
150 |
151 | # register hooks
152 | runner.register_training_hooks(
153 | cfg.lr_config,
154 | optimizer_config,
155 | cfg.checkpoint_config,
156 | cfg.log_config,
157 | cfg.get("momentum_config", None),
158 | )
159 |
160 | # register profiler hook
161 | # trace_config = dict(type='tb_trace', dir_name='work_dir')
162 | # profiler_config = dict(on_trace_ready=trace_config)
163 | # runner.register_profiler_hook(profiler_config)
164 |
165 | if distributed:
166 | if isinstance(runner, EpochBasedRunner):
167 | runner.register_hook(DistSamplerSeedHook())
168 |
169 | # register eval hooks
170 | if validate:
171 | # Support batch_size > 1 in validation
172 | val_samples_per_gpu = cfg.data.val.pop("samples_per_gpu", 1)
173 | if val_samples_per_gpu > 1:
174 | assert False
175 | # Replace 'ImageToTensor' to 'DefaultFormatBundle'
176 | cfg.data.val.pipeline = replace_ImageToTensor(
177 | cfg.data.val.pipeline
178 | )
179 | val_dataset = custom_build_dataset(cfg.data.val, dict(test_mode=True))
180 |
181 | val_dataloader = build_dataloader(
182 | val_dataset,
183 | samples_per_gpu=val_samples_per_gpu,
184 | workers_per_gpu=cfg.data.workers_per_gpu,
185 | dist=distributed,
186 | shuffle=False,
187 | nonshuffler_sampler=dict(type="DistributedSampler"),
188 | )
189 | eval_cfg = cfg.get("evaluation", {})
190 | eval_cfg["by_epoch"] = cfg.runner["type"] != "IterBasedRunner"
191 | eval_cfg["jsonfile_prefix"] = osp.join(
192 | "val",
193 | cfg.work_dir,
194 | time.ctime().replace(" ", "_").replace(":", "_"),
195 | )
196 | eval_hook = CustomDistEvalHook if distributed else EvalHook
197 | runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
198 |
199 | # user-defined hooks
200 | if cfg.get("custom_hooks", None):
201 | custom_hooks = cfg.custom_hooks
202 | assert isinstance(
203 | custom_hooks, list
204 | ), f"custom_hooks expect list type, but got {type(custom_hooks)}"
205 | for hook_cfg in cfg.custom_hooks:
206 | assert isinstance(hook_cfg, dict), (
207 | "Each item in custom_hooks expects dict type, but got "
208 | f"{type(hook_cfg)}"
209 | )
210 | hook_cfg = hook_cfg.copy()
211 | priority = hook_cfg.pop("priority", "NORMAL")
212 | hook = build_from_cfg(hook_cfg, HOOKS)
213 | runner.register_hook(hook, priority=priority)
214 |
215 | if cfg.resume_from:
216 | runner.resume(cfg.resume_from)
217 | elif cfg.load_from:
218 | if cfg.get('revise_keys', None):
219 | runner.load_checkpoint(cfg.load_from, revise_keys=cfg['revise_keys'])
220 | else:
221 | runner.load_checkpoint(cfg.load_from)
222 | runner.run(data_loaders, cfg.workflow)
223 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/allocation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | import numpy as np
5 |
6 | from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
7 | from .detection3d.decoder import W, L, H, SIN_YAW, COS_YAW
8 |
9 | @PLUGIN_LAYERS.register_module()
10 | class DynamicQueryAllocation(nn.Module):
11 | def __init__(self,
12 | with_attn_mask=False,
13 | with_project_wh=False,
14 | limit_anchor_size=[35, 35, 10],
15 | limit_corners_num=[100] * 6,
16 | ):
17 | super().__init__()
18 | self.with_attn_mask = with_attn_mask
19 | self.with_project_wh = with_project_wh
20 | self.limit_anchor_size = limit_anchor_size
21 | self.limit_corners_num = limit_corners_num
22 |
23 | def forward(self, anchor3d, metas):
24 | outputs = self.projection_allocation(anchor3d, metas)
25 | return outputs
26 |
27 | def projection_allocation(self, anchor3d, metas):
28 | device = anchor3d.device
29 | anchor3d_center = anchor3d[..., :3]
30 | lidar2imgs = torch.tile(metas['projection_mat'][:, None], (1, anchor3d.shape[1], 1, 1, 1))
31 | batch_size, num_anchor3d, num_cams = lidar2imgs.shape[:3]
32 | img_w, img_h = map(int, metas['image_wh'][0, 0].tolist())
33 |
34 | # get rotation mat
35 | rotation_mat = anchor3d.new_zeros([batch_size, num_anchor3d, 3, 3])
36 | rotation_mat[:, :, 0, 0] = anchor3d[:, :, COS_YAW]
37 | rotation_mat[:, :, 0, 1] = -anchor3d[:, :, SIN_YAW]
38 | rotation_mat[:, :, 1, 0] = anchor3d[:, :, SIN_YAW]
39 | rotation_mat[:, :, 1, 1] = anchor3d[:, :, COS_YAW]
40 | rotation_mat[:, :, 2, 2] = 1
41 |
42 | # get anchor corners
43 | corners_norm = anchor3d.new_tensor(np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1))
44 | corners_norm = corners_norm - anchor3d.new_tensor([0.5, 0.5, 0.5])
45 |
46 | anchor3d_size = anchor3d[..., [W, L, H]].exp()
47 | anchor3d_size = anchor3d_size.clamp(max=torch.tensor(self.limit_anchor_size, device=device).view(1, 1, -1))
48 |
49 | anchor3d_corners = anchor3d_size[:, :, None, :] * corners_norm[None, None, :, :]
50 | anchor3d_corners = torch.matmul(rotation_mat[:, :, None], anchor3d_corners[..., None]).squeeze(-1)
51 | anchor3d_corners = anchor3d_corners + anchor3d_center[:, :, None, :]
52 | anchor3d_corners = torch.cat([anchor3d_corners, anchor3d_center[:, :, None, :]], dim=-2)
53 |
54 | # get points in camera and image plane
55 | coord_pts3d = torch.cat([anchor3d_corners, torch.ones_like(anchor3d_corners[..., :1])], -1)
56 | coord_pts3d = coord_pts3d.view(batch_size, num_anchor3d, 1, 9, 4, 1).repeat(1, 1, num_cams, 1, 1, 1)
57 | coord_pts2d = torch.matmul(lidar2imgs[:, :, :, None], coord_pts3d).squeeze(-1)
58 |
59 | center_pts2d = coord_pts2d[..., -1, :]
60 | corner_pts2d = coord_pts2d[..., :-1, :]
61 | center_depth2d = center_pts2d[..., 2:3]
62 | corner_depth2d = corner_pts2d[..., 2:3]
63 |
64 | center_pts2d = center_pts2d[..., :2] / center_depth2d.clamp(1e-5)
65 | corner_pts2d = corner_pts2d[..., :2] / corner_depth2d.clamp(1e-5)
66 |
67 | center_valid = ((0 < center_pts2d[..., 0]) & (center_pts2d[..., 0] < img_w) &
68 | (0 < center_pts2d[..., 1]) & (center_pts2d[..., 1] < img_h))
69 |
70 | corner_valid1 = (corner_depth2d[..., 0] > 0)
71 | corner_valid2 = ((0 < corner_pts2d[..., 0]) & (corner_pts2d[..., 0] < img_w) &
72 | (0 < corner_pts2d[..., 1]) & (corner_pts2d[..., 1] < img_h))
73 | corner_valid = torch.logical_and(corner_valid1, corner_valid2).any(-1)
74 |
75 | # project corners to get corner-centers
76 | x_min = torch.clamp(corner_pts2d[..., 0].min(-1).values, min=0, max=img_w)
77 | x_max = torch.clamp(corner_pts2d[..., 0].max(-1).values, min=0, max=img_w)
78 | y_min = torch.clamp(corner_pts2d[..., 1].min(-1).values, min=0, max=img_h)
79 | y_max = torch.clamp(corner_pts2d[..., 1].max(-1).values, min=0, max=img_h)
80 |
81 | cx, cy = (x_min + x_max) / 2, (y_min + y_max) / 2
82 | select_centers = torch.stack([cx, cy], dim=-1)
83 | select_centers[center_valid] = center_pts2d[center_valid] # overwrite center points
84 |
85 | if self.training and self.limit_corners_num:
86 | corner_valid = torch.where(torch.logical_and(corner_valid, center_valid), False, corner_valid)
87 | corner_valid = self.random_sample_corner_mask(corner_valid)
88 |
89 | # divide to groups
90 | trans_mask = torch.logical_or(center_valid, corner_valid)
91 | trans_shape = trans_mask.sum(1)
92 | trans_meta_shape = trans_shape.max(0).values
93 | trans_meta_start = torch.cat([torch.zeros_like(trans_meta_shape[:1]), trans_meta_shape])
94 | trans_meta_cumsum = trans_meta_start.cumsum(-1).tolist()
95 |
96 | trans_start = trans_meta_start.cumsum(-1)[:num_cams].unsqueeze(0).repeat(batch_size, 1)
97 | trans_end = trans_start + trans_shape
98 |
99 | query_groups = [(qs, qe) for qs, qe in zip(trans_meta_cumsum[:-1], trans_meta_cumsum[1:])]
100 | num_anchor2d = trans_meta_shape.sum()
101 |
102 | # create reference points
103 | trans_mask_tmp = trans_mask.permute(0, 2, 1).flatten(0, 1)
104 | select_centers = select_centers.permute(0, 2, 1, 3).flatten(0, 1)
105 | select_depths = center_depth2d.permute(0, 2, 1, 3).flatten(0, 1) # corner depth is fake
106 |
107 | select_centers = select_centers[trans_mask_tmp]
108 | select_depths = select_depths[trans_mask_tmp]
109 |
110 | selected_mask = torch.zeros((batch_size, num_anchor2d), device=device)
111 | attn_mask = torch.ones((batch_size, num_anchor2d, num_anchor2d), device=device).fill_(float("-inf"))
112 | for bs in range(batch_size):
113 | for st, ed in zip(trans_start[bs], trans_end[bs]):
114 | selected_mask[bs, st:ed] = 1.0
115 | if self.with_attn_mask:
116 | attn_mask[bs, st:ed, st:ed] = 0.0
117 | selected_mask = selected_mask.unsqueeze(-1).repeat(1, 1, 2).to(torch.bool)
118 |
119 | ref_pts2d = torch.zeros((batch_size, num_anchor2d, 2), device=device)
120 | ref_depth2d = torch.zeros((batch_size, num_anchor2d, 1), device=device)
121 |
122 | ref_pts2d = torch.masked_scatter(ref_pts2d, selected_mask[..., :2], select_centers)
123 | ref_depth2d = torch.masked_scatter(ref_depth2d, selected_mask[..., :1], select_depths.abs())
124 |
125 | ref_pts2d = ref_pts2d / ref_pts2d.new_tensor([img_w, img_h])
126 |
127 | # create trans matrix
128 | trans_matrix = nn.Parameter(
129 | torch.zeros((batch_size, num_anchor2d, num_anchor3d), device=device), requires_grad=False)
130 |
131 | meta_mask = trans_mask.to(torch.float) + center_valid.to(torch.float)
132 | meta_mask = meta_mask.permute(0, 2, 1)
133 |
134 | for bs in range(batch_size):
135 | cam_index, pts3d_index = torch.nonzero(meta_mask[bs]).chunk(2, dim=1)
136 | cam_index, pts3d_index = cam_index[:, 0], pts3d_index[:, 0]
137 | pts2d_index = torch.cat(
138 | [torch.arange(st, ed, device=device) for st, ed in zip(trans_start[bs], trans_end[bs])])
139 | trans_matrix[bs, pts2d_index, pts3d_index] = meta_mask[bs, cam_index, pts3d_index]
140 |
141 | center_matrix = (trans_matrix == 2).to(torch.float)
142 | trans_matrix = (trans_matrix >= 1).to(torch.float) # include center and corner points
143 |
144 | return ref_pts2d, ref_depth2d, trans_mask, trans_shape, trans_matrix, center_matrix, query_groups, attn_mask
145 |
146 | def random_sample_corner_mask(self, corner_valid):
147 | batch_size, num_anchor3d, num_cams = corner_valid.shape
148 | corner_view = corner_valid.permute(0, 2, 1).reshape(batch_size * num_cams, -1)
149 | corner_nums = corner_view.sum(-1).detach().cpu().numpy()
150 | limit_nums = np.array(self.limit_corners_num * batch_size)
151 |
152 | remove_ids = np.where(corner_nums > limit_nums)[0]
153 | for remove_id in remove_ids:
154 | bs = remove_id // num_cams
155 | cam_id = remove_id % num_cams
156 | remove_num = max(corner_nums[remove_id] - limit_nums[remove_id], 0)
157 | remove_index = np.random.permutation(corner_nums[remove_id])[remove_num:]
158 | corner_valid[bs, :, cam_id][remove_index] = False
159 |
160 | return corner_valid
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 | import cv2
4 | import numpy as np
5 | import torch
6 |
7 | from projects.mmdet3d_plugin.core.box3d import *
8 |
9 |
10 | def box3d_to_corners(box3d):
11 | if isinstance(box3d, torch.Tensor):
12 | box3d = box3d.detach().cpu().numpy()
13 | corners_norm = np.stack(np.unravel_index(np.arange(8), [2] * 3), axis=1)
14 | corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
15 | # use relative origin [0.5, 0.5, 0]
16 | corners_norm = corners_norm - np.array([0.5, 0.5, 0.5])
17 | corners = box3d[:, None, [W, L, H]] * corners_norm.reshape([1, 8, 3])
18 |
19 | # rotate around z axis
20 | rot_cos = np.cos(box3d[:, YAW])
21 | rot_sin = np.sin(box3d[:, YAW])
22 | rot_mat = np.tile(np.eye(3)[None], (box3d.shape[0], 1, 1))
23 | rot_mat[:, 0, 0] = rot_cos
24 | rot_mat[:, 0, 1] = -rot_sin
25 | rot_mat[:, 1, 0] = rot_sin
26 | rot_mat[:, 1, 1] = rot_cos
27 | corners = (rot_mat[:, None] @ corners[..., None]).squeeze(axis=-1)
28 | corners += box3d[:, None, :3]
29 | return corners
30 |
31 |
32 | def plot_rect3d_on_img(
33 | img, num_rects, rect_corners, color=(0, 255, 0), thickness=1
34 | ):
35 | """Plot the boundary lines of 3D rectangular on 2D images.
36 |
37 | Args:
38 | img (numpy.array): The numpy array of image.
39 | num_rects (int): Number of 3D rectangulars.
40 | rect_corners (numpy.array): Coordinates of the corners of 3D
41 | rectangulars. Should be in the shape of [num_rect, 8, 2].
42 | color (tuple[int], optional): The color to draw bboxes.
43 | Default: (0, 255, 0).
44 | thickness (int, optional): The thickness of bboxes. Default: 1.
45 | """
46 | line_indices = (
47 | (0, 1),
48 | (0, 3),
49 | (0, 4),
50 | (1, 2),
51 | (1, 5),
52 | (3, 2),
53 | (3, 7),
54 | (4, 5),
55 | (4, 7),
56 | (2, 6),
57 | (5, 6),
58 | (6, 7),
59 | )
60 | h, w = img.shape[:2]
61 | for i in range(num_rects):
62 | corners = np.clip(rect_corners[i], -1e4, 1e5).astype(np.int32)
63 | for start, end in line_indices:
64 | if (
65 | (corners[start, 1] >= h or corners[start, 1] < 0)
66 | or (corners[start, 0] >= w or corners[start, 0] < 0)
67 | ) and (
68 | (corners[end, 1] >= h or corners[end, 1] < 0)
69 | or (corners[end, 0] >= w or corners[end, 0] < 0)
70 | ):
71 | continue
72 | if isinstance(color[0], int):
73 | cv2.line(
74 | img,
75 | (corners[start, 0], corners[start, 1]),
76 | (corners[end, 0], corners[end, 1]),
77 | color,
78 | thickness,
79 | cv2.LINE_AA,
80 | )
81 | else:
82 | cv2.line(
83 | img,
84 | (corners[start, 0], corners[start, 1]),
85 | (corners[end, 0], corners[end, 1]),
86 | color[i],
87 | thickness,
88 | cv2.LINE_AA,
89 | )
90 |
91 | return img.astype(np.uint8)
92 |
93 |
94 | def draw_image_bbox2d_on_img(
95 | bboxes2d, raw_img, img_metas=None, color=(0, 255, 0), thickness=1
96 | ):
97 | for i, bbox2d in enumerate(bboxes2d):
98 | bbox = np.array(bbox2d)
99 | raw_img = cv2.rectangle(raw_img, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), color, thickness)
100 | return raw_img
101 |
102 |
103 | def draw_lidar_bbox3d_on_img(
104 | bboxes3d, raw_img, lidar2img_rt, img_metas=None, color=(0, 255, 0), thickness=1
105 | ):
106 | """Project the 3D bbox on 2D plane and draw on input image.
107 |
108 | Args:
109 | bboxes3d (:obj:`LiDARInstance3DBoxes`):
110 | 3d bbox in lidar coordinate system to visualize.
111 | raw_img (numpy.array): The numpy array of image.
112 | lidar2img_rt (numpy.array, shape=[4, 4]): The projection matrix
113 | according to the camera intrinsic parameters.
114 | img_metas (dict): Useless here.
115 | color (tuple[int], optional): The color to draw bboxes.
116 | Default: (0, 255, 0).
117 | thickness (int, optional): The thickness of bboxes. Default: 1.
118 | """
119 | img = raw_img.copy()
120 | # corners_3d = bboxes3d.corners
121 | corners_3d = box3d_to_corners(bboxes3d)
122 | num_bbox = corners_3d.shape[0]
123 | pts_4d = np.concatenate(
124 | [corners_3d.reshape(-1, 3), np.ones((num_bbox * 8, 1))], axis=-1
125 | )
126 | lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4)
127 | if isinstance(lidar2img_rt, torch.Tensor):
128 | lidar2img_rt = lidar2img_rt.cpu().numpy()
129 | pts_2d = pts_4d @ lidar2img_rt.T
130 |
131 | pts_2d[:, 2] = np.clip(pts_2d[:, 2], a_min=1e-5, a_max=1e5)
132 | pts_2d[:, 0] /= pts_2d[:, 2]
133 | pts_2d[:, 1] /= pts_2d[:, 2]
134 | imgfov_pts_2d = pts_2d[..., :2].reshape(num_bbox, 8, 2)
135 |
136 | return plot_rect3d_on_img(img, num_bbox, imgfov_pts_2d, color, thickness)
137 |
138 |
139 | def draw_points_on_img(points, img, lidar2img_rt, color=(0, 255, 0), circle=4):
140 | img = img.copy()
141 | N = points.shape[0]
142 | points = points.cpu().numpy()
143 | lidar2img_rt = copy.deepcopy(lidar2img_rt).reshape(4, 4)
144 | if isinstance(lidar2img_rt, torch.Tensor):
145 | lidar2img_rt = lidar2img_rt.cpu().numpy()
146 | pts_2d = (
147 | np.sum(points[:, :, None] * lidar2img_rt[:3, :3], axis=-1)
148 | + lidar2img_rt[:3, 3]
149 | )
150 | pts_2d[..., 2] = np.clip(pts_2d[..., 2], a_min=1e-5, a_max=1e5)
151 | pts_2d = pts_2d[..., :2] / pts_2d[..., 2:3]
152 | pts_2d = np.clip(pts_2d, -1e4, 1e4).astype(np.int32)
153 |
154 | for i in range(N):
155 | for point in pts_2d[i]:
156 | if isinstance(color[0], int):
157 | color_tmp = color
158 | else:
159 | color_tmp = color[i]
160 | cv2.circle(img, point.tolist(), circle, color_tmp, thickness=-1)
161 | return img.astype(np.uint8)
162 |
163 |
164 | def draw_lidar_bbox3d_on_bev(bboxes_3d, bev_size, bev_range=115, color=(255, 0, 0), thickness=3, bev_img=None):
165 | if isinstance(bev_size, (list, tuple)):
166 | bev_h, bev_w = bev_size
167 | else:
168 | bev_h, bev_w = bev_size, bev_size
169 |
170 | bev_resolution = bev_range / bev_h
171 |
172 | # init bev image
173 | if bev_img is None:
174 | bev = np.ones([bev_h, bev_w, 3]) * 255
175 |
176 | marking_color = (110, 110, 110)
177 |
178 | for cir in range(int(bev_range / 20)):
179 | cv2.circle(bev, (int(bev_h / 2), int(bev_w / 2)),
180 | int((cir + 1) * 10 / bev_resolution), marking_color, thickness=thickness)
181 | cv2.line(bev, (0, int(bev_h / 2)), (bev_w, int(bev_h / 2)), marking_color, thickness=thickness-1)
182 | cv2.line(bev, (int(bev_w / 2), 0), (int(bev_w / 2), bev_h), marking_color, thickness=thickness-1)
183 | else:
184 | bev = bev_img
185 |
186 | if len(bboxes_3d) != 0:
187 | bev_corners = box3d_to_corners(bboxes_3d)[:, [0, 3, 4, 7]][..., [0, 1]]
188 | xs = bev_corners[..., 0] / bev_resolution + bev_w / 2
189 | ys = -bev_corners[..., 1] / bev_resolution + bev_h / 2
190 | for obj_idx, (x, y) in enumerate(zip(xs, ys)):
191 | for p1, p2 in ((0, 1), (0, 2), (1, 3), (2, 3)):
192 | if isinstance(color[0], (list, tuple)):
193 | tmp = color[obj_idx]
194 | else:
195 | tmp = color
196 | cv2.line(
197 | bev,
198 | (int(x[p1]), int(y[p1])),
199 | (int(x[p2]), int(y[p2])),
200 | tmp,
201 | thickness=thickness,
202 | )
203 | return bev.astype(np.uint8)
204 |
205 |
206 | def draw_lidar_bbox3d(bboxes_3d, imgs, lidar2imgs, color=(255, 0, 0)):
207 | vis_imgs = []
208 | for i, (img, lidar2img) in enumerate(zip(imgs, lidar2imgs)):
209 | vis_imgs.append(
210 | draw_lidar_bbox3d_on_img(bboxes_3d, img, lidar2img, color=color)
211 | )
212 |
213 | num_imgs = len(vis_imgs)
214 | if num_imgs < 4 or num_imgs % 2 != 0:
215 | vis_imgs = np.concatenate(vis_imgs, axis=1)
216 | else:
217 | vis_imgs = np.concatenate([
218 | np.concatenate(vis_imgs[:num_imgs//2], axis=1),
219 | np.concatenate(vis_imgs[num_imgs//2:], axis=1)
220 | ], axis=0)
221 |
222 | bev = draw_lidar_bbox3d_on_bev(bboxes_3d, vis_imgs.shape[0], color=color)
223 | vis_imgs = np.concatenate([bev, vis_imgs], axis=1)
224 | return vis_imgs
225 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/instance_bank.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 |
7 | from mmcv.utils import build_from_cfg
8 | from mmcv.cnn.bricks.registry import PLUGIN_LAYERS
9 |
10 | __all__ = ["InstanceBank"]
11 |
12 |
13 | def topk(confidence, k, *inputs):
14 | bs, N = confidence.shape[:2]
15 | confidence, indices = torch.topk(confidence, k, dim=1)
16 | indices = (indices + torch.arange(bs, device=indices.device)[:, None] * N).reshape(-1)
17 | outputs = []
18 | for input in inputs:
19 | outputs.append(input.flatten(end_dim=1)[indices].reshape(bs, k, -1))
20 | return confidence, outputs
21 |
22 |
23 | @PLUGIN_LAYERS.register_module()
24 | class InstanceBank(nn.Module):
25 | def __init__(
26 | self,
27 | num_anchor,
28 | embed_dims,
29 | anchor,
30 | anchor_handler=None,
31 | num_temp_instances=0,
32 | default_time_interval=0.5,
33 | confidence_decay=0.6,
34 | anchor_grad=True,
35 | feat_grad=True,
36 | max_time_interval=2,
37 | ):
38 | super(InstanceBank, self).__init__()
39 | self.embed_dims = embed_dims
40 | self.num_temp_instances = num_temp_instances
41 | self.default_time_interval = default_time_interval
42 | self.confidence_decay = confidence_decay
43 | self.max_time_interval = max_time_interval
44 |
45 | if anchor_handler is not None:
46 | anchor_handler = build_from_cfg(anchor_handler, PLUGIN_LAYERS)
47 | assert hasattr(anchor_handler, "anchor_projection")
48 | self.anchor_handler = anchor_handler
49 |
50 | if isinstance(anchor, str):
51 | anchor = np.load(anchor)
52 | elif isinstance(anchor, (list, tuple)):
53 | anchor = np.array(anchor)
54 | self.num_anchor = min(len(anchor), num_anchor)
55 | anchor = anchor[:num_anchor]
56 |
57 | self.anchor_init = anchor
58 | self.anchor = nn.Parameter(
59 | torch.tensor(anchor, dtype=torch.float32), requires_grad=anchor_grad)
60 | self.instance_feature = nn.Parameter(
61 | torch.zeros([self.anchor.shape[0], self.embed_dims]), requires_grad=feat_grad)
62 | self.reset()
63 |
64 | def init_weight(self):
65 | self.anchor.data = self.anchor.data.new_tensor(self.anchor_init)
66 | if self.instance_feature.requires_grad:
67 | torch.nn.init.xavier_uniform_(self.instance_feature.data, gain=1)
68 |
69 | def reset(self):
70 | self.cached_feature = None
71 | self.cached_anchor = None
72 | self.metas = None
73 | self.mask = None
74 | self.confidence = None
75 | self.temp_confidence = None
76 | self.instance_id = None
77 | self.prev_id = 0
78 |
79 | def get(self, batch_size, metas=None, dn_metas=None):
80 | instance_feature = torch.tile(self.instance_feature[None], (batch_size, 1, 1))
81 | anchor = torch.tile(self.anchor[None], (batch_size, 1, 1))
82 |
83 | if self.cached_anchor is not None and batch_size == self.cached_anchor.shape[0]:
84 | history_time = self.metas["timestamp"]
85 | time_interval = metas["timestamp"] - history_time
86 | time_interval = time_interval.to(dtype=instance_feature.dtype)
87 | self.mask = torch.abs(time_interval) <= self.max_time_interval
88 |
89 | if self.anchor_handler is not None:
90 | T_temp2cur = self.cached_anchor.new_tensor(
91 | np.stack(
92 | [
93 | x["T_global_inv"] @ self.metas["img_metas"][i]["T_global"]
94 | for i, x in enumerate(metas["img_metas"])
95 | ]
96 | )
97 | )
98 | self.cached_anchor = self.anchor_handler.anchor_projection(
99 | self.cached_anchor, [T_temp2cur], time_intervals=[-time_interval]
100 | )[0]
101 |
102 | if (self.anchor_handler is not None and dn_metas is not None
103 | and batch_size == dn_metas["dn_anchor"].shape[0]):
104 | num_dn_group, num_dn = dn_metas["dn_anchor"].shape[1:3]
105 | dn_anchor = self.anchor_handler.anchor_projection(
106 | dn_metas["dn_anchor"].flatten(1, 2), [T_temp2cur], time_intervals=[-time_interval]
107 | )[0]
108 | dn_metas["dn_anchor"] = dn_anchor.reshape(batch_size, num_dn_group, num_dn, -1)
109 |
110 | time_interval = torch.where(
111 | torch.logical_and(time_interval != 0, self.mask),
112 | time_interval,
113 | time_interval.new_tensor(self.default_time_interval),
114 | )
115 | else:
116 | self.reset()
117 | time_interval = instance_feature.new_tensor([self.default_time_interval] * batch_size)
118 |
119 | return instance_feature, anchor, self.cached_feature, self.cached_anchor, time_interval
120 |
121 | def update(self, instance_feature, anchor, confidence):
122 | if self.cached_feature is None:
123 | return instance_feature, anchor
124 |
125 | num_dn = 0
126 | if instance_feature.shape[1] > self.num_anchor:
127 | num_dn = instance_feature.shape[1] - self.num_anchor
128 | dn_instance_feature = instance_feature[:, -num_dn:]
129 | dn_anchor = anchor[:, -num_dn:]
130 | instance_feature = instance_feature[:, : self.num_anchor]
131 | anchor = anchor[:, : self.num_anchor]
132 | confidence = confidence[:, : self.num_anchor]
133 |
134 | N = self.num_anchor - self.num_temp_instances
135 | confidence = confidence.max(dim=-1).values
136 | _, (selected_feature, selected_anchor) = topk(confidence, N, instance_feature, anchor)
137 |
138 | selected_feature = torch.cat([self.cached_feature, selected_feature], dim=1)
139 | selected_anchor = torch.cat([self.cached_anchor, selected_anchor], dim=1)
140 |
141 | instance_feature = torch.where(self.mask[:, None, None], selected_feature, instance_feature)
142 | anchor = torch.where(self.mask[:, None, None], selected_anchor, anchor)
143 | if self.instance_id is not None:
144 | self.instance_id = torch.where(self.mask[:, None], self.instance_id, self.instance_id.new_tensor(-1))
145 |
146 | if num_dn > 0:
147 | instance_feature = torch.cat([instance_feature, dn_instance_feature], dim=1)
148 | anchor = torch.cat([anchor, dn_anchor], dim=1)
149 |
150 | return instance_feature, anchor
151 |
152 | def cache(self, instance_feature, anchor, confidence, metas=None, feature_maps=None):
153 | if self.num_temp_instances <= 0:
154 | return
155 | instance_feature = instance_feature.detach()
156 | anchor = anchor.detach()
157 | confidence = confidence.detach()
158 |
159 | self.metas = metas
160 | confidence = confidence.max(dim=-1).values.sigmoid()
161 | if self.confidence is not None:
162 | confidence[:, : self.num_temp_instances] = torch.maximum(
163 | self.confidence * self.confidence_decay, confidence[:, : self.num_temp_instances])
164 | self.temp_confidence = confidence
165 |
166 | self.confidence, (self.cached_feature, self.cached_anchor) = topk(
167 | confidence, self.num_temp_instances, instance_feature, anchor)
168 |
169 | def get_instance_id(self, confidence, anchor=None, threshold=None):
170 | confidence = confidence.max(dim=-1).values.sigmoid()
171 | instance_id = confidence.new_full(confidence.shape, -1).long()
172 |
173 | if self.instance_id is not None and self.instance_id.shape[0] == instance_id.shape[0]:
174 | instance_id[:, : self.instance_id.shape[1]] = self.instance_id
175 |
176 | mask = instance_id < 0
177 | if threshold is not None:
178 | mask = mask & (confidence >= threshold)
179 | num_new_instance = mask.sum()
180 | new_ids = torch.arange(num_new_instance).to(instance_id) + self.prev_id
181 | instance_id[torch.where(mask)] = new_ids
182 | self.prev_id += num_new_instance
183 | self.update_instance_id(instance_id, confidence)
184 | return instance_id
185 |
186 | def update_instance_id(self, instance_id=None, confidence=None):
187 | if self.temp_confidence is None:
188 | if confidence.dim() == 3: # bs, num_anchor, num_cls
189 | temp_conf = confidence.max(dim=-1).values
190 | else: # bs, num_anchor
191 | temp_conf = confidence
192 | else:
193 | temp_conf = self.temp_confidence
194 | instance_id = topk(temp_conf, self.num_temp_instances, instance_id)[1][0]
195 | instance_id = instance_id.squeeze(dim=-1)
196 | self.instance_id = F.pad(instance_id, (0, self.num_anchor - self.num_temp_instances), value=-1)
197 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/detection3d/decoder.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 |
5 | from mmdet.core.bbox.builder import BBOX_CODERS
6 | from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy
7 | from projects.mmdet3d_plugin.core.box3d import *
8 |
9 |
10 | @BBOX_CODERS.register_module()
11 | class SparseBox3DDecoder(object):
12 | def __init__(
13 | self,
14 | num_output: int = 300,
15 | score_threshold: Optional[float] = None,
16 | sorted: bool = True,
17 | ):
18 | super(SparseBox3DDecoder, self).__init__()
19 | self.num_output = num_output
20 | self.score_threshold = score_threshold
21 | self.sorted = sorted
22 |
23 | def decode_box(self, box):
24 | yaw = torch.atan2(box[:, SIN_YAW], box[:, COS_YAW])
25 | box = torch.cat(
26 | [
27 | box[:, [X, Y, Z]],
28 | box[:, [W, L, H]].exp(),
29 | yaw[:, None],
30 | box[:, VX:],
31 | ],
32 | dim=-1,
33 | )
34 | return box
35 |
36 | def decode_box2d(self, box, aug_config):
37 | crop = aug_config['crop']
38 | scale_factor = aug_config['resize']
39 |
40 | crop_img_size = (crop[2] - crop[0], crop[3] - crop[1])
41 |
42 | box = bbox_cxcywh_to_xyxy(box)
43 |
44 | box[..., 0::2] = box[..., 0::2] * crop_img_size[0]
45 | box[..., 1::2] = box[..., 1::2] * crop_img_size[1]
46 | box[..., 0::2].clamp_(min=0, max=crop_img_size[0])
47 | box[..., 1::2].clamp_(min=0, max=crop_img_size[1])
48 |
49 | box[..., 1::2] += crop[1]
50 | box /= box.new_tensor(scale_factor)
51 | return box
52 |
53 | def decode(
54 | self,
55 | cls_scores,
56 | box_preds,
57 | instance_id=None,
58 | qulity=None,
59 | output_idx=-1,
60 | ):
61 | squeeze_cls = instance_id is not None
62 |
63 | cls_scores = cls_scores[output_idx].sigmoid()
64 |
65 | if squeeze_cls:
66 | cls_scores, cls_ids = cls_scores.max(dim=-1)
67 | cls_scores = cls_scores.unsqueeze(dim=-1)
68 |
69 | box_preds = box_preds[output_idx]
70 | bs, num_pred, num_cls = cls_scores.shape
71 | cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
72 | self.num_output, dim=1, sorted=self.sorted
73 | )
74 | if not squeeze_cls:
75 | cls_ids = indices % num_cls
76 | if self.score_threshold is not None:
77 | mask = cls_scores >= self.score_threshold
78 |
79 | if qulity is not None:
80 | centerness = qulity[output_idx][..., CNS]
81 | centerness = torch.gather(centerness, 1, indices // num_cls)
82 | cls_scores_origin = cls_scores.clone()
83 | cls_scores *= centerness.sigmoid()
84 | cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True)
85 | if not squeeze_cls:
86 | cls_ids = torch.gather(cls_ids, 1, idx)
87 | if self.score_threshold is not None:
88 | mask = torch.gather(mask, 1, idx)
89 | indices = torch.gather(indices, 1, idx)
90 |
91 | output = []
92 | for i in range(bs):
93 | category_ids = cls_ids[i]
94 | if squeeze_cls:
95 | category_ids = category_ids[indices[i]]
96 | scores = cls_scores[i]
97 | box = box_preds[i, indices[i] // num_cls]
98 | if self.score_threshold is not None:
99 | category_ids = category_ids[mask[i]]
100 | scores = scores[mask[i]]
101 | box = box[mask[i]]
102 | if qulity is not None:
103 | scores_origin = cls_scores_origin[i]
104 | if self.score_threshold is not None:
105 | scores_origin = scores_origin[mask[i]]
106 |
107 | box = self.decode_box(box)
108 | output.append(
109 | {
110 | "boxes_3d": box.cpu(),
111 | "scores_3d": scores.cpu(),
112 | "labels_3d": category_ids.cpu(),
113 | }
114 | )
115 | if qulity is not None:
116 | output[-1]["cls_scores"] = scores_origin.cpu()
117 | if instance_id is not None:
118 | ids = instance_id[i, indices[i]]
119 | if self.score_threshold is not None:
120 | ids = ids[mask[i]]
121 | output[-1]["instance_ids"] = ids
122 | return output
123 |
124 | def decode_with2d(
125 | self,
126 | cls_scores,
127 | box_preds,
128 | instance_id=None,
129 | qulity=None,
130 | output_idx=-1,
131 | cls_scores2d=None,
132 | box_preds2d=None,
133 | trans_matrix=None,
134 | query_groups=None,
135 | output_idx2d=-1,
136 | aug_configs=None,
137 | with_association=False
138 | ):
139 | squeeze_cls = instance_id is not None
140 |
141 | cls_scores = cls_scores[output_idx].sigmoid()
142 |
143 | if squeeze_cls:
144 | cls_scores, cls_ids = cls_scores.max(dim=-1)
145 | cls_scores = cls_scores.unsqueeze(dim=-1)
146 |
147 | box_preds = box_preds[output_idx]
148 | bs, num_pred, num_cls = cls_scores.shape
149 | cls_scores, indices = cls_scores.flatten(start_dim=1).topk(
150 | self.num_output, dim=1, sorted=self.sorted
151 | )
152 | if not squeeze_cls:
153 | cls_ids = indices % num_cls
154 | if self.score_threshold is not None:
155 | mask = cls_scores >= self.score_threshold
156 |
157 | if qulity is not None:
158 | centerness = qulity[output_idx][..., CNS]
159 | centerness = torch.gather(centerness, 1, indices // num_cls)
160 | cls_scores_origin = cls_scores.clone()
161 | cls_scores *= centerness.sigmoid()
162 | cls_scores, idx = torch.sort(cls_scores, dim=1, descending=True)
163 | if not squeeze_cls:
164 | cls_ids = torch.gather(cls_ids, 1, idx)
165 | if self.score_threshold is not None:
166 | mask = torch.gather(mask, 1, idx)
167 | indices = torch.gather(indices, 1, idx)
168 |
169 | cls_scores2d = cls_scores2d[output_idx2d]
170 | box_preds2d = box_preds2d[output_idx2d]
171 | trans_matrix_t = trans_matrix[output_idx2d].permute(0, 2, 1)
172 | query_groups = query_groups[output_idx2d]
173 | aug_config = aug_configs[0]
174 |
175 | output = []
176 | for i in range(bs):
177 | category_ids = cls_ids[i]
178 | if squeeze_cls:
179 | category_ids = category_ids[indices[i]]
180 | scores = cls_scores[i]
181 | box = box_preds[i, indices[i] // num_cls]
182 |
183 | # 2d
184 | assert num_cls == 1
185 |
186 | # get indices2d and trans_t
187 | if with_association:
188 | trans_t = trans_matrix_t[i, indices[i]]
189 | indices2d = torch.where(trans_t.any(0))[0]
190 | trans_t = torch.index_select(trans_t, 1, indices2d).cpu()
191 | else:
192 | indices2d = torch.arange(len(box_preds2d[i]))
193 | trans_t = None
194 |
195 | # get new query_groups
196 | camidx_2d = []
197 | query_groups_new = []
198 | for cam_idx, qg in enumerate(query_groups):
199 | parts_index = torch.where(torch.logical_and(qg[0] <= indices2d, indices2d < qg[1]))[0]
200 |
201 | if len(parts_index) > 0:
202 | qg_new = (parts_index[0].cpu().item(), parts_index[-1].cpu().item() + 1)
203 | elif len(query_groups_new) > 0:
204 | qg_new = (query_groups_new[-1][-1], query_groups_new[-1][-1])
205 | else:
206 | qg_new = (0, 0)
207 |
208 | camidx_2d.append(torch.ones((len(parts_index))) * cam_idx)
209 | query_groups_new.append(qg_new)
210 |
211 | camidx_2d = torch.cat(camidx_2d, dim=0)
212 | query_groups = query_groups_new
213 |
214 | # resize box to original img
215 | scores2d, category_ids2d = cls_scores2d[i, indices2d].sigmoid().max(dim=-1)
216 | box2d = box_preds2d[i, indices2d]
217 | box2d = self.decode_box2d(box2d, aug_config)
218 |
219 | if self.score_threshold is not None:
220 | category_ids = category_ids[mask[i]]
221 | scores = scores[mask[i]]
222 | box = box[mask[i]]
223 | if qulity is not None:
224 | scores_origin = cls_scores_origin[i]
225 | if self.score_threshold is not None:
226 | scores_origin = scores_origin[mask[i]]
227 |
228 | box = self.decode_box(box)
229 |
230 | output.append(
231 | {
232 | # 3d
233 | "boxes_3d": box.cpu(),
234 | "scores_3d": scores.cpu(),
235 | "labels_3d": category_ids.cpu(),
236 | # 2d
237 | "boxes_2d": box2d.cpu(),
238 | "scores_2d": scores2d.cpu(),
239 | "labels_2d": category_ids2d.cpu(),
240 | "camidx_2d": camidx_2d,
241 | "trans_matrix": trans_t,
242 | "query_groups": query_groups
243 | }
244 | )
245 | if qulity is not None:
246 | output[-1]["cls_scores"] = scores_origin.cpu()
247 | if instance_id is not None:
248 | ids = instance_id[i, indices[i]]
249 | if self.score_threshold is not None:
250 | ids = ids[mask[i]]
251 | output[-1]["instance_ids"] = ids
252 | return output
253 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/detection2d/coster.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from scipy.optimize import linear_sum_assignment
4 |
5 | from mmdet.core.bbox.builder import BBOX_SAMPLERS
6 | from mmdet.core.bbox.builder import BBOX_ASSIGNERS
7 | from mmdet.core.bbox.assigners import HungarianAssigner
8 | from mmdet.core.bbox.assigners.assign_result import AssignResult
9 | from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
10 | from mmdet.core.bbox.match_costs import build_match_cost
11 |
12 | __all__ = ["SparseBox2DCoster"]
13 |
14 |
15 | @BBOX_SAMPLERS.register_module()
16 | class SparseBox2DCoster(object):
17 | def __init__(self,
18 | eps=1e-12,
19 | cls_cost=None,
20 | reg_cost=None,
21 | iou_cost=None):
22 | super(SparseBox2DCoster, self).__init__()
23 | self.eps = eps
24 | self.cls_cost = build_match_cost(cls_cost)
25 | self.reg_cost = build_match_cost(reg_cost)
26 | self.iou_cost = build_match_cost(iou_cost)
27 |
28 |
29 | def cost(self, cls_pred, box_pred, cls_target, box_target,
30 | data, trans_shape=None, query_groups=None):
31 | bs, num_pred, num_cls = cls_pred.shape
32 | img_w, img_h = map(int, data['image_wh'][0, 0].tolist())
33 | factor = box_pred[0].new_tensor([img_w, img_h, img_w, img_h]).unsqueeze(0)
34 |
35 | all_reg_weights = []
36 | for i in range(bs):
37 | reg_weights = []
38 | for j in range(len(query_groups)):
39 | weights = torch.logical_not(box_target[i][j].isnan()).to(dtype=box_target[i][j].dtype)
40 | reg_weights.append(weights)
41 | all_reg_weights.append(reg_weights)
42 |
43 | cls_cost = self._cls_cost(cls_pred, cls_target, query_groups)
44 | reg_cost = self._reg_cost(box_pred, box_target, query_groups, factor)
45 | iou_cost = self._iou_cost(box_pred, box_target, query_groups, factor)
46 |
47 | all_costs = []
48 | for i in range(bs):
49 | costs = []
50 | for j in range(len(query_groups)):
51 | if cls_cost[i][j] is not None and reg_cost[i][j] is not None and iou_cost[i][j] is not None:
52 | cost = (cls_cost[i][j] + reg_cost[i][j] + iou_cost[i][j]).detach().cpu().numpy()
53 | if cost.size > 0 and trans_shape is not None:
54 | cost[trans_shape[i][j]:, :] = cost.max()
55 | cost = np.where(np.isneginf(cost) | np.isnan(cost), 1e8, cost)
56 | costs.append(cost)
57 | else:
58 | costs.append(None)
59 | all_costs.append(costs)
60 |
61 | return all_costs
62 |
63 | def trans_cost(self, all_costs, pred_cls_2d, gt_cls_2d, gt_cls_3d,
64 | ref_trans_matrix, query_groups, gt_2d_3d_map):
65 |
66 | bs, num_query2d, num_cls2d = pred_cls_2d.shape
67 | num_query3d = ref_trans_matrix.shape[-1]
68 |
69 | cost2d_map3d_list = []
70 | for i in range(bs):
71 | num_target2d = sum([len(x) for x in gt_cls_2d[i]])
72 | num_target3d = len(gt_cls_3d[i])
73 |
74 | if num_target2d >0 and num_target3d>0:
75 | # 1. extend cost2d to large map
76 | cost2d_extend = np.ones((num_query2d, num_target2d), dtype=np.float32) * (-1 / self.eps)
77 |
78 | gt_qg_shape = np.cumsum([0] + [len(x) for x in gt_cls_2d[i]])
79 | gt_qg_groups = [(qg[0], qg[1]) for qg in zip(gt_qg_shape[:-1], gt_qg_shape[1:])]
80 |
81 | for j, (qg, gp) in enumerate(zip(query_groups, gt_qg_groups)):
82 | if all_costs[i][j] is not None:
83 | cost2d_extend[qg[0]:qg[1], gp[0]:gp[1]] = all_costs[i][j]
84 |
85 | if cost2d_extend.max() == (-1 / self.eps):
86 | cost2d_extend = 0
87 |
88 | cost2d_extend[cost2d_extend == (-1 / self.eps)] = cost2d_extend.max()
89 |
90 | # 2. trans cost2d_extend to cost2d_map3d
91 | map_trans_matrix = torch.zeros((num_target2d, num_target3d))
92 | map_trans_matrix[torch.arange(num_target2d), torch.cat(gt_2d_3d_map[i]).cpu()] = 1
93 | ref_trans_matrix_t = ref_trans_matrix[i].cpu().T
94 |
95 | cost2d_extend = torch.from_numpy(cost2d_extend)
96 | cost2d_map3d = (cost2d_extend @ map_trans_matrix) / torch.clamp(map_trans_matrix.sum(0), 1e-5).unsqueeze(0)
97 | cost2d_map3d = (ref_trans_matrix_t @ cost2d_map3d) / torch.clamp(ref_trans_matrix_t.sum(-1), 1e-5).unsqueeze(-1)
98 |
99 | map_mask = torch.logical_or((cost2d_map3d.sum(0) == 0)[None, :],
100 | (cost2d_map3d.sum(1) == 0)[:, None])
101 |
102 | cost2d_map3d[map_mask] = cost2d_map3d.max()
103 | cost2d_map3d_list.append(cost2d_map3d.numpy())
104 |
105 | else:
106 | cost2d_map3d = np.zeros((num_query3d, num_target3d), dtype=np.float32)
107 | cost2d_map3d_list.append(cost2d_map3d)
108 |
109 | return cost2d_map3d_list
110 |
111 |
112 | def sample(self, cls_pred, box_pred, depth_pred, cls_target, box_target, depth_target,
113 | data, trans_matrix = None, query_groups = None, cost_list=None,
114 | alpha=None, alpha_targets=None):
115 |
116 | bs, num_pred, num_cls = cls_pred.shape
117 |
118 | all_reg_weights = []
119 | for i in range(bs):
120 | reg_weights = []
121 | for j in range(len(query_groups)):
122 | weights = torch.logical_not(box_target[i][j].isnan()).to(dtype=box_target[i][j].dtype)
123 | reg_weights.append(weights)
124 | all_reg_weights.append(reg_weights)
125 |
126 |
127 | all_indices = []
128 | for i in range(bs):
129 | indices = []
130 | for j in range(len(query_groups)):
131 | if cost_list[i][j] is not None and cost_list[i][j].size>0:
132 | cost = cost_list[i][j]
133 | cost = np.where(np.isneginf(cost) | np.isnan(cost), 1e8, cost)
134 | indices.append(
135 | [
136 | cls_pred.new_tensor(x, dtype=torch.int64)
137 | for x in linear_sum_assignment(cost)
138 | ]
139 | )
140 | else:
141 | indices.append(None)
142 | all_indices.append(indices)
143 |
144 | output_cls_target = cls_target[0][0].new_ones([bs, num_pred], dtype=torch.long) * -1
145 | output_box_target = box_pred.new_zeros(box_pred.shape)
146 | output_reg_weights = box_pred.new_zeros(box_pred.shape)
147 |
148 | output_box_depth = None
149 | if depth_pred is not None:
150 | output_box_depth = box_pred.new_zeros(box_pred[..., 1].shape)
151 |
152 | output_alpha_target = None
153 | if alpha_targets is not None and alpha is not None:
154 | output_alpha_target = alpha.new_zeros(alpha.shape)
155 |
156 | for i in range(bs):
157 | for j, qg in enumerate(query_groups):
158 | if len(cls_target[i][j]) > 0 and all_indices[i][j] is not None:
159 | pred_idx, target_idx = all_indices[i][j]
160 |
161 | output_cls_target[i, pred_idx + qg[0]] = cls_target[i][j][target_idx]
162 | output_box_target[i, pred_idx + qg[0]] = box_target[i][j][target_idx]
163 | output_reg_weights[i, pred_idx + qg[0]] = all_reg_weights[i][j][target_idx]
164 |
165 | if alpha_targets is not None:
166 | assigned_alpha = alpha_targets[i][j][target_idx]
167 | if len(assigned_alpha.shape) == 1: # for scalar alpha, not multibin
168 | assigned_alpha_target = torch.stack((torch.sin(assigned_alpha), torch.cos(assigned_alpha))).transpose(1,0)#view(-1, 2)
169 |
170 | if output_alpha_target is not None:
171 | output_alpha_target[i, pred_idx + qg[0]] = assigned_alpha_target
172 |
173 | if depth_pred is not None:
174 | output_box_depth[i, pred_idx + qg[0]] = depth_target[i][j][target_idx]
175 |
176 | return output_cls_target, output_box_target, output_alpha_target, output_box_depth, output_reg_weights
177 |
178 |
179 | def _cls_cost(self, cls_pred, cls_target, query_groups):
180 | bs = cls_pred.shape[0]
181 | all_cost = []
182 | for i in range(bs):
183 | cost = []
184 | for j, qg in enumerate(query_groups):
185 | if len(cls_target[i][j]) > 0:
186 | cls_cost = self.cls_cost(cls_pred[i][qg[0]:qg[1]],
187 | cls_target[i][j])
188 | cost.append(cls_cost)
189 | else:
190 | cost.append(None)
191 | all_cost.append(cost)
192 | return all_cost
193 |
194 |
195 | def _reg_cost(self, box_pred, box_target, query_groups, factor):
196 | bs = box_pred.shape[0]
197 | all_cost = []
198 | for i in range(bs):
199 | cost = []
200 | for j, qg in enumerate(query_groups):
201 | if len(box_target[i][j]) > 0:
202 | # normalized resolution
203 | reg_cost = self.reg_cost(box_pred[i][qg[0]:qg[1]],
204 | box_target[i][j][:, 0:4] / factor)
205 | cost.append(reg_cost)
206 | else:
207 | cost.append(None)
208 | all_cost.append(cost)
209 | return all_cost
210 |
211 |
212 | def _iou_cost(self, box_pred, box_target, query_groups, factor):
213 | bs = box_pred.shape[0]
214 | all_cost = []
215 | for i in range(bs):
216 | cost = []
217 | for j, qg in enumerate(query_groups):
218 | if len(box_target[i][j])>0:
219 | # original resolution
220 | iou_cost = self.iou_cost(bbox_cxcywh_to_xyxy(box_pred[i][qg[0]:qg[1]]) * factor,
221 | box_target[i][j][:, 0:4])
222 | cost.append(iou_cost)
223 | else:
224 | cost.append(None)
225 | all_cost.append(cost)
226 | return all_cost
227 |
228 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/datasets/pipelines/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import mmcv
3 | import torch
4 | from mmcv.parallel import DataContainer as DC
5 | from mmdet.datasets.builder import PIPELINES
6 | from mmdet.datasets.pipelines import to_tensor
7 |
8 | def filter_info2d(input_dict, mask):
9 | if isinstance(mask, torch.Tensor):
10 | mask = mask.numpy()
11 |
12 | maskes_2d = [np.ones(len(x), dtype=np.bool) for x in input_dict['gt_bboxes_2d']]
13 | for cam_idx, map_2d_3d in enumerate(input_dict['gt_2d_3d_map']):
14 | for index_2d, index_3d in enumerate(map_2d_3d):
15 | if index_3d in np.where(~mask)[0]:
16 | maskes_2d[cam_idx][index_2d] = False
17 |
18 | trans_index = np.ones((len(mask)), dtype=np.long) * -1
19 | trans_index[mask] = 1
20 | trans_index[mask] = np.cumsum(trans_index[mask]) - 1
21 | trans_index = np.concatenate([trans_index, -np.ones(1, dtype=np.long)])
22 |
23 | for cam_idx, mask_2d in enumerate(maskes_2d):
24 | input_dict['gt_bboxes_2d'][cam_idx] = input_dict['gt_bboxes_2d'][cam_idx][mask_2d]
25 | input_dict['gt_labels_2d'][cam_idx] = input_dict['gt_labels_2d'][cam_idx][mask_2d]
26 | input_dict['gt_centers_2d'][cam_idx] = input_dict['gt_centers_2d'][cam_idx][mask_2d]
27 | input_dict['gt_depths_2d'][cam_idx] = input_dict['gt_depths_2d'][cam_idx][mask_2d]
28 | input_dict['gt_alphas_2d'][cam_idx] = input_dict['gt_alphas_2d'][cam_idx][mask_2d]
29 | input_dict['gt_2d_3d_map'][cam_idx] = trans_index[input_dict['gt_2d_3d_map'][cam_idx][mask_2d]]
30 |
31 | return input_dict
32 |
33 |
34 | @PIPELINES.register_module()
35 | class MultiScaleDepthMapGenerator(object):
36 | def __init__(self, downsample=1, max_depth=60):
37 | if not isinstance(downsample, (list, tuple)):
38 | downsample = [downsample]
39 | self.downsample = downsample
40 | self.max_depth = max_depth
41 |
42 | def __call__(self, input_dict):
43 | points = input_dict["points"][..., :3, None]
44 | gt_depth = []
45 | for i, lidar2img in enumerate(input_dict["lidar2img"]):
46 | H, W = input_dict["img_shape"][i][:2]
47 |
48 | pts_2d = (
49 | np.squeeze(lidar2img[:3, :3] @ points, axis=-1)
50 | + lidar2img[:3, 3]
51 | )
52 | pts_2d[:, :2] /= pts_2d[:, 2:3]
53 | U = np.round(pts_2d[:, 0]).astype(np.int32)
54 | V = np.round(pts_2d[:, 1]).astype(np.int32)
55 | depths = pts_2d[:, 2]
56 | mask = np.logical_and.reduce(
57 | [
58 | V >= 0,
59 | V < H,
60 | U >= 0,
61 | U < W,
62 | depths >= 0.1,
63 | # depths <= self.max_depth,
64 | ]
65 | )
66 | V, U, depths = V[mask], U[mask], depths[mask]
67 | sort_idx = np.argsort(depths)[::-1]
68 | V, U, depths = V[sort_idx], U[sort_idx], depths[sort_idx]
69 | depths = np.clip(depths, 0.1, self.max_depth)
70 | for j, downsample in enumerate(self.downsample):
71 | if len(gt_depth) < j + 1:
72 | gt_depth.append([])
73 | h, w = (int(H / downsample), int(W / downsample))
74 | u = np.floor(U / downsample).astype(np.int32)
75 | v = np.floor(V / downsample).astype(np.int32)
76 | depth_map = np.ones([h, w], dtype=np.float32) * -1
77 | depth_map[v, u] = depths
78 | gt_depth[j].append(depth_map)
79 |
80 | input_dict["gt_depth"] = [np.stack(x) for x in gt_depth]
81 | return input_dict
82 |
83 |
84 | @PIPELINES.register_module()
85 | class NuScenesSparse4DAdaptor(object):
86 | def __init(self):
87 | pass
88 |
89 | def __call__(self, input_dict):
90 | input_dict["projection_mat"] = np.float32(np.stack(input_dict["lidar2img"]))
91 | input_dict["image_wh"] = np.ascontiguousarray(
92 | np.array(input_dict["img_shape"], dtype=np.float32)[:, :2][:, ::-1])
93 | input_dict["T_global_inv"] = np.linalg.inv(input_dict["lidar2global"])
94 | input_dict["T_global"] = input_dict["lidar2global"]
95 |
96 | if "cam_intrinsic" in input_dict:
97 | input_dict["cam_intrinsic"] = np.float32(
98 | np.stack(input_dict["cam_intrinsic"])
99 | )
100 | input_dict["focal"] = input_dict["cam_intrinsic"][..., 0, 0]
101 | # input_dict["focal"] = np.sqrt(
102 | # np.abs(np.linalg.det(input_dict["cam_intrinsic"][:, :2, :2]))
103 | # )
104 | if "instance_inds" in input_dict:
105 | input_dict["instance_id"] = input_dict["instance_inds"]
106 |
107 | if "gt_bboxes_3d" in input_dict:
108 | input_dict["gt_bboxes_3d"][:, 6] = self.limit_period(
109 | input_dict["gt_bboxes_3d"][:, 6], offset=0.5, period=2 * np.pi
110 | )
111 | input_dict["gt_bboxes_3d"] = DC(
112 | to_tensor(input_dict["gt_bboxes_3d"]).float()
113 | )
114 | if "gt_labels_3d" in input_dict:
115 | input_dict["gt_labels_3d"] = DC(
116 | to_tensor(input_dict["gt_labels_3d"]).long()
117 | )
118 |
119 | for key in ['gt_bboxes_2d', 'gt_labels_2d', 'gt_centers_2d',
120 | 'gt_depths_2d', 'gt_2d_3d_map', 'gt_alphas_2d']:
121 | if key not in input_dict:
122 | continue
123 | if isinstance(input_dict[key], list):
124 | input_dict[key] = DC([to_tensor(res) for res in input_dict[key]])
125 | else:
126 | input_dict[key] = DC(to_tensor(input_dict[key]))
127 |
128 | imgs = [img.transpose(2, 0, 1) for img in input_dict["img"]]
129 | imgs = np.ascontiguousarray(np.stack(imgs, axis=0))
130 | input_dict["img"] = DC(to_tensor(imgs), stack=True)
131 |
132 | input_dict["intrinsics"] = np.float32(np.stack(input_dict["intrinsics"]))
133 | input_dict["extrinsics"] = np.float32(np.stack(input_dict["extrinsics"]))
134 |
135 | return input_dict
136 |
137 | def limit_period(
138 | self, val: np.ndarray, offset: float = 0.5, period: float = np.pi
139 | ) -> np.ndarray:
140 | limited_val = val - np.floor(val / period + offset) * period
141 | return limited_val
142 |
143 |
144 | @PIPELINES.register_module()
145 | class InstanceNameFilter(object):
146 | """Filter GT objects by their names.
147 |
148 | Args:
149 | classes (list[str]): List of class names to be kept for training.
150 | """
151 |
152 | def __init__(self, classes):
153 | self.classes = classes
154 | self.labels = list(range(len(self.classes)))
155 |
156 | def __call__(self, input_dict):
157 | """Call function to filter objects by their names.
158 |
159 | Args:
160 | input_dict (dict): Result dict from loading pipeline.
161 |
162 | Returns:
163 | dict: Results after filtering, 'gt_bboxes_3d', 'gt_labels_3d' \
164 | keys are updated in the result dict.
165 | """
166 | gt_labels_3d = input_dict["gt_labels_3d"]
167 | gt_bboxes_mask = np.array([n in self.labels for n in gt_labels_3d], dtype=np.bool_)
168 |
169 | input_dict["gt_bboxes_3d"] = input_dict["gt_bboxes_3d"][gt_bboxes_mask]
170 | input_dict["gt_labels_3d"] = input_dict["gt_labels_3d"][gt_bboxes_mask]
171 | if "instance_inds" in input_dict:
172 | input_dict["instance_inds"] = input_dict["instance_inds"][gt_bboxes_mask]
173 |
174 | # update 2d labels
175 | if 'gt_bboxes_2d' in input_dict:
176 | input_dict = filter_info2d(input_dict, gt_bboxes_mask)
177 |
178 | return input_dict
179 |
180 | def __repr__(self):
181 | """str: Return a string that describes the module."""
182 | repr_str = self.__class__.__name__
183 | repr_str += f"(classes={self.classes})"
184 | return repr_str
185 |
186 |
187 | @PIPELINES.register_module()
188 | class CircleObjectRangeFilter(object):
189 | def __init__(
190 | self, class_dist_thred=[52.5] * 5 + [31.5] + [42] * 3 + [31.5]
191 | ):
192 | self.class_dist_thred = class_dist_thred
193 |
194 | def __call__(self, input_dict):
195 | gt_bboxes_3d = input_dict["gt_bboxes_3d"]
196 | gt_labels_3d = input_dict["gt_labels_3d"]
197 | dist = np.sqrt(
198 | np.sum(gt_bboxes_3d[:, :2] ** 2, axis=-1)
199 | )
200 | mask = np.array([False] * len(dist))
201 | for label_idx, dist_thred in enumerate(self.class_dist_thred):
202 | mask = np.logical_or(
203 | mask,
204 | np.logical_and(gt_labels_3d == label_idx, dist <= dist_thred),
205 | )
206 |
207 | gt_bboxes_3d = gt_bboxes_3d[mask]
208 | gt_labels_3d = gt_labels_3d[mask]
209 |
210 | # update 2d labels
211 | if 'gt_bboxes_2d' in input_dict:
212 | input_dict = filter_info2d(input_dict, mask)
213 |
214 | input_dict["gt_bboxes_3d"] = gt_bboxes_3d
215 | input_dict["gt_labels_3d"] = gt_labels_3d
216 |
217 | if "instance_inds" in input_dict:
218 | input_dict["instance_inds"] = input_dict["instance_inds"][mask]
219 |
220 | return input_dict
221 |
222 | def __repr__(self):
223 | """str: Return a string that describes the module."""
224 | repr_str = self.__class__.__name__
225 | repr_str += f"(class_dist_thred={self.class_dist_thred})"
226 | return repr_str
227 |
228 |
229 | @PIPELINES.register_module()
230 | class NormalizeMultiviewImage(object):
231 | """Normalize the image.
232 | Added key is "img_norm_cfg".
233 | Args:
234 | mean (sequence): Mean values of 3 channels.
235 | std (sequence): Std values of 3 channels.
236 | to_rgb (bool): Whether to convert the image from BGR to RGB,
237 | default is true.
238 | """
239 |
240 | def __init__(self, mean, std, to_rgb=True):
241 | self.mean = np.array(mean, dtype=np.float32)
242 | self.std = np.array(std, dtype=np.float32)
243 | self.to_rgb = to_rgb
244 |
245 | def __call__(self, results):
246 | """Call function to normalize images.
247 | Args:
248 | results (dict): Result dict from loading pipeline.
249 | Returns:
250 | dict: Normalized results, 'img_norm_cfg' key is added into
251 | result dict.
252 | """
253 | results["img"] = [
254 | mmcv.imnormalize(img, self.mean, self.std, self.to_rgb)
255 | for img in results["img"]
256 | ]
257 | results["img_norm_cfg"] = dict(
258 | mean=self.mean, std=self.std, to_rgb=self.to_rgb
259 | )
260 | return results
261 |
262 | def __repr__(self):
263 | repr_str = self.__class__.__name__
264 | repr_str += f"(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})"
265 | return repr_str
266 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/models/detection3d/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from mmcv.cnn import Linear, Scale, bias_init_with_prob
6 | from mmcv.runner.base_module import Sequential, BaseModule
7 | from mmcv.cnn import xavier_init
8 | from mmcv.cnn.bricks.registry import (
9 | PLUGIN_LAYERS,
10 | POSITIONAL_ENCODING,
11 | )
12 |
13 | from projects.mmdet3d_plugin.core.box3d import *
14 | from ..blocks import linear_relu_ln
15 |
16 | __all__ = [
17 | "SparseBox3DRefinementModule",
18 | "SparseBox3DKeyPointsGenerator",
19 | "SparseBox3DEncoder",
20 | ]
21 |
22 |
23 | @POSITIONAL_ENCODING.register_module()
24 | class SparseBox3DEncoder(BaseModule):
25 | def __init__(
26 | self,
27 | embed_dims,
28 | vel_dims=3,
29 | mode="add",
30 | output_fc=True,
31 | in_loops=1,
32 | out_loops=2,
33 | ):
34 | super().__init__()
35 | assert mode in ["add", "cat"]
36 | self.embed_dims = embed_dims
37 | self.vel_dims = vel_dims
38 | self.mode = mode
39 |
40 | def embedding_layer(input_dims, output_dims):
41 | return nn.Sequential(
42 | *linear_relu_ln(output_dims, in_loops, out_loops, input_dims)
43 | )
44 |
45 | if not isinstance(embed_dims, (list, tuple)):
46 | embed_dims = [embed_dims] * 5
47 | self.pos_fc = embedding_layer(3, embed_dims[0])
48 | self.size_fc = embedding_layer(3, embed_dims[1])
49 | self.yaw_fc = embedding_layer(2, embed_dims[2])
50 | if vel_dims > 0:
51 | self.vel_fc = embedding_layer(self.vel_dims, embed_dims[3])
52 | if output_fc:
53 | self.output_fc = embedding_layer(embed_dims[-1], embed_dims[-1])
54 | else:
55 | self.output_fc = None
56 |
57 | def forward(self, box_3d: torch.Tensor):
58 | pos_feat = self.pos_fc(box_3d[..., [X, Y, Z]])
59 | size_feat = self.size_fc(box_3d[..., [W, L, H]])
60 | yaw_feat = self.yaw_fc(box_3d[..., [SIN_YAW, COS_YAW]])
61 | if self.mode == "add":
62 | output = pos_feat + size_feat + yaw_feat
63 | elif self.mode == "cat":
64 | output = torch.cat([pos_feat, size_feat, yaw_feat], dim=-1)
65 |
66 | if self.vel_dims > 0:
67 | vel_feat = self.vel_fc(box_3d[..., VX : VX + self.vel_dims])
68 | if self.mode == "add":
69 | output = output + vel_feat
70 | elif self.mode == "cat":
71 | output = torch.cat([output, vel_feat], dim=-1)
72 | if self.output_fc is not None:
73 | output = self.output_fc(output)
74 | return output
75 |
76 |
77 | @PLUGIN_LAYERS.register_module()
78 | class SparseBox3DRefinementModule(BaseModule):
79 | def __init__(
80 | self,
81 | embed_dims=256,
82 | output_dim=11,
83 | num_cls=10,
84 | normalize_yaw=False,
85 | refine_yaw=False,
86 | with_cls_branch=True,
87 | with_quality_estimation=False,
88 | ):
89 | super(SparseBox3DRefinementModule, self).__init__()
90 | self.embed_dims = embed_dims
91 | self.output_dim = output_dim
92 | self.num_cls = num_cls
93 | self.normalize_yaw = normalize_yaw
94 | self.refine_yaw = refine_yaw
95 |
96 | self.refine_state = [X, Y, Z, W, L, H]
97 | if self.refine_yaw:
98 | self.refine_state += [SIN_YAW, COS_YAW]
99 |
100 | self.layers = nn.Sequential(
101 | *linear_relu_ln(embed_dims, 2, 2),
102 | Linear(self.embed_dims, self.output_dim),
103 | Scale([1.0] * self.output_dim),
104 | )
105 | self.with_cls_branch = with_cls_branch
106 | if with_cls_branch:
107 | self.cls_layers = nn.Sequential(
108 | *linear_relu_ln(embed_dims, 1, 2),
109 | Linear(self.embed_dims, self.num_cls),
110 | )
111 | self.with_quality_estimation = with_quality_estimation
112 | if with_quality_estimation:
113 | self.quality_layers = nn.Sequential(
114 | *linear_relu_ln(embed_dims, 1, 2),
115 | Linear(self.embed_dims, 2),
116 | )
117 |
118 | def init_weight(self):
119 | if self.with_cls_branch:
120 | bias_init = bias_init_with_prob(0.01)
121 | nn.init.constant_(self.cls_layers[-1].bias, bias_init)
122 |
123 | def forward(
124 | self,
125 | instance_feature: torch.Tensor,
126 | anchor: torch.Tensor,
127 | anchor_embed: torch.Tensor,
128 | time_interval: torch.Tensor = 1.0,
129 | return_cls=True,
130 | ):
131 | feature = instance_feature + anchor_embed
132 | output = self.layers(feature)
133 | output[..., self.refine_state] = output[..., self.refine_state] + anchor[..., self.refine_state]
134 |
135 | if self.normalize_yaw:
136 | output[..., [SIN_YAW, COS_YAW]] = torch.nn.functional.normalize(output[..., [SIN_YAW, COS_YAW]], dim=-1)
137 |
138 | if self.output_dim > 8:
139 | if not isinstance(time_interval, torch.Tensor):
140 | time_interval = instance_feature.new_tensor(time_interval)
141 | translation = torch.transpose(output[..., VX:], 0, -1)
142 | velocity = torch.transpose(translation / time_interval, 0, -1)
143 | output[..., VX:] = velocity + anchor[..., VX:]
144 |
145 | if return_cls:
146 | assert self.with_cls_branch, "Without classification layers !!!"
147 | cls = self.cls_layers(instance_feature)
148 | else:
149 | cls = None
150 | if return_cls and self.with_quality_estimation:
151 | quality = self.quality_layers(feature)
152 | else:
153 | quality = None
154 | return output, cls, quality
155 |
156 |
157 | @PLUGIN_LAYERS.register_module()
158 | class SparseBox3DKeyPointsGenerator(BaseModule):
159 | def __init__(
160 | self,
161 | embed_dims=256,
162 | num_learnable_pts=0,
163 | fix_scale=None,
164 | ):
165 | super(SparseBox3DKeyPointsGenerator, self).__init__()
166 | self.embed_dims = embed_dims
167 | self.num_learnable_pts = num_learnable_pts
168 | if fix_scale is None:
169 | fix_scale = ((0.0, 0.0, 0.0),)
170 | self.fix_scale = nn.Parameter(
171 | torch.tensor(fix_scale), requires_grad=False
172 | )
173 | self.num_pts = len(self.fix_scale) + num_learnable_pts
174 | if num_learnable_pts > 0:
175 | self.learnable_fc = Linear(self.embed_dims, num_learnable_pts * 3)
176 |
177 | def init_weight(self):
178 | if self.num_learnable_pts > 0:
179 | xavier_init(self.learnable_fc, distribution="uniform", bias=0.0)
180 |
181 | def forward(
182 | self,
183 | anchor,
184 | instance_feature=None,
185 | T_cur2temp_list=None,
186 | cur_timestamp=None,
187 | temp_timestamps=None,
188 | ):
189 | bs, num_anchor = anchor.shape[:2]
190 | size = anchor[..., None, [W, L, H]].exp()
191 | key_points = self.fix_scale * size
192 | if self.num_learnable_pts > 0 and instance_feature is not None:
193 | learnable_scale = (
194 | self.learnable_fc(instance_feature)
195 | .reshape(bs, num_anchor, self.num_learnable_pts, 3)
196 | .sigmoid()
197 | - 0.5
198 | )
199 | key_points = torch.cat(
200 | [key_points, learnable_scale * size], dim=-2
201 | )
202 |
203 | rotation_mat = anchor.new_zeros([bs, num_anchor, 3, 3])
204 |
205 | rotation_mat[:, :, 0, 0] = anchor[:, :, COS_YAW]
206 | rotation_mat[:, :, 0, 1] = -anchor[:, :, SIN_YAW]
207 | rotation_mat[:, :, 1, 0] = anchor[:, :, SIN_YAW]
208 | rotation_mat[:, :, 1, 1] = anchor[:, :, COS_YAW]
209 | rotation_mat[:, :, 2, 2] = 1
210 |
211 | key_points = torch.matmul(
212 | rotation_mat[:, :, None], key_points[..., None]
213 | ).squeeze(-1)
214 | key_points = key_points + anchor[..., None, [X, Y, Z]]
215 |
216 | if (
217 | cur_timestamp is None
218 | or temp_timestamps is None
219 | or T_cur2temp_list is None
220 | or len(temp_timestamps) == 0
221 | ):
222 | return key_points
223 |
224 | temp_key_points_list = []
225 | velocity = anchor[..., VX:]
226 | for i, t_time in enumerate(temp_timestamps):
227 | time_interval = cur_timestamp - t_time
228 | translation = (
229 | velocity
230 | * time_interval.to(dtype=velocity.dtype)[:, None, None]
231 | )
232 | temp_key_points = key_points - translation[:, :, None]
233 | T_cur2temp = T_cur2temp_list[i].to(dtype=key_points.dtype)
234 | temp_key_points = (
235 | T_cur2temp[:, None, None, :3]
236 | @ torch.cat(
237 | [
238 | temp_key_points,
239 | torch.ones_like(temp_key_points[..., :1]),
240 | ],
241 | dim=-1,
242 | ).unsqueeze(-1)
243 | )
244 | temp_key_points = temp_key_points.squeeze(-1)
245 | temp_key_points_list.append(temp_key_points)
246 | return key_points, temp_key_points_list
247 |
248 | @staticmethod
249 | def anchor_projection(anchor, T_src2dst_list, src_timestamp=None, dst_timestamps=None, time_intervals=None):
250 | dst_anchors = []
251 | for i in range(len(T_src2dst_list)):
252 | vel = anchor[..., VX:]
253 | vel_dim = vel.shape[-1]
254 | T_src2dst = torch.unsqueeze(T_src2dst_list[i].to(dtype=anchor.dtype), dim=1)
255 |
256 | center = anchor[..., [X, Y, Z]]
257 | if time_intervals is not None:
258 | time_interval = time_intervals[i]
259 | elif src_timestamp is not None and dst_timestamps is not None:
260 | time_interval = (src_timestamp - dst_timestamps[i]).to(dtype=vel.dtype)
261 | else:
262 | time_interval = None
263 |
264 | if time_interval is not None:
265 | translation = vel.transpose(0, -1) * time_interval
266 | translation = translation.transpose(0, -1)
267 | center = center - translation
268 |
269 | center = torch.matmul(T_src2dst[..., :3, :3], center[..., None]).squeeze(dim=-1) + T_src2dst[..., :3, 3]
270 | size = anchor[..., [W, L, H]]
271 | yaw = torch.matmul(T_src2dst[..., :2, :2], anchor[..., [COS_YAW, SIN_YAW], None],).squeeze(-1)
272 | vel = torch.matmul(T_src2dst[..., :vel_dim, :vel_dim], vel[..., None]).squeeze(-1)
273 | dst_anchor = torch.cat([center, size, yaw, vel], dim=-1)
274 | # TODO: Fix bug
275 | # index = [X, Y, Z, W, L, H, COS_YAW, SIN_YAW] + [VX, VY, VZ][:vel_dim]
276 | # index = torch.tensor(index, device=dst_anchor.device)
277 | # index = torch.argsort(index)
278 | # dst_anchor = dst_anchor.index_select(dim=-1, index=index)
279 | dst_anchors.append(dst_anchor)
280 | return dst_anchors
281 |
282 | @staticmethod
283 | def distance(anchor):
284 | return torch.norm(anchor[..., :2], p=2, dim=-1)
285 |
--------------------------------------------------------------------------------
/projects/mmdet3d_plugin/ops/src/deformable_aggregation_cuda.cu:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 | #include
5 | #include
6 |
7 | #include
8 |
9 | #include
10 | #include
11 |
12 |
13 | __device__ float bilinear_sampling(
14 | const float *&bottom_data, const int &height, const int &width,
15 | const int &num_embeds, const float &h_im, const float &w_im,
16 | const int &base_ptr
17 | ) {
18 | const int h_low = floorf(h_im);
19 | const int w_low = floorf(w_im);
20 | const int h_high = h_low + 1;
21 | const int w_high = w_low + 1;
22 |
23 | const float lh = h_im - h_low;
24 | const float lw = w_im - w_low;
25 | const float hh = 1 - lh, hw = 1 - lw;
26 |
27 | const int w_stride = num_embeds;
28 | const int h_stride = width * w_stride;
29 | const int h_low_ptr_offset = h_low * h_stride;
30 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
31 | const int w_low_ptr_offset = w_low * w_stride;
32 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
33 |
34 | float v1 = 0;
35 | if (h_low >= 0 && w_low >= 0) {
36 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
37 | v1 = bottom_data[ptr1];
38 | }
39 | float v2 = 0;
40 | if (h_low >= 0 && w_high <= width - 1) {
41 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
42 | v2 = bottom_data[ptr2];
43 | }
44 | float v3 = 0;
45 | if (h_high <= height - 1 && w_low >= 0) {
46 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
47 | v3 = bottom_data[ptr3];
48 | }
49 | float v4 = 0;
50 | if (h_high <= height - 1 && w_high <= width - 1) {
51 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
52 | v4 = bottom_data[ptr4];
53 | }
54 |
55 | const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
56 |
57 | const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
58 | return val;
59 | }
60 |
61 |
62 | __device__ void bilinear_sampling_grad(
63 | const float *&bottom_data, const float &weight,
64 | const int &height, const int &width,
65 | const int &num_embeds, const float &h_im, const float &w_im,
66 | const int &base_ptr,
67 | const float &grad_output,
68 | float *&grad_mc_ms_feat, float *grad_sampling_location, float *grad_weights) {
69 | const int h_low = floorf(h_im);
70 | const int w_low = floorf(w_im);
71 | const int h_high = h_low + 1;
72 | const int w_high = w_low + 1;
73 |
74 | const float lh = h_im - h_low;
75 | const float lw = w_im - w_low;
76 | const float hh = 1 - lh, hw = 1 - lw;
77 |
78 | const int w_stride = num_embeds;
79 | const int h_stride = width * w_stride;
80 | const int h_low_ptr_offset = h_low * h_stride;
81 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
82 | const int w_low_ptr_offset = w_low * w_stride;
83 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
84 |
85 | const float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
86 | const float top_grad_mc_ms_feat = grad_output * weight;
87 | float grad_h_weight = 0, grad_w_weight = 0;
88 |
89 | float v1 = 0;
90 | if (h_low >= 0 && w_low >= 0) {
91 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
92 | v1 = bottom_data[ptr1];
93 | grad_h_weight -= hw * v1;
94 | grad_w_weight -= hh * v1;
95 | atomicAdd(grad_mc_ms_feat + ptr1, w1 * top_grad_mc_ms_feat);
96 | }
97 | float v2 = 0;
98 | if (h_low >= 0 && w_high <= width - 1) {
99 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
100 | v2 = bottom_data[ptr2];
101 | grad_h_weight -= lw * v2;
102 | grad_w_weight += hh * v2;
103 | atomicAdd(grad_mc_ms_feat + ptr2, w2 * top_grad_mc_ms_feat);
104 | }
105 | float v3 = 0;
106 | if (h_high <= height - 1 && w_low >= 0) {
107 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
108 | v3 = bottom_data[ptr3];
109 | grad_h_weight += hw * v3;
110 | grad_w_weight -= lh * v3;
111 | atomicAdd(grad_mc_ms_feat + ptr3, w3 * top_grad_mc_ms_feat);
112 | }
113 | float v4 = 0;
114 | if (h_high <= height - 1 && w_high <= width - 1) {
115 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
116 | v4 = bottom_data[ptr4];
117 | grad_h_weight += lw * v4;
118 | grad_w_weight += lh * v4;
119 | atomicAdd(grad_mc_ms_feat + ptr4, w4 * top_grad_mc_ms_feat);
120 | }
121 |
122 | const float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
123 | atomicAdd(grad_weights, grad_output * val);
124 | atomicAdd(grad_sampling_location, width * grad_w_weight * top_grad_mc_ms_feat);
125 | atomicAdd(grad_sampling_location + 1, height * grad_h_weight * top_grad_mc_ms_feat);
126 | }
127 |
128 |
129 | __global__ void deformable_aggregation_kernel(
130 | const int num_kernels,
131 | float* output,
132 | const float* mc_ms_feat,
133 | const int* spatial_shape,
134 | const int* scale_start_index,
135 | const float* sample_location,
136 | const float* weights,
137 | int batch_size,
138 | int num_cams,
139 | int num_feat,
140 | int num_embeds,
141 | int num_scale,
142 | int num_anchors,
143 | int num_pts,
144 | int num_groups
145 | ) {
146 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
147 | if (idx >= num_kernels) return;
148 |
149 | const float weight = *(weights + idx / (num_embeds / num_groups));
150 | const int channel_index = idx % num_embeds;
151 | idx /= num_embeds;
152 | const int scale_index = idx % num_scale;
153 | idx /= num_scale;
154 |
155 | const int cam_index = idx % num_cams;
156 | idx /= num_cams;
157 | const int pts_index = idx % num_pts;
158 | idx /= num_pts;
159 |
160 | int anchor_index = idx % num_anchors;
161 | idx /= num_anchors;
162 | const int batch_index = idx % batch_size;
163 | idx /= batch_size;
164 |
165 | anchor_index = batch_index * num_anchors + anchor_index;
166 | const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
167 |
168 | const float loc_w = sample_location[loc_offset];
169 | if (loc_w <= 0 || loc_w >= 1) return;
170 | const float loc_h = sample_location[loc_offset + 1];
171 | if (loc_h <= 0 || loc_h >= 1) return;
172 |
173 | int cam_scale_index = cam_index * num_scale + scale_index;
174 | const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
175 |
176 | cam_scale_index = cam_scale_index << 1;
177 | const int h = spatial_shape[cam_scale_index];
178 | const int w = spatial_shape[cam_scale_index + 1];
179 |
180 | const float h_im = loc_h * h - 0.5;
181 | const float w_im = loc_w * w - 0.5;
182 |
183 | atomicAdd(
184 | output + anchor_index * num_embeds + channel_index,
185 | bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight
186 | );
187 | }
188 |
189 |
190 | __global__ void deformable_aggregation_grad_kernel(
191 | const int num_kernels,
192 | const float* mc_ms_feat,
193 | const int* spatial_shape,
194 | const int* scale_start_index,
195 | const float* sample_location,
196 | const float* weights,
197 | const float* grad_output,
198 | float* grad_mc_ms_feat,
199 | float* grad_sampling_location,
200 | float* grad_weights,
201 | int batch_size,
202 | int num_cams,
203 | int num_feat,
204 | int num_embeds,
205 | int num_scale,
206 | int num_anchors,
207 | int num_pts,
208 | int num_groups
209 | ) {
210 | int idx = blockIdx.x * blockDim.x + threadIdx.x;
211 | if (idx >= num_kernels) return;
212 |
213 | const int weights_ptr = idx / (num_embeds / num_groups);
214 | const int channel_index = idx % num_embeds;
215 | idx /= num_embeds;
216 | const int scale_index = idx % num_scale;
217 | idx /= num_scale;
218 |
219 | const int cam_index = idx % num_cams;
220 | idx /= num_cams;
221 | const int pts_index = idx % num_pts;
222 | idx /= num_pts;
223 |
224 | int anchor_index = idx % num_anchors;
225 | idx /= num_anchors;
226 | const int batch_index = idx % batch_size;
227 | idx /= batch_size;
228 |
229 | anchor_index = batch_index * num_anchors + anchor_index;
230 | const int loc_offset = ((anchor_index * num_pts + pts_index) * num_cams + cam_index) << 1;
231 |
232 | const float loc_w = sample_location[loc_offset];
233 | if (loc_w <= 0 || loc_w >= 1) return;
234 | const float loc_h = sample_location[loc_offset + 1];
235 | if (loc_h <= 0 || loc_h >= 1) return;
236 |
237 | const float grad = grad_output[anchor_index*num_embeds + channel_index];
238 |
239 | int cam_scale_index = cam_index * num_scale + scale_index;
240 | const int value_offset = (batch_index * num_feat + scale_start_index[cam_scale_index]) * num_embeds + channel_index;
241 |
242 | cam_scale_index = cam_scale_index << 1;
243 | const int h = spatial_shape[cam_scale_index];
244 | const int w = spatial_shape[cam_scale_index + 1];
245 |
246 | const float h_im = loc_h * h - 0.5;
247 | const float w_im = loc_w * w - 0.5;
248 |
249 | /* atomicAdd( */
250 | /* output + anchor_index * num_embeds + channel_index, */
251 | /* bilinear_sampling(mc_ms_feat, h, w, num_embeds, h_im, w_im, value_offset) * weight */
252 | /* ); */
253 | const float weight = weights[weights_ptr];
254 | float *grad_weights_ptr = grad_weights + weights_ptr;
255 | float *grad_location_ptr = grad_sampling_location + loc_offset;
256 | bilinear_sampling_grad(
257 | mc_ms_feat, weight, h, w, num_embeds, h_im, w_im,
258 | value_offset,
259 | grad,
260 | grad_mc_ms_feat, grad_location_ptr, grad_weights_ptr
261 | );
262 | }
263 |
264 |
265 | void deformable_aggregation(
266 | float* output,
267 | const float* mc_ms_feat,
268 | const int* spatial_shape,
269 | const int* scale_start_index,
270 | const float* sample_location,
271 | const float* weights,
272 | int batch_size,
273 | int num_cams,
274 | int num_feat,
275 | int num_embeds,
276 | int num_scale,
277 | int num_anchors,
278 | int num_pts,
279 | int num_groups
280 | ) {
281 | const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
282 | deformable_aggregation_kernel
283 | <<<(int)ceil(((double)num_kernels/128)), 128>>>(
284 | num_kernels, output,
285 | mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
286 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
287 | );
288 | }
289 |
290 |
291 | void deformable_aggregation_grad(
292 | const float* mc_ms_feat,
293 | const int* spatial_shape,
294 | const int* scale_start_index,
295 | const float* sample_location,
296 | const float* weights,
297 | const float* grad_output,
298 | float* grad_mc_ms_feat,
299 | float* grad_sampling_location,
300 | float* grad_weights,
301 | int batch_size,
302 | int num_cams,
303 | int num_feat,
304 | int num_embeds,
305 | int num_scale,
306 | int num_anchors,
307 | int num_pts,
308 | int num_groups
309 | ) {
310 | const int num_kernels = batch_size * num_pts * num_embeds * num_anchors * num_cams * num_scale;
311 | deformable_aggregation_grad_kernel
312 | <<<(int)ceil(((double)num_kernels/128)), 128>>>(
313 | num_kernels,
314 | mc_ms_feat, spatial_shape, scale_start_index, sample_location, weights,
315 | grad_output, grad_mc_ms_feat, grad_sampling_location, grad_weights,
316 | batch_size, num_cams, num_feat, num_embeds, num_scale, num_anchors, num_pts, num_groups
317 | );
318 | }
319 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/tools/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from __future__ import division
3 | import sys
4 | import os
5 |
6 | print(sys.executable, os.path.abspath(__file__))
7 | # import init_paths # for conda pkgs submitting method
8 | import argparse
9 | import copy
10 | import mmcv
11 | import time
12 | import torch
13 | import warnings
14 | from mmcv import Config, DictAction
15 | from mmcv.runner import get_dist_info, init_dist
16 | from os import path as osp
17 |
18 | from mmdet import __version__ as mmdet_version
19 | from mmdet.apis import train_detector
20 | from mmdet.datasets import build_dataset
21 | from mmdet.models import build_detector
22 | from mmdet.utils import collect_env, get_root_logger
23 | from mmdet.apis import set_random_seed
24 | from torch import distributed as dist
25 | from datetime import timedelta
26 |
27 | import cv2
28 |
29 | cv2.setNumThreads(8)
30 |
31 |
32 | def parse_args():
33 | parser = argparse.ArgumentParser(description="Train a detector")
34 | parser.add_argument("config", help="train config file path")
35 | parser.add_argument("--work-dir", help="the dir to save logs and models")
36 | parser.add_argument(
37 | "--resume-from", help="the checkpoint file to resume from"
38 | )
39 | parser.add_argument(
40 | "--no-validate",
41 | action="store_true",
42 | help="whether not to evaluate the checkpoint during training",
43 | )
44 | group_gpus = parser.add_mutually_exclusive_group()
45 | group_gpus.add_argument(
46 | "--gpus",
47 | type=int,
48 | help="number of gpus to use "
49 | "(only applicable to non-distributed training)",
50 | )
51 | group_gpus.add_argument(
52 | "--gpu-ids",
53 | type=int,
54 | nargs="+",
55 | help="ids of gpus to use "
56 | "(only applicable to non-distributed training)",
57 | )
58 | parser.add_argument("--seed", type=int, default=0, help="random seed")
59 | parser.add_argument(
60 | "--deterministic",
61 | action="store_true",
62 | help="whether to set deterministic options for CUDNN backend.",
63 | )
64 | parser.add_argument(
65 | "--options",
66 | nargs="+",
67 | action=DictAction,
68 | help="override some settings in the used config, the key-value pair "
69 | "in xxx=yyy format will be merged into config file (deprecate), "
70 | "change to --cfg-options instead.",
71 | )
72 | parser.add_argument(
73 | "--cfg-options",
74 | nargs="+",
75 | action=DictAction,
76 | help="override some settings in the used config, the key-value pair "
77 | "in xxx=yyy format will be merged into config file. If the value to "
78 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
79 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
80 | "Note that the quotation marks are necessary and that no white space "
81 | "is allowed.",
82 | )
83 | parser.add_argument(
84 | "--dist-url",
85 | type=str,
86 | default="auto",
87 | help="dist url for init process, such as tcp://localhost:8000",
88 | )
89 | parser.add_argument("--gpus-per-machine", type=int, default=8)
90 | parser.add_argument(
91 | "--launcher",
92 | choices=["none", "pytorch", "slurm", "mpi", "mpi_nccl"],
93 | default="none",
94 | help="job launcher",
95 | )
96 | parser.add_argument("--local_rank", type=int, default=0)
97 | parser.add_argument(
98 | "--autoscale-lr",
99 | action="store_true",
100 | help="automatically scale lr with the number of gpus",
101 | )
102 | args = parser.parse_args()
103 | if "LOCAL_RANK" not in os.environ:
104 | os.environ["LOCAL_RANK"] = str(args.local_rank)
105 |
106 | if args.options and args.cfg_options:
107 | raise ValueError(
108 | "--options and --cfg-options cannot be both specified, "
109 | "--options is deprecated in favor of --cfg-options"
110 | )
111 | if args.options:
112 | warnings.warn("--options is deprecated in favor of --cfg-options")
113 | args.cfg_options = args.options
114 |
115 | return args
116 |
117 |
118 | def main():
119 | args = parse_args()
120 |
121 | cfg = Config.fromfile(args.config)
122 | if args.cfg_options is not None:
123 | cfg.merge_from_dict(args.cfg_options)
124 | # import modules from string list.
125 | if cfg.get("custom_imports", None):
126 | from mmcv.utils import import_modules_from_strings
127 |
128 | import_modules_from_strings(**cfg["custom_imports"])
129 |
130 | # import modules from plguin/xx, registry will be updated
131 | if hasattr(cfg, "plugin"):
132 | if cfg.plugin:
133 | import importlib
134 |
135 | if hasattr(cfg, "plugin_dir"):
136 | plugin_dir = cfg.plugin_dir
137 | _module_dir = os.path.dirname(plugin_dir)
138 | _module_dir = _module_dir.split("/")
139 | _module_path = _module_dir[0]
140 |
141 | for m in _module_dir[1:]:
142 | _module_path = _module_path + "." + m
143 | print(_module_path)
144 | plg_lib = importlib.import_module(_module_path)
145 | else:
146 | # import dir is the dirpath for the config file
147 | _module_dir = os.path.dirname(args.config)
148 | _module_dir = _module_dir.split("/")
149 | _module_path = _module_dir[0]
150 | for m in _module_dir[1:]:
151 | _module_path = _module_path + "." + m
152 | print(_module_path)
153 | plg_lib = importlib.import_module(_module_path)
154 | from projects.mmdet3d_plugin.apis.train import custom_train_model
155 |
156 | # set cudnn_benchmark
157 | if cfg.get("cudnn_benchmark", False):
158 | torch.backends.cudnn.benchmark = True
159 |
160 | # work_dir is determined in this priority: CLI > segment in file > filename
161 | if args.work_dir is not None:
162 | # update configs according to CLI args if args.work_dir is not None
163 | cfg.work_dir = args.work_dir
164 | elif cfg.get("work_dir", None) is None:
165 | # use config filename as default work_dir if cfg.work_dir is None
166 | cfg.work_dir = osp.join(
167 | "./work_dirs", osp.splitext(osp.basename(args.config))[0]
168 | )
169 | if args.resume_from is not None:
170 | cfg.resume_from = args.resume_from
171 | if args.gpu_ids is not None:
172 | cfg.gpu_ids = args.gpu_ids
173 | else:
174 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
175 |
176 | if args.autoscale_lr:
177 | # apply the linear scaling rule (https://arxiv.org/abs/1706.02677)
178 | cfg.optimizer["lr"] = cfg.optimizer["lr"] * len(cfg.gpu_ids) / 8
179 |
180 | # init distributed env first, since logger depends on the dist info.
181 | if args.launcher == "none":
182 | distributed = False
183 | elif args.launcher == "mpi_nccl":
184 | distributed = True
185 |
186 | import mpi4py.MPI as MPI
187 |
188 | comm = MPI.COMM_WORLD
189 | mpi_local_rank = comm.Get_rank()
190 | mpi_world_size = comm.Get_size()
191 | print(
192 | "MPI local_rank=%d, world_size=%d"
193 | % (mpi_local_rank, mpi_world_size)
194 | )
195 |
196 | # num_gpus = torch.cuda.device_count()
197 | device_ids_on_machines = list(range(args.gpus_per_machine))
198 | str_ids = list(map(str, device_ids_on_machines))
199 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str_ids)
200 | torch.cuda.set_device(mpi_local_rank % args.gpus_per_machine)
201 |
202 | dist.init_process_group(
203 | backend="nccl",
204 | init_method=args.dist_url,
205 | world_size=mpi_world_size,
206 | rank=mpi_local_rank,
207 | timeout=timedelta(seconds=3600),
208 | )
209 |
210 | cfg.gpu_ids = range(mpi_world_size)
211 | print("cfg.gpu_ids:", cfg.gpu_ids)
212 | else:
213 | distributed = True
214 | init_dist(
215 | args.launcher, timeout=timedelta(seconds=3600), **cfg.dist_params
216 | )
217 | # re-set gpu_ids with distributed training mode
218 | _, world_size = get_dist_info()
219 | cfg.gpu_ids = range(world_size)
220 |
221 | # create work_dir
222 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
223 | # dump config
224 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
225 | # init the logger before other steps
226 | timestamp = time.strftime("%Y%m%d_%H%M%S", time.localtime())
227 | log_file = osp.join(cfg.work_dir, f"{timestamp}.log")
228 | # specify logger name, if we still use 'mmdet', the output info will be
229 | # filtered and won't be saved in the log_file
230 | # TODO: ugly workaround to judge whether we are training det or seg model
231 | logger = get_root_logger(
232 | log_file=log_file, log_level=cfg.log_level
233 | )
234 |
235 | # init the meta dict to record some important information such as
236 | # environment info and seed, which will be logged
237 | meta = dict()
238 | # log env info
239 | env_info_dict = collect_env()
240 | env_info = "\n".join([(f"{k}: {v}") for k, v in env_info_dict.items()])
241 | dash_line = "-" * 60 + "\n"
242 | logger.info(
243 | "Environment info:\n" + dash_line + env_info + "\n" + dash_line
244 | )
245 | meta["env_info"] = env_info
246 | meta["config"] = cfg.pretty_text
247 |
248 | # log some basic info
249 | logger.info(f"Distributed training: {distributed}")
250 | logger.info(f"Config:\n{cfg.pretty_text}")
251 |
252 | # set random seeds
253 | if args.seed is not None:
254 | logger.info(
255 | f"Set random seed to {args.seed}, "
256 | f"deterministic: {args.deterministic}"
257 | )
258 | set_random_seed(args.seed, deterministic=args.deterministic)
259 | cfg.seed = args.seed
260 | meta["seed"] = args.seed
261 | meta["exp_name"] = osp.basename(args.config)
262 |
263 | model = build_detector(
264 | cfg.model, train_cfg=cfg.get("train_cfg"), test_cfg=cfg.get("test_cfg")
265 | )
266 | model.init_weights()
267 |
268 | logger.info(f"Model:\n{model}")
269 | datasets = [build_dataset(cfg.data.train)]
270 | if len(cfg.workflow) == 2:
271 | val_dataset = copy.deepcopy(cfg.data.val)
272 | # in case we use a dataset wrapper
273 | if "dataset" in cfg.data.train:
274 | val_dataset.pipeline = cfg.data.train.dataset.pipeline
275 | else:
276 | val_dataset.pipeline = cfg.data.train.pipeline
277 | # set test_mode=False here in deep copied config
278 | # which do not affect AP/AR calculation later
279 | # refer to https://mmdetection3d.readthedocs.io/en/latest/tutorials/customize_runtime.html#customize-workflow # noqa
280 | val_dataset.test_mode = False
281 | datasets.append(build_dataset(val_dataset))
282 | if cfg.checkpoint_config is not None:
283 | # save mmdet version, config file content and class names in
284 | # checkpoints as meta data
285 | cfg.checkpoint_config.meta = dict(
286 | mmdet_version=mmdet_version,
287 | config=cfg.pretty_text,
288 | CLASSES=datasets[0].CLASSES,
289 | )
290 | # add an attribute for visualization convenience
291 | model.CLASSES = datasets[0].CLASSES
292 | if hasattr(cfg, "plugin"):
293 | custom_train_model(
294 | model,
295 | datasets,
296 | cfg,
297 | distributed=distributed,
298 | validate=(not args.no_validate),
299 | timestamp=timestamp,
300 | meta=meta,
301 | )
302 | else:
303 | train_detector(
304 | model,
305 | datasets,
306 | cfg,
307 | distributed=distributed,
308 | validate=(not args.no_validate),
309 | timestamp=timestamp,
310 | meta=meta,
311 | )
312 |
313 |
314 | if __name__ == "__main__":
315 | torch.multiprocessing.set_start_method(
316 | "fork"
317 | ) # use fork workers_per_gpu can be > 1
318 | main()
319 |
--------------------------------------------------------------------------------